## Markov chains, Metropolis-Hastings, Gibbs sampling, and the way it pertains to Bayesian inference

This post is an introduction to Markov chain Monte Carlo (MCMC) sampling methods. We are going to consider two methods specifically, namely the Metropolis-Hastings algorithm and Gibbs sampling. We are going to introduce them and prove why they work, implement practical examples in Python, and eventually explain how sampling is applied for Bayesian inference and why it’s so essential.

MCMC methods are a family of sampling methods which make use of Markov chains to generate dependent data samples. Their basic idea is to construct such Markov chains, that are easy to sample from, and whose stationary distribution is our goal distribution — such that when following them, within the limit, we obtain samples from the goal distribution.

Why do we’d like this? In a previous post I introduced basic sampling methods, amongst others covering rejection and importance sampling for complex distributions. These generate independent data samples, whereas here we generate dependent ones, as mentioned — which doesn’t answer the previous query, but is a very important distinction. Nonetheless, within the previous posts we saw that the presented methods suffer from severe limitations: it is tough to search out suited proposal distributions, specifically in high-dimensional spaces, yielding to high variance and wasteful computations.

MCMC methods, i.e. following a (easy) Markov chain fare higher in these circumstances, specifically as a result of less needed information concerning the distribution we wish to sample from, and the undeniable fact that we only must find a way to judge it as much as a set factor. That’s: we don’t must find a way to judge the total pdf for a given `x`

, `p(x)`

, however it suffices to find a way to compute `zp(x)`

. At the tip of this text we’ll see why that is so powerful, by applying it to unravel a Bayesian inference problem. In lots of tutorials and explanations this last bit is just given quite briefly as a side note — but I think this deserves — especially for beginners to Bayesian inference — more highlight.

Naturally, there are disadvantages of MCMC methods too though: as a result of the samples being correlated, the *effective* sample size shrinks, and infrequently the methods won’t converge or be very slow at it.

Since, because the name suggests (and we stated it multiple times to this point), MCMC methods are based on Marko Chains, we introduce these first.

They’re a way of modelling stochastic processes as a sequence of events. On this, the *Markovian property* states that the following state only is dependent upon the present, and never any historic information.

(Small excursus: many practical ML methods require this property, equivalent to RL. Requiring this one-step dependency may appear very limiting and impractical — nonetheless note that we will simply expand the state space to arbitrary dimension, specifically including past events in the present state — and thus totally circumventing this “limitation”.)

Formally, allow us to consider a random variable `X`

, and denote its per-timestep realisations with `X₀`

, `X₁`

, … How `X`

develops over time is given by a transition function `P`

, where

denotes that the possibility of `X`

transitioning from state `i`

to state `j`

is `p`

.

To completely specify a Markov chain, as well as we’d like to define an initial distribution for `X`

, denoted by `π₀`

. With this, we will follow the Markov chain, from `π₀`

iteratively applying `P`

, yielding the per-timestamp distributions `π₁`

, `π₂`

, …

Allow us to visualise this with an example. We selected the next transition matrix:

Note that in our notation index `ij`

denotes the transition probabilities from state `j`

to `i`

, for convenience.

We now take a random initial distribution, and follow the Markov chain for 30 steps. This will be implemented as follows in Python:

`import numpy as np`P = np.asarray([[0.3, 0.5, 0.75], [0.1, 0.1, 0.1], [0.6, 0.4, 0.15]])

print(f"Transition matrix P: {P}")

# Generate random initial distribution (normalize to acquire valid distribution).

pi = np.random.rand(3)

pi /= np.sum(pi)

print(f"Initial distribution pi_0: {pi}")

for i in range(30):

pi = np.matmul(P, pi)

if i % 5 == 0:

print(f"Distribution after i steps: {pi}")

When executing this program, we’ll get some output much like this:

`Distribution after i steps: [0.51555326 0.1 0.38444674]`

Distribution after i steps: [0.499713 0.1 0.400287]

Distribution after i steps: [0.5000053 0.1 0.3999947]

Distribution after i steps: [0.4999999 0.1 0.4000001]

Distribution after i steps: [0.5 0.1 0.4]

Distribution after i steps: [0.5 0.1 0.4]

As we will see, this Markov chain converges — for any initial distribution — to the distribution `[0.5, 0.1, 0.4] `

— which we call the stationary distribution of this Markov chain.

Before moving on, we’ll introduce a criterion, needed in the next sections, to find out whether a Markov chain converges: *detailed balance*. We are saying a Markov chain satisfies the detailed balance criterion, if there exists a distribution `π`

satisfying:

I.e., the probability of transitioning from state `j`

to state `i`

is similar because the reverse transition, considering the distribution `π`

. Intuitively this also needs to make sense, as to why this yields a stationary distribution. Be happy to persuade yourself that this criterion is satisfied for above defined Markov chain, and that indeed `[0.5, 0.1, 0.4]`

is the distribution satisfying it.

Equipped with this information, we now describe and introduce some of the common and continuously used MCMC algorithms, namely the Metropolis-Hastings algorithm. To recap, what we are attempting to do is sample values from a difficult probability distribution `f(x)`

, the goal distribution.

Let’s begin with an outline over the algorithm. Essentially, it’s made up of the next steps:

- Select an arbitrary initial value
`x₀`

within the goal distribution’s support - Draw
`y₁`

using a proposal distribution`q`

- Compute
`p₁`

(see below) - Draw
`u₁`

from the uniform distribution over [0, 1] - Set
`x₁ = y₁`

if`u₁ ≤ p₁`

, else set`x₁ = x₀`

- Repeat steps 2–5

`p₁`

is given by:

## Example

Let’s reveal this using a concrete example, implemented in Python. The setup: the goal distribution we wish to sample is a Gaussian distribution. Our proposal distribution is one other Gaussian. This naturally isn’t any real-world practical example. Nonetheless, I think and hope, that this simplified settings helps understanding, as a substitute of confusing the reader. Note that in this instance, all values of interest are 1D.

The corresponding Python code looks as follows:

`import matplotlib.pyplot as plt`

import numpy as np

import scipy.statsNUM_SAMPLES = 10000

# Goal distribution

f = scipy.stats.norm(5, 2)

# Plot goal distribution

x = np.linspace(-5, 15, 5000)

plt.plot(x, f.pdf(x))

# Step 1

x = np.random.uniform(-2, 2)

# Proposal distribution

q = scipy.stats.norm(0, 1)

samples = []

for i in range(NUM_SAMPLES):

# Step 2

y = x + q.rvs()

# Step 3

p = min(f.pdf(y) / f.pdf(x) * q.pdf(x - y) / q.pdf(y - x), 1)

# Step 4

u = np.random.uniform(0, 1)

# Step 5

x = y if u <= p else x

samples.append(x)

plt.hist(samples, density=True, bins=30)

plt.show()

Let’s go over this with some more details. To start with, we’re using scipy’s stats module to represent our goal distribution `f `

— then plot its pdf. We then define an initial value `x`

to start sampling with — simply generating one value from a uniform distribution. We then enter the sampling loop, iteratively generating `NUM_SAMPLES`

value in keeping with the algorithm introduced above. As proposal distribution we use one other Gaussian `q`

— which yields a latest value `y`

obtained by “jumping” away from `x`

in keeping with this Gaussian. It might be value noting, that the conditional evaluation of `q`

equals `q`

’s pdf with the given jump range — intuitively the further we jump, the less likely it becomes.

Executing this program should yield a result much like this:

We see that we accurately sampled from the “unknown” distribution `f`

.

## Proof of Correctness

To prove the correctness of the Metropolis-Hastings algorithm, we’d like to indicate that the used Markov chain’s stationary distribution is indeed the goal distribution. For this, we use above introduced notation of detailed balance. Remember, this involves showing that

i.e. it doesn’t matter whether we first visit state `(t-1)`

after which transition to `(t)`

, or vice versa.

Thus let’s evaluate the left side of this equation, and easily plug in our proposal distribution and the acceptance criterion:

A fast reformulation yields:

When doing this analogously for the appropriate side of above goal equation, we obtain the identical result, concluding the proof.

## Discussion

We stated within the introduction that MCMC methods like Metropolis-Hastings are superior and more efficient than, e.g., rejection sampling. That is true, still we should always put some effort into selecting q, as this selection will influence the speed of conversion. Consider again the acceptance criterion:

If we defined `q`

to be equal to `f`

, we’d get 1 — i.e. accept all samples, which is the best case. Naturally, this will not be possible, as we cannot sample from `f`

(which is why we’re doing this and sampling from `q`

as a substitute, in spite of everything). Still, this provides some intuition select `q`

. Conversely, if `q`

is poorly chosen, we’ll reject many samples, obtaining several highly correlated samples, which is an issue (the chain is “stuck” in some region).

These discussions are related to the terms *effective sample* size and *burn-in*. Since MCMC methods produce correlated, and never independent samples, when investigating these we now have to contemplate this. Specifically, this offers rise to the term effective sample size — which will be viewed because the actual sample size “cleaned” of effects as a result of correlation. Further, it is not uncommon to throw away the primary `N`

elements obtained by an MCMC algorithm (burn-in): this is principally as a result of balance out “bad” initializations, which lie in regions of low probability and out of which the proposal distribution has a tough time getting out.

As a second example of MCMC sampling methods we’ll have a have a look at Gibbs sampling. Since we already introduced underlying ideas and proved correctness for one MCMC method, we’ll go considerably faster this time — but I still desired to put it on the market to achieve sufficient depth of this tutorial, and would refer the reader to other resources for more details, or do the maths themselves.

## Overview

Gibbs sampling is applied for sampling from distributions with multiple variables, where sampling from the joint distribution `p(X, Y)`

is tough, but we do know sample the conditional distributions `p(X | Y)`

, `p(Y | X)`

. Making use of this, the employed Markov Chain iterates between sampling values for `X`

and `Y`

making use the updated conditional distributions. Thus, overall — pretty quick to introduce and implement — which we’ll do in the following section.

For the sake of introduction, we’ll consider a two-dimensional multivariate normal distribution. This multi-dimensional normal distribution is characterised by a mean vector `μ`

and a covariance matrix `Σ`

. Conveniently, the needed conditionals again are normal distributions, and defined (exemplary for `x₁`

) by:

## Implementation

To start, let’s use `scipy.stats`

to define and plot our goal distribution:

`from typing import Any`import matplotlib.pyplot as plt

import numpy as np

import numpy.typing as npt

import scipy.stats

MEAN = np.asarray([0, 0])

VARIANCE = np.asarray([[0.25, 0.3], [0.3, 1]])

def plot_multivariate(

mean: npt.NDArray[np.float32], variance: npt.NDArray[np.float32]

) -> None:

multivariate_normal = scipy.stats.multivariate_normal(mean, variance)

num_ticks = 100

min_axis_value = -5

max_axis_value = 5

x = np.linspace(min_axis_value, max_axis_value, num_ticks)

y = np.linspace(min_axis_value, max_axis_value, num_ticks)

X, Y = np.meshgrid(x, y)

pos = np.array([X.flatten(), Y.flatten()]).T

fig = plt.figure()

ax = fig.add_subplot(projection="3d")

ax.plot_surface(

X,

Y,

multivariate_normal.pdf(pos).reshape((100, 100)),

cmap="viridis",

linewidth=0,

)

plt.show()

plot_multivariate(MEAN, VARIANCE)

We should always get something like this:

Next, we execute the Gibbs sampling procedure as described above:

`def get_cond_distr_x(`

mean: npt.NDArray[np.float32], variance: npt.NDArray[np.float32], y: float

) -> Any:

mean = mean[0] + variance[0, 1] * 1 / variance[1, 1] * (y - mean[1])

var = variance[0, 0] - variance[0, 1] * 1 / variance[1, 1] * variance[1, 0]

return scipy.stats.norm(mean, var)def get_cond_distr_y(

mean: npt.NDArray[np.float32], variance: npt.NDArray[np.float32], x: float

) -> Any:

mean = mean[1] + variance[1, 0] * 1 / variance[0, 0] * (x - mean[0])

var = variance[1, 1] - variance[1, 0] * 1 / variance[0, 0] * variance[0, 1]

return scipy.stats.norm(mean, var)

def gibbs_sampling(

mean: npt.NDArray[np.float32],

variance: npt.NDArray[np.float32],

num_samples: int = 50000,

) -> npt.NDArray[np.float32]:

xs = []

ys = []

x = 0

for i in range(num_samples):

y = get_cond_distr_y(mean, variance, x).rvs()

x = get_cond_distr_x(mean, variance, y).rvs()

xs.append(x)

ys.append(y)

return np.stack((xs, ys)).transpose(1, 0)

sampled_points = gibbs_sampling(MEAN, VARIANCE)

Eventually, we persuade ourselves the obtained distribution is correct by drawing a 3D histogram:

`def draw_3d_hist(points: npt.NDArray[np.float32]) -> None:`

# Taken from https://matplotlib.org/stable/gallery/mplot3d/hist3d.html.

fig = plt.figure()

ax = fig.add_subplot(projection="3d")

hist, xedges, yedges = np.histogram2d(

points[:, 0], points[:, 1], bins=50, range=[[-5, 5], [-5, 5]]

)# Construct arrays for the anchor positions of the 16 bars.

xpos, ypos = np.meshgrid(

xedges[:-1] + 0.25, yedges[:-1] + 0.25, indexing="ij"

)

xpos = xpos.ravel()

ypos = ypos.ravel()

zpos = 0

# Construct arrays with the scale for the bars.

dx = dy = 0.5 * np.ones_like(zpos)

dz = hist.ravel()

ax.bar3d(xpos, ypos, zpos, dx, dy, dz, zsort="average")

plt.show()

draw_3d_hist(sampled_points)

To conclude this text, let’s give a real-world use case where such sampling becomes incredibly handy, and helps solve essential problems: Bayesian inference. We’ll introduce this with more details in a future post, for now: that is solving for the “full” probability distribution of a given problem, and specifically calculating the probability distribution of the parameters given the info:

These terms are commonly often called:

What makes solving this equation particularly difficult is the evidence, as this requires marginalising over all possible parameter values, i.e.:

This integral normally is tough to compute, and even intractable.

Thus, numerical approximations in the shape of MCMC methods are start line, and customary selection for solving such problems. What is especially useful in methods like Metropolis-Hastings, is their lax requirement of only needing to find a way to judge distributions as much as a certain normalisation constant — and the evidence just is such a continuing! Meaning, we will formulate and work with the posterior distribution without the tricky denominator part.

Let’s reveal this with an example: we’ll flip an (unfair) coin `N`

times, and are curious about checking out the probability `θ`

of the coin landing heads. Particularly, we don’t just need to arrive at a degree estimate, but as a substitute go Bayesian, and model the total posterior.

Let’s analyse the only terms in additional details: the results of throwing a coin follows a Bernoulli distribution, and, denoting by `Nₕ`

the observed variety of heads, and `Nₜ`

the corresponding variety of tails, for a given parameter value `θ`

this yields the next likelihood:

Next, we’d like to search out a suited prior — i.e. induce some type of belief over the estimated parameter value. Since we don’t must worry about solving the issue analytically (see: conjugate priors), we’re free to decide on any prior. Thus, we simply pick a standard distribution with mean 0.5 and standard deviation 0.2 — expressing that we expect the coin to be around 50:50, but additionally cover all of [0, 1].

These two terms are enough to run the Metropolis-Hastings algorithm. In it, we’d like to calculate `f(y)/f(x)`

for some parameter values `y`

and `x`

, and the density function `f`

we’re curious about. In our case, as mentioned, that is `p(θ|x).`

We also mentioned already, that the evidence cancels out, since that is a continuing. What stays is the next:

With these assessments and above introduction to the Metropolis-Hastings algorithm, porting above into Python must be no hard feat:

`import matplotlib.pyplot as plt`

import numpy as np

import scipy.statsNUM_THROWS = 100 # Variety of coin tosses

THETA_TRUE = 0.3 # True probability for landing heads

THETA_PRIOR = 0.5 # Prior estimate for heads probability

NUM_SAMPLES = 100000 # Variety of MCMC steps

# Define the unfair coin and generate data from it

unfair_coin = scipy.stats.bernoulli(THETA_TRUE)

D = np.asarray([unfair_coin.rvs() for _ in range(NUM_THROWS)])

# Define prior distribution

prior = scipy.stats.norm(THETA_PRIOR)

def likelihood_ratio(theta_1, theta_2):

return (theta_1 / theta_2) ** np.sum(D == 1) * (

(1 - theta_1) / (1 - theta_2)

) ** np.sum(D == 0)

def norm_ratio(theta_1, theta_2):

return prior.pdf(theta_1) / prior.pdf(theta_2)

# Step 1

x = np.random.uniform(0, 1)

# Proposal distribution

q = scipy.stats.norm(0, 0.1)

samples = []

for i in range(NUM_SAMPLES):

# Step 2

y = x + q.rvs()

# Step 3

ratio = likelihood_ratio(y, x) * norm_ratio(y, x)

p = min(ratio * q.pdf(x - y) / q.pdf(y - x), 1)

# Step 4

u = np.random.uniform(0, 1)

# Step 5

x = y if u <= p and 0 <= y <= 1 else x

samples.append(x)

plt.hist(samples, density=True, bins=100)

plt.show()

Running this could print something like:

As we will see the algorithm accurately found the true `θ`

of about 0.3, as shown by the mode of the posterior distribution. We may observe some variance, which is nice and really expected / desired — that is one among the explanations we do full Bayesian inference. Throwing a coin “only” 100 times gives us first estimate of what its true flipping probability looks like, but to me more sure we’d prefer to see more examples.

So let’s increase our dataset to 5000 throws, and inspect the output on this case:

Now, the posterior distribution indeed has a much lower variance, as expected.

On this post we introduced Markov chain Monte-carlo (MCMC) methods, that are powerful methods for numerical sampling. Such methods allow us to efficiently sample from complex distributions without too strict requirements on tractability: specifically, we only must find a way to judge distributions of interest as much as a set factor. MCMC methods work by generating a Marko chain, whose stationary distribution is the goal distribution — we thus can follow this, and effectively sample the distribution we’re after.

As a primary algorithm we introduced Metropolis-Hastings, and proved why it’s correct. It really works by introducing a proposal distribution, which is used to “jump” from the present point to a latest one. This latest point is accepted with a certain probability, which is proportional to the ratio of probability densities around these points.

Next, we discussed Gibbs sampling, which is a technique for sampling multi-dimensional distributions. Core idea is using conditional distributions and iterating through sampling a latest value for every dimension while leaving the others fixed.

Eventually, we gave a practical example of Bayesian inference. Solving this problem analytically requires solving a posh integral, making it a main example (and a quite common one) for numerical approximation. We demonstrated estimate the posterior distribution of a Bernoulli variable simulating an unfair coin toss.

I hope this post was informative for you and shed some light into this exciting field. Thanks for reading!

*All images, unless denoted otherwise, were generated by the writer.*

This post is Part 3 of a series about sampling. Yow will discover the others here:

jazz background music

studying jazz

coffee work jazz

Thank you very much for sharing, I learned a lot from your article. Very cool. Thanks. nimabi

A perfect blend of informative and entertaining, like the ideal date night conversation.