Making a Gradient Descent Animation in Python

-

Find out how to plot the trajectory of some extent over a fancy surface

Photo by Todd Diemer on Unsplash

Let me inform you how I created an animation of gradient descent just as an instance some extent in a blog post. It was value it since I learned more Python by doing it and unlocked a latest skill: making animated plots.

Animation of two different points descending a saddle surface.
Gradient descent animation created in Python. Image by the writer.

I’ll walk you thru the steps of the method I followed.

A little bit of background

Just a few days ago, I published a blog post about gradient descent as an optimization algorithm used for training Artificial Neural Networks.

I wanted to incorporate an animated figure to indicate how selecting different initialization points for a gradient descent optimization can produce different results.

That’s once I stumbled upon these amazing animations created by Alec Radford years ago, and shared on a Reddit comment, illustrating the difference between some advanced gradient descent algorithms, like Adagrad, Adadelta and RMSprop.

Since I’ve been pushing myself to replace Matlab with Python, I made a decision to offer it a go and check out to code an analogous animation myself, using a “vanilla” gradient descent algorithm to begin with.

Let’s go, step-by-step.

Plot the surface used for optimization

The very first thing we do is import the libraries we’ll need and define the mathematical function we’ll wish to represent.

I wanted to make use of a saddle point surface, so I defined the next equation:

We also create a grid of points for plotting our surface. np.mgrid is ideal for this. The complex number 81j passed as step length indicates what number of points to create between the beginning and stop values (81 points).

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

# Create a function to compute the surface
def f(theta):
x = theta[0]
y = theta[1]
return x**2 - y**2

# Make a grid of points for plotting
x, y = np.mgrid[-1:1:81j, -1:1:81j]

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x