We live within the era of quantification. But rigorous quantification is less complicated said then done. In complex systems similar to biology, data may be difficult and expensive to gather. While in high stakes applications, similar to in medicine and finance, it’s crucial to account for uncertainty. Variational inference — a technique on the forefront of AI research — is a solution to address these facets.
This tutorial introduces you to the fundamentals: the when, why, and the way of variational inference.
Variational inference is appealing in the next three closely related usecases:
1. if you could have little data (i.e., low variety of observations),
2. you care about uncertainty,
3. for generative modelling.
We’ll touch upon each usecase in our worked example.
1. Variational inference with little data
Sometimes, data collection is dear. For instance, DNA or RNA measurements can easily cost a couple of thousand euros per commentary. On this case, you may hardcode domain knowledge in lieu of additional samples. Variational inference may help to systematically “dial down” the domain knowledge as you gather more examples, and more heavily depend on the information (Fig. 1).
2. Variational inference for uncertainty
For safety critical applications, similar to in finance and healthcare, uncertainty is significant. Uncertainty can affect all facets of the model, most obviously the expected output. Less obvious are the model’s parameters (e.g., weights and biases). As an alternative of the same old arrays of numbers — the weights and biases — you may endow the parameters with a distribution to make them fuzzy. Variational inference permits you to infer the range(s) of reasonable values.
3. Variational inference for generative modelling
Generative models provide an entire specification how the information was generated. For instance, how you can generate a picture of a cat or a dog. Normally, there may be a latent representation that carries semantic meaning (e.g., descibes a siamese cat). Through a set of (non-linear) transformations and sampling steps, is transformed into the actual image (e.g., the pixel values of the siamese cat). Variational inference is a solution to infer, and sample from, the latent semantic space . A well-known example is the variational auto encoder.
At its core, variational inference is a Bayesian undertaking [1]. Within the Bayesian perspective, you continue to let the machine learn from the information, as usual. What’s different, is that you just give the model a touch (a previous) and permit the answer (the posterior) to be more fuzzy. More concretely, say you could have a training set ₁, ₂,..,ₘ]ᵗ of m examples. We use Bayes’ theorem:
p(|)p(|)p() /p(),
to infer a spread — a distribution — of solutions . Contrast this with the standard machine learning approach, where we minimise a loss ℒ() = ln p(|) to seek out one specific solution . Bayesian inference revolves around finding a solution to determine p(|): the posterior distribution of the parameters given the training set . Generally, it is a difficult problem. In practice, two ways are used to unravel for p(|): (i) using simulation (Markov chain Monte Carlo) or (ii) through optimisation.
Variational inference is about option (ii).
The evidence lower sure (ELBO)
The thought behind variational inference is to search for a distribution q() that could be a stand-in (a surrogate) for p(|). We then attempt to make q[()] look just like p(|) by changing the values of (Fig. 2). This is completed by maximising the evidence lower sure (ELBO):
ℒ() = E[ln p(,) — ln q(],
where the expectation E[·] is taken over q(). (Note that implicitly depends upon the dataset , but for notational convenience we’ll drop the specific dependence.)
For gradient based optimisation of ℒ it looks, at first sight, like we’ve got to watch out when taking derivatives (with respect to ) due to dependence of E[·] on q(). Fortunately, autograd packages like JAX support reparameterisation tricks [2] that will let you directly take derivatives from random samples (e.g., of the gamma distribution) as a substitute of counting on high variance black box variational approaches [3]. Long story short: estimate ∇ℒ(Φ) with a batch ₁, ₂,..] ~ q() and let your autograd package worry about the small print.
To solidify our understanding allow us to implement variational inference from scratch using JAX. In this instance, you’ll train a generative model on handwritten digits from sci-kit learn. You may follow together with the Colab notebook.
To maintain it easy, we’ll only analyse the digit “zero”.
from sklearn import datasetsdigits = datasets.load_digits()
is_zero = digits.goal == 0
X_train = digits.images[is_zero]
# Flatten image grid to a vector.
n_pixels = 64  # 8-by-8.
X_train = X_train.reshape((-1, n_pixels))
Each image is a 8-by-8 array of discrete pixel values starting from 0–16. For the reason that pixels are count data, let’s model the pixels, , using the Poisson distribution with a gamma prior for the speed . The speed determines the common intensity of the pixels. Thus, the joint distribution is given by:
p(,)Poisson(|)Gamma(|, ),
where and are the form and rate of the gamma distribution.
The prior — on this case, Gamma(|, ) — is the place where you infuse your domain knowledge (usecase 1.). For instance, you could have some idea what the “average” digit zero looks like (Fig. 4). You should utilize this a priori information to guide your selection of and . To make use of Fig. 4 as prior information — let’s call it ₀ — and weigh its importance as two examples, then set = 2₀; = 2.
Written down in Python this looks like:
import jax.numpy as jnp
import jax.scipy as jsp# Hyperparameters of the model.
a = 2. * x_domain_knowledge
b = 2.
def log_joint(θ):
log_likelihood = jnp.sum(jsp.stats.gamma.logpdf(θ, a, scale=1./b))
log_likelihood += jnp.sum(jsp.stats.poisson.logpmf(X_train, θ))
return log_likelihood
Note that we’ve used the JAX implementation of numpy and scipy, in order that we are able to take derivatives.
Next, we’d like to decide on a surrogate distribution q(). To remind you, our goal is to alter in order that the surrogate distribution q() matches p(. So, the selection of q() determines the extent of approximation (we suppress the dependence on where context permits). For illustration purposes, lets select a variational distribution that consists of (a product of) gamma’s:
q() = Gamma(|,),
where we used the shorthand = {,}.
Next, to implement the evidence lower sure ℒ() = E[ln p(,) — ln q()], first write down the term contained in the expectation brackets:
@partial(vmap, in_axes=(0, None, None))
def evidence_lower_bound(θ_i, alpha, inv_beta):
elbo = log_joint(θ_i) - jnp.sum(jsp.stats.gamma.logpdf(θ_i, alpha, scale=inv_beta))
return elbo
Here, we used JAX’s vmap to vectorise the function in order that we are able to run it on a batch ₁, ₂,..,₁₂₈]ᵗ.
To finish the implementation of ℒ(), we average the above function over realisations of the variational distribution ᵢ ~ q():
def loss(Φ: dict, key):
"""Stochastic estimate of evidence lower sure."""
alpha = jnp.exp(Φ['log_alpha'])
inv_beta = jnp.exp(-Φ['log_beta'])# Sample a batch from variational distribution q.
batch_size = 128
batch_shape = [batch_size, n_pixels]
θ_samples = random.gamma(key, alpha , shape=batch_shape) * inv_beta
# Compute Monte Carlo estimate of evidence lower sure.
elbo_loss = jnp.mean(evidence_lower_bound(θ_samples, alpha, inv_beta))
# Turn elbo right into a loss.
return -elbo_loss
A number of things to note here concerning the arguments:
- We’ve packed as a dictionary (or technically, a pytree) containing ln(), and ln(). This trick guarantees that >0 and >0 — a requirement imposed by the gamma distribution — during optimisation.
- The loss is a random estimate of the ELBO. In JAX, we’d like a recent pseudo random number generator (PRNG) key each time we sample. On this case, we use key to sample ₁, ₂,..,₁₂₈]ᵗ.
This completes the specification of the model p(,, the variational distribution q(), and the loss ℒ().
Model training
Next, we minimise the loss ℒ() by various = {,}in order that q() matches the posterior p(|). How? Using quaint gradient descent! For convenience, we use the Adam optimiser from Optax and initialise the parameters with the prior , and [remember, the prior wasGamma(|, ) and codified our domain knowledge].
# Initialise parameters using prior.
Φ = {
'log_alpha': jnp.log(a),
'log_beta': jnp.full(fill_value=jnp.log(b), shape=[n_pixels]),
}loss_val_grad = jit(jax.value_and_grad(loss))
optimiser = optax.adam(learning_rate=0.2)
opt_state = optimiser.init(Φ)
Here, we use value_and_grad to concurrently evaluate the ELBO and its derivative. Convenient for monitoring convergence! We then just-in-time compile the resulting function(with jit) to make it snappy.
Finally, we’Il train the model for 5000 steps. Since loss is random, for every evaluation we’d like to provide it a pseudo random number generator (PRNG) key. We do that by allocating 5000 keys with random.split.
n_iter = 5_000
keys = random.split(random.PRNGKey(42), num=n_iter)for i, key in enumerate(keys):
elbo, grads = loss_val_grad(Φ, key)
updates, opt_state = optimiser.update(grads, opt_state)
Φ = optax.apply_updates(Φ, updates)
Congrats! You’ve succesfully trained your first model using variational inference!
You may access the notebook with the total code here on Colab.
Results
Let’s take a step back and appreciate what we’ve built (Fig. 5). For every pixel, the surrogate q() describes the uncertainty concerning the average pixel intensity (usecase 2.). Particularly, our selection of q() captures two complementary elements:
- The standard pixel intensity.
- How much the intensity varies from image to image (the variability).
It seems that the joint distribution p(,) we selected has an actual solution:
p(Gamma(|Σᵢ, m + ),
where m are the variety of samples within the training set . Here, we see explicitly how the domain knowledge—codified in and — is dialed down as we gather more examples ᵢ.
We are able to easily compare the learned shape and rate with the true values Σᵢ and m + . In Fig. 4 we compare the distributions — q() versus p(for 2 specific pixels. Lo and behold, an ideal match!
Bonus: generating synthetic images
Variational inference is great for generative modelling (usecase 3.). With the stand-in posterior q() in hand, generating recent synthetic images is trivial. The 2 steps are:
- Sample pixel intensities q().
# Extract parameters of q.
alpha = jnp.exp(Φ['log_alpha'])
inv_beta = jnp.exp(-Φ['log_beta'])# 1) Generate pixel-level intensities for 10 images.
key_θ, key_x = random.split(key)
m_new_images = 10
new_batch_shape = [m_new_images, n_pixels]
θ_samples = random.gamma(key_θ, alpha , shape=new_batch_shape) * inv_beta
- Sample images using ~ Poisson(|).
# 2) Sample image from intensities.
X_synthetic = random.poisson(key_x, θ_samples)
You may see the end in Fig. 6. Notice that the “zero” character is barely less sharp than expected. This was a part of our modelling assumptions: we modelled the pixels as mutually independent relatively than correlated. To account for pixel correlations, you may expand the model to cluster pixel intensities: this is named Poisson factorisation [4].
On this tutorial, we introduced the fundamentals of variational inference and applied it to a toy example: learning a handwritten digit zero. Due to autograd, implementing variational inference from scratch takes only a couple of lines of Python.
Variational inference is especially powerful if you could have little data. We saw how you can infuse and trade-of domain knowledge with information from the information. The inferred surrogate distribution q() gives a “fuzzy” representation of the model parameters, as a substitute of a set value. This is good should you are in a high-stakes application where uncertainty is significant! Finally, we demonstrated generative modelling. Generating synthetic samples is straightforward once you may sample from q().
In summary, by harnessing the ability of variational inference, we are able to tackle complex problems, enabling us to make informed decisions, quantify uncertainties, and ultimately unlock the true potential of information science.
Acknowledgements
I would love to thank Dorien Neijzen and Martin Banchero for proofreading.
References:
[1] Blei, David M., Alp Kucukelbir, and Jon D. McAuliffe. “Variational inference: A review for statisticians.” Journal of the American statistical Association 112.518 (2017): 859–877.
[2] Figurnov, Mikhail, Shakir Mohamed, and Andriy Mnih. “Implicit reparameterization gradients.” Advances in neural information processing systems 31 (2018).
[3] Ranganath, Rajesh, Sean Gerrish, and David Blei. “Black box variational inference.” Artificial intelligence and statistics. PMLR, 2014.
[4] Gopalan, Prem, Jake M. Hofman, and David M. Blei. “Scalable suggestion with poisson factorization.” arXiv preprint arXiv:1311.1704 (2013).



