Skip to content

Adding example script with custom Loader for PyTorch API documentation #97

@timonmerk

Description

@timonmerk

The CEBRA documentation is very comprehensive and presents in a lot of detail the parameterization.
In the current form however the focus seems to explain the scikit-learn API and there is no example script for using the PyTorch API: https://cebra.ai/docs/usage.html

But for many options I am unsure how to parametrize them in the scikit-learn API. For example when using discrete behavioral data, it's currently not possible to specify empirical or discretesampling:

prior: str = dataclasses.field(

I think this is also intended to not overload the cebra.Cebra intialization or the model.fit() function with too many parameters?

Therefore I thought that maybe adding a minimal example in the usage.rst of how a dataloader with "non-scikitlearn API" conform parameters could be used using PyTorch directly:

import numpy as np
import cebra.datasets
from cebra import plot_embedding
import torch

neural_data = cebra.load_data(file="neural_data.npz", key="neural")
# continuous_label = cebra.load_data(
#    file="auxiliary_behavior_data.h5",
#    key="auxiliary_variables",
#    columns=["continuous1", "continuous2", "continuous3"],
# )

discrete_label = cebra.load_data(
    file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"],
)

# 1. Define Cebra Dataset
InputData = cebra.data.TensorDataset(
    torch.from_numpy(neural_data).type(torch.FloatTensor),
    # continuous=torch.from_numpy(np.array(continuous_label)).type(torch.FloatTensor),
    discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor),
).to("cpu")

# 2. Define Cebra Model
neural_model = cebra.models.init(
    name="offset10-model",
    num_neurons=InputData.input_dimension,
    num_units=32,
    num_output=2,
).to("cpu")

InputData.configure_for(neural_model)

# 3. Define Loss Function Criterion and Optimizer
Crit = cebra.models.criterions.LearnableCosineInfoNCE(
    # temperature=0.001,
    # min_temperature=0.0001
).to("cpu")

Opt = torch.optim.Adam(
    list(neural_model.parameters()) + list(Crit.parameters()),
    # lr=0.001,
    weight_decay=0,
)

# 4. Initialize Cebra Model
cebra_model = cebra.solver.init(
    name="single-session",
    model=neural_model,
    criterion=Crit,
    optimizer=Opt,
    tqdm_on=True,
).to("cpu")

# 5. Define Data Loader
# loader = cebra.data.single_session.ContinuousDataLoader(
#    dataset=InputData, num_steps=1000, batch_size=200
# )
loader = cebra.data.single_session.DiscreteDataLoader(
    dataset=InputData, num_steps=1000, batch_size=200, prior="uniform"
)

# 6. Fit model
cebra_model.fit(loader=loader)

# 7. Transform embedding
TrainBatches = np.lib.stride_tricks.sliding_window_view(
    neural_data, neural_model.get_offset().__len__(), axis=0
)
X_train_emb = cebra_model.transform(
    torch.from_numpy(TrainBatches[:]).type(torch.FloatTensor).to("cpu")
).to("cpu")

# 8. Potentially plot embedding
plot_embedding(
    X_train_emb,
    discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
    markersize=10,
)

Metadata

Metadata

Assignees

Labels

documentationImprovements or additions to documentationenhancementNew feature or request

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions