scrna5/6 Jupyter Notebook lamindata

Iteratively train an ML model on a dataset#

In the previous tutorial, we loaded an entire dataset into memory to perform a simple analysis.

Here, we’ll iterate over the files within the dataset to train an ML model.

import lamindb as ln
import anndata as ad
import numpy as np
πŸ’‘ lamindb instance: testuser1/test-scrna
ln.track()
πŸ’‘ notebook imports: anndata==0.9.2 lamindb==0.61.0 numpy==1.26.2 scvi-tools==1.0.4 torch==2.1.1
πŸ’‘ saved: Transform(uid='Qr1kIHvK506rz8', name='Iteratively train an ML model on a dataset', short_name='scrna5', version='0', type=notebook, updated_at=2023-11-20 19:13:47 UTC, created_by_id=1)
πŸ’‘ saved: Run(uid='8dO6EOMPuYD7TRHl4jvh', run_at=2023-11-20 19:13:47 UTC, transform_id=5, created_by_id=1)

Setup#

dataset_v2 = ln.Dataset.filter(name="My versioned scRNA-seq dataset", version="2").one()

dataset_v2
Dataset(uid='2NsHmQCCFoPE1H8QgVBs', name='My versioned scRNA-seq dataset', version='2', hash='-J1PZEjWCBP0OptD6HtZ', visibility=0, updated_at=2023-11-20 19:13:24 UTC, transform_id=2, run_id=2, initial_version_id=1, created_by_id=1)

We import scvi-tools.

import scvi
Hide code cell output
/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/scvi/_settings.py:63: UserWarning: Since v1.0.0, scvi-tools no longer uses a random seed by default. Run `scvi.settings.seed = 0` to reproduce results from previous versions.
  self.seed = seed
/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/scvi/_settings.py:70: UserWarning: Setting `dl_pin_memory_gpu_training` is deprecated in v1.0 and will be removed in v1.1. Please pass in `pin_memory` to the data loaders instead.
  self.dl_pin_memory_gpu_training = (

Similar to what we did in the previous tutorial, we could load the entire dataset into memory and train a model in 4 lines of code.

How would this look like?
data_train = dataset_v2.load(join="inner")
scvi.model.SCVI.setup_anndata(data_train)
vae = scvi.model.SCVI(data_train)
vae.train(max_epochs=1)  # we use max_epochs=1 to be able to run it on CI

Let us instead load all file records:

file1, file2 = dataset_v2.files.list()

We’d like some context on what the first file contains and where it’s from:

file1.describe()
file1.view_flow()
Hide code cell output
File(uid='7cZL0wUTAhbYPoRkwJ94', key='scrna/conde22.h5ad', suffix='.h5ad', accessor='AnnData', description='Human immune cells from Conde22', size=57612943, hash='9sXda5E7BYiVoDOQkTC0KB', hash_type='sha1-fl', visibility=0, key_is_virtual=True, updated_at=2023-11-20 19:12:55 UTC)

Provenance:
  πŸ—ƒοΈ storage: Storage(uid='1DHSaXuk', root='/home/runner/work/lamin-usecases/lamin-usecases/docs/test-scrna', type='local', updated_at=2023-11-20 19:12:34 UTC, created_by_id=1)
  πŸ“” transform: Transform(uid='Nv48yAceNSh8z8', name='scRNA-seq', short_name='scrna', version='0', type='notebook', updated_at=2023-11-20 19:12:38 UTC, created_by_id=1)
  πŸ‘£ run: Run(uid='krXmpiuQnyAVVwpu9Rst', run_at=2023-11-20 19:12:38 UTC, transform_id=1, created_by_id=1)
  πŸ‘€ created_by: User(uid='DzTjkKse', handle='testuser1', name='Test User1', updated_at=2023-11-20 19:12:34 UTC)
  ⬇️ input_of (core.Run): ['2023-11-20 19:13:01 UTC', '2023-11-20 19:13:30 UTC']
Features:
  var: FeatureSet(uid='BYD1KRzo3u4Y4DImNYZ4', n=36390, type='number', registry='bionty.Gene', hash='rMZltwoBCMdVPVR8x6nJ', updated_at=2023-11-20 19:12:52 UTC, created_by_id=1)
    'MIR1302-2HG', 'FAM138A', 'OR4F5', 'None', 'None', 'None', 'None', 'None', 'None', 'None', 'OR4F29', 'None', 'OR4F16', 'None', 'LINC01409', 'FAM87B', 'LINC01128', 'LINC00115', 'FAM41C', 'None', ...
  obs: FeatureSet(uid='TQSf0AMgXjOMz53oPxBi', n=4, registry='core.Feature', hash='5Nc89cKbUXM3R-6eoEru', updated_at=2023-11-20 19:12:53 UTC, created_by_id=1)
    πŸ”— cell_type (32, bionty.CellType): 'classical monocyte', 'T follicular helper cell', 'memory B cell', 'alveolar macrophage', 'naive thymus-derived CD4-positive, alpha-beta T cell', 'effector memory CD8-positive, alpha-beta T cell, terminally differentiated', 'alpha-beta T cell', 'CD4-positive helper T cell', 'naive thymus-derived CD8-positive, alpha-beta T cell', 'macrophage', ...
    πŸ”— assay (4, bionty.ExperimentalFactor): 'single-cell RNA sequencing', '10x 3' v3', '10x 5' v2', '10x 5' v1'
    πŸ”— tissue (17, bionty.Tissue): 'blood', 'thoracic lymph node', 'spleen', 'lung', 'mesenteric lymph node', 'lamina propria', 'liver', 'jejunal epithelium', 'omentum', 'bone marrow', ...
    πŸ”— donor (12, core.ULabel): 'D496', '621B', 'A29', 'A36', 'A35', '637C', 'A52', 'A37', 'D503', '640C', ...
Labels:
  🏷️ organism (1, bionty.Organism): 'human'
  🏷️ tissues (17, bionty.Tissue): 'blood', 'thoracic lymph node', 'spleen', 'lung', 'mesenteric lymph node', 'lamina propria', 'liver', 'jejunal epithelium', 'omentum', 'bone marrow', ...
  🏷️ cell_types (32, bionty.CellType): 'classical monocyte', 'T follicular helper cell', 'memory B cell', 'alveolar macrophage', 'naive thymus-derived CD4-positive, alpha-beta T cell', 'effector memory CD8-positive, alpha-beta T cell, terminally differentiated', 'alpha-beta T cell', 'CD4-positive helper T cell', 'naive thymus-derived CD8-positive, alpha-beta T cell', 'macrophage', ...
  🏷️ experimental_factors (4, bionty.ExperimentalFactor): 'single-cell RNA sequencing', '10x 3' v3', '10x 5' v2', '10x 5' v1'
  🏷️ ulabels (12, core.ULabel): 'D496', '621B', 'A29', 'A36', 'A35', '637C', 'A52', 'A37', 'D503', '640C', ...
_images/537e54ea4fa1f1217ce03d951d6176b602a9b4b833789d5826ef1ec47a702b64.svg

We’ll need to make a decision on the features that we want to use for training the model.

Because each file is validated, they’re all indexed by ensembl_gene_id in the var slot of AnnData.

shared_genes = file1.features["var"] & file2.features["var"]
shared_genes_ensembl = shared_genes.list("ensembl_gene_id")

Train the model#

Let us load the first file into memory:

data_train1 = file1.load().raw[:, shared_genes_ensembl].to_adata()
data_train1
AnnData object with n_obs Γ— n_vars = 1648 Γ— 749
    obs: 'donor', 'tissue', 'cell_type', 'assay'
    var: 'feature_name', 'feature_reference', 'feature_biotype'
    uns: 'default_embedding'
    obsm: 'X_umap'

Train the model on this first file:

scvi.model.SCVI.setup_anndata(data_train1)
vae = scvi.model.SCVI(data_train1)
vae.train(max_epochs=1)  # we use max_epochs=1 to run it on CI
vae.save("saved_models/scvi1")
Hide code cell output
2023-11-20 19:13:51,236:INFO - Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2023-11-20 19:13:51,236:INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2023-11-20 19:13:51,238:INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO: GPU available: False, used: False
2023-11-20 19:13:51,419:INFO - GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
2023-11-20 19:13:51,420:INFO - TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
2023-11-20 19:13:51,422:INFO - IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
2023-11-20 19:13:51,424:INFO - HPU available: False, using: 0 HPUs
Training:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1/1:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1/1: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00,  5.99it/s]
Epoch 1/1: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00,  5.99it/s, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=1.06e+3]
INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
2023-11-20 19:13:51,614:INFO - `Trainer.fit` stopped: `max_epochs=1` reached.
Epoch 1/1: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00,  5.63it/s, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=1.06e+3]

Load the second file and resume training the model:

data_train2 = file2.load().raw[:, shared_genes_ensembl].to_adata()
vae = scvi.model.SCVI.load("saved_models/scvi1", data_train2)
vae.train(max_epochs=1)
vae.save("saved_models/scvi1", overwrite=True)
Hide code cell output
INFO    
 File saved_models/scvi1/model.pt already downloaded                                                       
/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/scvi/data/fields/_base_field.py:64: UserWarning: adata.X does not contain unnormalized count data. Are you sure this is what you want?
  self.validate_field(adata)
INFO: GPU available: False, used: False
2023-11-20 19:13:51,724:INFO - GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
2023-11-20 19:13:51,726:INFO - TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
2023-11-20 19:13:51,727:INFO - IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
2023-11-20 19:13:51,729:INFO - HPU available: False, using: 0 HPUs
/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:281: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
Training:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1/1:   0%|          | 0/1 [00:00<?, ?it/s]
/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/scvi/module/_vae.py:477: UserWarning: The value argument must be within the support of the distribution
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
Epoch 1/1: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00, 51.64it/s, v_num=1, train_loss_step=823, train_loss_epoch=823]
INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
2023-11-20 19:13:51,767:INFO - `Trainer.fit` stopped: `max_epochs=1` reached.
Epoch 1/1: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00, 34.16it/s, v_num=1, train_loss_step=823, train_loss_epoch=823]

Save the model#

weights = ln.File("saved_models/scvi1/model.pt", description="My trained model")
weights.save()

Save latent representation as a new dataset#

latent1 = vae.get_latent_representation(data_train1)
latent2 = vae.get_latent_representation(data_train2)

adata_latent1 = ad.AnnData(X=latent1, obs=data_train1.obs)
adata_latent2 = ad.AnnData(X=latent2, obs=data_train2.obs)
INFO    
 Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             

Because the latent representation is low-dimensional, we can typically fit very high number of observations into memory.

Hence, let’s store it as a concatenated adata.

adata_latent = ad.concat([adata_latent1, adata_latent2])
dataset_v2_latent = ln.Dataset(
    adata_latent,
    name="Latent representation of scRNA-seq dataset v2",
    description="For the original data, see dataset T5x0SkRJNviE0jYGbJKt",
)
dataset_v2_latent.save()

Let us look at the data flow:

dataset_v2_latent.view_flow()
_images/c721bf7d6a850abe2d981ccac6cc20602c81a39eb47d3bf0eca76730aa5109c8.svg

Compare this with the model:

weights.view_flow()
_images/74150466db821a3bcadf150bafa3ff1a1fc05df49c24cd597ca7608fd3fcc3de.svg

Annotate with labels:

dataset_v2_latent.labels.add_from(dataset_v2)

dataset_v2_latent.describe()
Dataset(uid='LpNcJPg9IpsSHACi74VT', name='Latent representation of scRNA-seq dataset v2', description='For the original data, see dataset T5x0SkRJNviE0jYGbJKt', hash='-MuaunRLs-BLiHka92uKow', visibility=0, updated_at=2023-11-20 19:13:51 UTC)

Provenance:
  πŸ’« transform: Transform(uid='Qr1kIHvK506rz8', name='Iteratively train an ML model on a dataset', short_name='scrna5', version='0', type=notebook, updated_at=2023-11-20 19:13:47 UTC, created_by_id=1)
  πŸ‘£ run: Run(uid='8dO6EOMPuYD7TRHl4jvh', run_at=2023-11-20 19:13:47 UTC, transform_id=5, created_by_id=1)
  πŸ“„ file: File(uid='LpNcJPg9IpsSHACi74VT', suffix='.h5ad', accessor='AnnData', description='See dataset LpNcJPg9IpsSHACi74VT', size=220242, hash='-MuaunRLs-BLiHka92uKow', hash_type='md5', visibility=0, key_is_virtual=True, updated_at=2023-11-20 19:13:51 UTC, storage_id=1, transform_id=5, run_id=5, created_by_id=1)
  πŸ‘€ created_by: User(uid='DzTjkKse', handle='testuser1', name='Test User1', updated_at=2023-11-20 19:12:34 UTC)
Features:
  external: FeatureSet(uid='vsmOlJZsEiV43a9cyOW1', n=5, registry='core.Feature', hash='9Pofnceb4Lp-eeHoz6TY', updated_at=2023-11-20 19:13:52 UTC, created_by_id=1)
    πŸ”— cell_type (40, bionty.CellType): 'classical monocyte', 'T follicular helper cell', 'memory B cell', 'alveolar macrophage', 'naive thymus-derived CD4-positive, alpha-beta T cell', 'effector memory CD8-positive, alpha-beta T cell, terminally differentiated', 'alpha-beta T cell', 'CD4-positive helper T cell', 'naive thymus-derived CD8-positive, alpha-beta T cell', 'macrophage', ...
    πŸ”— assay (4, bionty.ExperimentalFactor): '10x 3' v3', '10x 5' v2', '10x 5' v1', 'single-cell RNA sequencing'
    πŸ”— tissue (17, bionty.Tissue): 'blood', 'thoracic lymph node', 'spleen', 'lung', 'mesenteric lymph node', 'lamina propria', 'liver', 'jejunal epithelium', 'omentum', 'bone marrow', ...
    πŸ”— organism (1, bionty.Organism): 'human'
    πŸ”— donor (12, core.ULabel): 'D496', '621B', 'A29', 'A36', 'A35', '637C', 'A52', 'A37', 'D503', '640C', ...
Labels:
  🏷️ organism (1, bionty.Organism): 'human'
  🏷️ tissues (17, bionty.Tissue): 'blood', 'thoracic lymph node', 'spleen', 'lung', 'mesenteric lymph node', 'lamina propria', 'liver', 'jejunal epithelium', 'omentum', 'bone marrow', ...
  🏷️ cell_types (40, bionty.CellType): 'classical monocyte', 'T follicular helper cell', 'memory B cell', 'alveolar macrophage', 'naive thymus-derived CD4-positive, alpha-beta T cell', 'effector memory CD8-positive, alpha-beta T cell, terminally differentiated', 'alpha-beta T cell', 'CD4-positive helper T cell', 'naive thymus-derived CD8-positive, alpha-beta T cell', 'macrophage', ...
  🏷️ experimental_factors (4, bionty.ExperimentalFactor): '10x 3' v3', '10x 5' v2', '10x 5' v1', 'single-cell RNA sequencing'
  🏷️ ulabels (12, core.ULabel): 'D496', '621B', 'A29', 'A36', 'A35', '637C', 'A52', 'A37', 'D503', '640C', ...

Use a pytorch compatible IndexedDataset class#

If you need to train your model on a list of files, you can use a ln.Dataset.indexed() with the pytorch DataLoader. It doesn’t load anything into memory and thus allows to work with very large AnnData files.

from torch.utils.data import DataLoader, WeightedRandomSampler

Create a Dataset for training.

file_train1 = ln.File(data_train1, description="Conde22_train")
file_train1.save()
file_train2 = ln.File(data_train2, description="pbmc10x_train")
file_train2.save()

Files in the dataset should have the same variables, we have already taken care of this.

dataset_train = ln.Dataset([file_train1, file_train2], name="Dataset for training")
dataset_train.save()
ds = dataset_train.indexed(labels=["cell_type"])

This is compatible with pytorch DataLoader because it implements __getitem__ over a list of AnnData files.

ds[5]
[array([ 0.,  0.,  0.,  2.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0., 17.,  0.,  0.,  0.,  2.,  0.,  0.,  2.,  1.,
         0.,  0.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,
         0.,  0.,  0.,  0.,  0.,  4.,  0.,  2.,  3.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,
         3.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,
         1.,  0.,  3.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,
         2.,  0.,  0.,  5.,  6.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,
         0.,  0.,  0.,  1.,  1.,  1.,  1.,  3.,  0.,  0.,  4.,  1.,  3.,
         0.,  0.,  0.,  0.,  0.,  2.,  0.,  2.,  1.,  0.,  0.,  1.,  0.,
         0.,  0.,  1.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  2.,  0.,  1.,
         0.,  0.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         5.,  0.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,
         0.,  2.,  1.,  0.,  0.,  1.,  3.,  4.,  1.,  0.,  2.,  1.,  1.,
         1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,  2.,  0.,  0.,  3.,
         0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 96.,  0.,  6.,
         1.,  2.,  0.,  0.,  1.,  0.,  6.,  1.,  0.,  0.,  1.,  0.,  2.,
         0.,  3.,  0.,  0.,  2., 10.,  0.,  0.,  0.,  5.,  1., 26.,  2.,
        14.,  6.,  5.,  0.,  3.,  0.,  8., 10.,  0.,  1.,  0.,  1.,  0.,
         1.,  1.,  0.,  5.,  1.,  0.,  3.,  1.,  1.,  1.,  0.,  0.,  0.,
         2.,  1.,  3.,  0.,  0.,  1.,  3.,  3.,  0.,  0.,  2.,  0.,  1.,
         0.,  4.,  1.,  0.,  0.,  1.,  0.,  0.,  1.,  1.,  7.,  1.,  0.,
         0.,  0.,  0.,  0.,  0.,  2.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  4.,  0.,  0.,  0.,  0.,  1.,  0.,  1.,  0.,  0.,
         0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  2.,  0.,  0.,  0.,
         0.,  0.,  0.,  2.,  0.,  1.,  2.,  0.,  1.,  0.,  2.,  0.,  1.,
         0.,  1.,  0.,  0.,  1.,  0.,  0.,  0., 17.,  0.,  0.,  0.,  0.,
         4.,  0.,  0.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         2.,  1., 13.,  0.,  1.,  1.,  1.,  0.,  0.,  1.,  0.,  1.,  0.,
         0.,  0.,  0.,  2.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  2.,  2.,  1.,  0.,  0.,  0.,  6.,  0.,  2.,  0.,  0.,
         0.,  0.,  1.,  1.,  0.,  1.,  0.,  2.,  0.,  0.,  1.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  1.,  0.,  1.,  0.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  2.,  0.,  0.,  0.,  1.,  1.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  4.,  0.,  6.,  1.,  0.,  0.,  0.,  2.,  0.,  2.,
         1.,  0.,  0.,  0.,  0.,  1.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  5.,  0.,  7.,  1.,  0.,
         0.,  0.,  0.,  1., 10.,  1.,  6.,  0.,  0.,  1.,  4.,  0.,  0.,
         0.,  0.,  0.,  2.,  0.,  1.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,
         0.,  0.,  0.,  0.,  0.,  0.,  1.,  2.,  0.,  0.,  0.,  0.,  1.,
         0.,  0.,  0.,  0.,  1.,  2.,  1.,  0.,  0.,  0.,  4.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  3.,  0.,  0.,  0.,  0.,
         0.,  0.,  2.,  0.,  0.,  0.,  3.,  0.,  0.,  3.,  0.,  0.,  0.,
         0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  0.,  0.,  2.,
         0.,  6.,  0.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,  2.,  0.,  0.,  0.,  1.,
         1.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,  0.,  0.,  0.,  2.,
         0.,  0.,  0.,  0.,  0.,  1.,  3.,  4.,  0.,  0.,  0.,  1.,  0.,
        10.,  0.,  1.,  0.,  1.,  1.,  0.,  3.,  0.,  0.,  0.,  0.,  0.,
         1.,  0.,  0.,  4.,  0.,  0.,  0.,  0.,  0., 15.,  0.,  6.,  0.,
         1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,
         0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  2.,  3.,  0.,  6.,
         1.,  1.,  0.,  1.,  1.,  0.,  2.,  0.,  1.,  0.,  0.,  0.,  1.,
         1.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,
         0.,  4.,  0.,  1.,  1.,  0.,  0.,  0.,  2.,  0.,  0.,  0.,  2.,
         1.,  1.,  0.,  0.,  0.,  0.,  1.,  1.,  1.,  2.,  0.,  1.,  0.,
         0.,  0., 53.,  1.,  0.,  1.,  0., 35.]),
 26]

labels are encoded into integers.

ds.encoders
[{'CD8-positive, alpha-beta memory T cell, CD45RO-positive': 0,
  'naive B cell': 1,
  'CD16-positive, CD56-dim natural killer cell, human': 2,
  'lymphocyte': 3,
  'plasmablast': 4,
  'CD4-positive, alpha-beta T cell': 5,
  'dendritic cell': 6,
  'germinal center B cell': 7,
  'alveolar macrophage': 8,
  'effector memory CD4-positive, alpha-beta T cell, terminally differentiated': 9,
  'CD16-negative, CD56-bright natural killer cell, human': 10,
  'effector memory CD4-positive, alpha-beta T cell': 11,
  'effector memory CD8-positive, alpha-beta T cell, terminally differentiated': 12,
  'naive thymus-derived CD4-positive, alpha-beta T cell': 13,
  'classical monocyte': 14,
  'animal cell': 15,
  'mast cell': 16,
  'plasma cell': 17,
  'CD4-positive helper T cell': 18,
  'naive thymus-derived CD8-positive, alpha-beta T cell': 19,
  'CD8-positive, alpha-beta memory T cell': 20,
  'plasmacytoid dendritic cell': 21,
  'alpha-beta T cell': 22,
  'CD14-positive, CD16-negative classical monocyte': 23,
  'megakaryocyte': 24,
  'B cell, CD19-positive': 25,
  'memory B cell': 26,
  'progenitor cell': 27,
  'non-classical monocyte': 28,
  'T follicular helper cell': 29,
  'gamma-delta T cell': 30,
  'CD8-positive, CD25-positive, alpha-beta regulatory T cell': 31,
  'mucosal invariant T cell': 32,
  'group 3 innate lymphoid cell': 33,
  'regulatory T cell': 34,
  'conventional dendritic cell': 35,
  'cytotoxic T cell': 36,
  'macrophage': 37,
  'dendritic cell, human': 38,
  'CD38-positive naive B cell': 39}]
# label_key for weight doesn't have to be in labels on init
sampler = WeightedRandomSampler(
    weights=ds.get_labels_weights("cell_type"), num_samples=len(ds)
)
dl = DataLoader(ds, batch_size=128, sampler=sampler)
for batch in dl:
    pass
# clean up test instance
!lamin delete --force test-scrna
!rm -r ./test-scrna
πŸ’‘ deleting instance testuser1/test-scrna
βœ…     deleted instance settings file: /home/runner/.lamin/instance--testuser1--test-scrna.env
βœ…     instance cache deleted
βœ…     deleted '.lndb' sqlite file
❗     consider manually deleting your stored data: /home/runner/work/lamin-usecases/lamin-usecases/docs/test-scrna