Arc Institute recently unveiled the Virtual Cell Challenge. Participants are required to coach a model able to predicting the effect of silencing a gene in a (partially) unseen cell type, a task they term context generalization.
For ML engineers with little to no biology background, the jargon and required context can seem quite daunting. To encourage participation, we recapitulate the challenge in a form higher suited to engineers from other disciplines.
Goal
Train a model to predict the effect on a cell of silencing a gene using CRISPR.
Doing things on the planet of atoms is dear, laborious and error prone. What if we could test hundreds of drug candidates without ever touching a petri dish?
That is the goal of the virtual cell challenge — a model (probably a neural network) that may simulate exactly what happens
to a cell when we alter some parameter. Provided that tightening your feedback loop is commonly one of the best solution to speed up progress,
a model able to doing this accurately would have significant impact.
To coach this neural network, we’ll need data. For the challenge, Arc has curated a dataset of ~300k single-cell RNA sequencing profiles. It could be worthwhile to revisit the Central Dogma before continuing. This essay will construct off of this to supply the ~minimum biology knowledge you will need for the challenge.
Training data
The training set consists of a sparse matrix and a few associated metadata. More specifically, we now have 220k cells, and
for every cell we now have a transcriptome. This transcriptome is a sparse row vector, where each
entry is the raw count of RNA molecules (transcripts) that the corresponding gene (our column) encodes for. Of the 220k cells,
~38k are unperturbed, meaning no gene has been silenced using CRISPR. These control cells are crucial as we’ll see shortly.
To grasp the dataset more concretely, let’s select a gene, TMSB4X (probably the most continuously silenced gene within the dataset) and compare the variety of RNA molecules detected for a control cell and a
perturbed cell.

We will see that the cell with TMSB4X silenced has a greatly reduced variety of transcripts compared with the control
cells.
Modelling the challenge
The astute amongst you could be wondering why you do not just measure the count of the RNA molecules before and after
silencing the gene — why do we’d like the control cells in any respect? Unfortunately, reading the transcriptome destroys the cell, which is an issue paying homage to the observer effect.
This inability to measure the cell state before and after introduces many issues, as we’re forced to make use of a population of basal
(a.k.a control, unperturbed) cells as a reference point. The control cells and perturbed cells usually are not entirely
homogeneous even prior to the perturbation. Because of this we now have to now separate out our true signal, the perturbation, from
noise induced by the heterogeneity.
More formally, we are able to model observed gene expression in perturbed cells as:
where:
- : The observed gene expression measurements in cells with perturbation
- : The distribution of the unperturbed, baseline cell population.
- : True effect brought on by perturbation on the population.
- : Biological heterogeneity of the baseline population.
- : Experiment-specific technical noise, assumed independent of the unperturbed cell state and .
Prior to the Virtual Cell Challenge, Arc released STATE, their very own try and solve the challenge
using a pair of transformer based models. This serves as a robust baseline for participants to start out with, so we’ll
explore it intimately.
STATE consists of two models, the State Transition Model (ST) and the State Embedding Model (SE). SE is designed to provide wealthy semantic embeddings of cells in an effort to enhance cross cell type generalization. ST is the “cell simulator”, that takes in either a transcriptome of a control cell, or an embedding of a cell produced by SE, together with a one hot encoded vector representing the perturbation of interest, and outputs the perturbed transcriptome.
State Transition Model (ST)

The State Transition Model is a comparatively easy transformer with a Llama backbone that operates upon the next:
- A set of transcriptomes (or SE embeddings) for covariate matched basal cells.
- A set of 1 hot vectors representing our gene perturbation for every cell.
Using a covariate matched set of control cells with paired goal cells should assist the model in discerning the
actual effect of our intended perturbation. Each the control set tensor and the perturbation tensor are fed through independent encoders, that are simply 4 layer MLPs with GELU activations.
If working directly in gene expression space (i.e producing a full transcriptome), they pass the output through a learned
decoder.
ST is trained using Maximum Mean Discrepancy. Put simply, the model learns to reduce the difference between the 2 probability distributions.
State Embedding Model (SE)

The State Embedding Model is a BERT-like autoencoder. To grasp this more deeply, first we now have to
take a little bit detour for some more biological grounding.
A bit of biological detour

A gene consists of exons (protein coding sections) and introns (non-protein coding sections). DNA is first transcribed into pre-mRNA, as shown above. The cell then performs Alternative Splicing. This is largely “pick and select exons”, cut out all introns. You possibly can consider the gene as an IKEA manual for making a table. One could also construct a 3 legged table, perhaps an odd bookshelf with some effort, by leaving out some parts. These different objects are analogous to protein isoforms, proteins coded for by the identical gene.
Back to the model
With this basic understanding, we are able to move on to how the SE model works. Remember, our core goal for SE is to create meaningful
cell embeddings. To do that, we must first create meaningful gene embeddings.
To provide a single gene embedding, we first obtain the amino acid sequence (e.g … for TMSB4X) of all the various protein isoforms encoded for by the gene in query. We then feed these sequences to ESM2, a 15B parameter Protein Language Model from FAIR. ESM produces an embedding per amino acid, and we mean pool them together to acquire a “transcript” (a.k.a protein isoform) embedding.
Now we now have all of those protein isoform embeddings, we then just mean pool those to get the gene embedding. Next, we project these gene embeddings to our model dimension using a learned encoder as follows:
We have now obtained a gene embedding, but what we actually need is a cell embedding. To do that, Arc represents each cell
as the highest 2048 genes ranked by log fold expression level.
We then construct a “cell sentence” from our 2048 gene embeddings as follows:
We add a token and token to our sentence. The token finally ends up getting used as our “cell embedding” (very BERT-like)
and the token is used to “disentangle dataset-specific effects”. Although the genes are sorted by log fold
expression level, Arc further enforces the magnitude of every genes expression by incorporating the transcriptome in a
fashion analogous to positional embeddings. Through an odd “soft binning” algorithm and a pair of MLPs, they create some
“expression encodings” which they then add to every gene embedding. This could modulate the magnitude of every gene
embedding by how intensely it’s expressed within the transcriptome.
To coach the model, they mask 1280 genes per cell, and the model is tasked with predicting them. The 1280 genes are
chosen such that they’ve a wide selection of expression intensities. For the graphically inclined, the below
demonstrates the development of the cell sentence.

Understanding how your submission shall be evaluated is vital to success. The three evaluation metrics chosen by Arc are Perturbation Discrimination, Differential Expression and Mean Average Error. Provided that Mean Average Error is easy and exactly because it sounds, we’ll omit it from our evaluation.
Perturbation Discrimination

Perturbation Discrimination intends to judge how well your model can uncover relative differences between
perturbations. To do that, we compute the Manhattan distances for all of the measured perturbed transcriptomes within the test set (the bottom
truth we are attempting to predict, and all other perturbed transcriptomes, ) to our predicted transcriptome . We then rank where the
ground truth lands with respect to all transcriptomes as follows:
After, we normalize by the whole variety of transcriptomes:
Where can be an ideal match. The general rating on your predictions is the mean of all $$text{PDisc}_t$$. That is then normalized to:
We multiply by 2 as for a random prediction, ~half of the outcomes can be closer and half can be further away.
Differential Expression
Differential Expression intends to judge what fraction of the truly affected genes did you appropriately discover as significantly affected. Firstly, for every gene compute a -value using a Wilcoxon rank-sum test with tie correction. We do that for each our predicted perturbation distribution and the bottom truth perturbation distribution.
Next, we apply the Benjamini-Hochberg procedure, mainly some stats to modulate the -values, as with genes and a -value threshold of , you’d expect false positives. We denote our set of predicted differentially expressed genes , and the bottom truth set of differentially expressed genes .
If the dimensions of our set is lower than the bottom truth set size, take the intersection of the sets, and divide by the true variety of differentially expressed genes as follows:
If the dimensions of our set is bigger than the bottom truth set size, select the subset we predict are most differentially expressed (our “most confident” predictions, denoted ), take the intersection with the bottom truth set, after which divide by the true number.
Do that for all predicted perturbations and take the mean to acquire the ultimate rating.
Conclusion
If this challenge has piqued your interest, how can one start? Fortunately, Arc has provided a Colab notebook that walks through your complete strategy of training their STATE model. Moreover, STATE shall be hitting transformers
very soon, so starting with their pretrained models shall be so simple as:
import torch
from transformers import StateEmbeddingModel
model_name = "arcinstitute/SE-600M"
model = StateEmbeddingModel.from_pretrained(model_name)
input_ids = torch.randn((1, 1, 5120), dtype=torch.float32)
mask = torch.ones((1, 1, 5120), dtype=torch.bool)
mask[:, :, 2560:] = False
outputs = model(input_ids, mask)
Better of luck to all participants!
This post was originally published here.
