Home Artificial Intelligence Differentiable and Accelerated Spherical Harmonic Transforms

Differentiable and Accelerated Spherical Harmonic Transforms

1
Differentiable and Accelerated Spherical Harmonic Transforms

In JAX and PyTorch

Many areas of science and engineering encounter data defined on the sphere. Modelling and evaluation of such data often requires the spherical counterpart to the Fourier transform — the spherical harmonic transform. We offer a transient overview of the spherical harmonic transform and present a latest differentiable algorithm tailored towards acceleration on GPUs [1]. This algorithm is implemented within the recently released S2FFT python package, which supports each JAX and PyTorch.

[Image created by authors.]

Increasingly often we’re thinking about analysing data that lives on the sphere. The range in applications is remarkable, starting from quantum chemistry, biomedical imaging, climate physics and geophysics, to the broader cosmos.

Probably the most well-known areas during which one encounters data on the sphere are inside the physical sciences, particularly inside atmospheric science, geophysical modelling, and astrophysics.

Examples of essentially the most widely known cases of spherical data, akin to the Earth (left) and artist impression of astronomical observations (right). [Earth image sourced from Wikipedia; astrophysics image sourced from Wikipedia.]

These problems are naturally spherical as observations are made at each point on the surface of a sphere: the surface of the Earth for geophysics and the sky for astrophysics. Other examples come from applications like computer graphics and vision, where 360° panoramic cameras capture the world around you in every direction.

In lots of cases the spherical nature of the issue at hand is fairly easy to see; nevertheless, this is just not all the time the case. Perhaps surprisingly, spherical data is sort of incessantly encountered inside the biological disciplines, though the spherical aspect is commonly much less obvious! Since we are sometimes concerned about local directions in biological studies, akin to the direction water diffuses inside the brain, we encounter spherical data.

Diffusion tensor imaging of neuronal connections within the human brain. Inside each voxel neurons are free to travel in any direction, so the issue is of course spherical. [Animation by Alfred Anwander, CC-BY licence.]

Given the prevalence of such data, it isn’t surprising that many spherical evaluation techniques have been developed. A frequency evaluation of the information may be insightful, often to afford a statistical summary or an efficient representation for further evaluation or modelling. Recently geometric deep learning techniques have proven highly effective for the evaluation of information on complex domains [2–6], particularly for highly complex problems akin to molecular modelling and protein interactions (see our prior post on A Transient Introduction to Geometric Deep Learning).

So we’ve data on the sphere and a wide range of techniques by which spherical data could also be analysed, but we’d like mathematical tools to achieve this. Specifically, we’d like to know how one can decompose spherical data into frequencies efficiently.

The Fourier transforms provides a frequency decomposition that is commonly used to calculate statistical correlations inside data. Many physical systems may be described more straightforwardly in frequency space, as each frequency may evolve independently.

To increase the usual Fourier transform to the sphere, we’d like the meeting of minds of two seventeenth century French mathematicians: Joseph Fourier and Adrien-Marie Legendre.

Joseph Fourier (left) and Adrien-Marie Legendre (right). Tragically, the caricature of Legendre is the one known image of him. [Fourier image sourced from Wikipedia. Legendre image sourced from Wikipedia.]

First, let’s consider how one can decompose Euclidean data into its various frequencies. Such a change of the information was first derived by Joseph Fourier and is given by

which is found almost in all places and is a staple of undergraduate physics for a reason! This works by projecting our data f(x) onto a set of trigonometric functions, called a basis. One can do effectively the identical thing on the sphere, but the idea functions are actually given by the spherical harmonics Yₗₘ:

where (θ, ϕ) are the same old spherical polar co-ordinates.

Spherical harmonic basis functions (real component). [Sourced from Wikipedia.]

The spherical harmonics (shown above) may be broken down further into the product of an exponential and Legendre polynomials — à la Adrien-Marie Legendre — as

And so the spherical harmonic transform may be written as a Fourier transform followed by an associated Legendre transform. The true difficulty is available in evaluating the Legendre a part of the transform: it’s either computationally expensive or memory hungry, depending on the tactic one chooses.

The expansion of differentiable programming is opening up many latest kinds of evaluation. Particularly, many applications require spherical transforms which might be differentiable.

Machine learning models on the sphere require differentiable transforms in order that models could also be trained by gradient-based optimisation algorithms, i.e. through back-propagation.

Emerging physics-enhanced machine learning approaches [7] for hybrid data-driven and model-based approaches [8] also require differentiable physics models, which in lots of cases themselves require differentiable spherical transforms.

With this in mind it is evident that for contemporary applications an efficient algorithm for the spherical harmonic transform is needed but not enough. Differentiability is essential.

That is all well and good, but how does one efficiently evaluate the spherical harmonic transform? A wide range of algorithms have been developed, with some great software packages. Nonetheless for contemporary applications we’d like one which is differentiable, can run on hardware accelerators like GPUs, and is computationally scalable.

By redesigning the core algorithms from the bottom up (as described in depth in our corresponding paper [1]), we recently developed a python package called S2FFT that ought to fit the bill.

S2FFT is implemented in JAX, a differentiable programming language developed by Google, and in addition features a PyTorch frontend.

S2FFT is a Python package implementing differentiable and accelerated spherical harmonic transforms, with interfaces in JAX and PyTorch. [Image created by authors.]

S2FFT provides two operating modes: precompute the associated Legendre functions, that are then accessed at run time; or compute them on-the-fly through the transform. The pre-compute approach is nearly as fast as you possibly can get, however the memory required to store all Legendre function values scales cubicly with resolution, which could be a problem! The second approach we offer as a substitute recursively computes Legendre terms on-the-fly, and so may be scaled to very high resolutions.

As well as, S2FFT also supports a hybrid automatic and manual differentiation approach in order that gradients may be computed efficiently.

The package is designed to support multiple different sampling schemes on the sphere. At launch, equiangular (McEwen & Wiaux [9], Driscoll & Healy [10]), Gauss-Legendre, and HEALPix [11] sampling schemes are supported, although others may easily be added in future.

Different sampling schemes on the sphere supported by S2FFT. [Original figure created by authors.]

The S2FFT package is offered on PyPi so anyone can install it straightforwardly by running:

pip install s2fft

Or to select up PyTorch support by running:

pip install "s2fft[torch]"

From here the top-level transforms may be called just by

import s2fft

# Compute forward spherical harmonic transform
flm = s2fft.forward_jax(f, L)

# Compute inverse spherical harmonic transform
f = s2fft.inverse_jax(flm, L)

These functions may be picked up out of the box and integrated as layers inside existing models, each in JAX and PyTorch, with full support for each forward and reverse mode differentiation.

With researchers becoming increasingly thinking about differentiable programming for scientific applications, there may be a critical need for contemporary software packages that implement the foundational mathematical methods on which science is commonly based, just like the spherical harmonic transform.

We hope S2FFT can be of great use in coming years and are excited to see what people will use it for!

[1] Price & McEwen, Differentiable and accelerated spherical harmonic and Wigner transforms, arxiv:2311.14670 (2023).

[2] Bronstein, Bruna, Cohen, Velickovic, Geometric Deep Learning: Grids, Groups, Graphs, Geodesics, and Gauges, arXix:2104.13478 (2021).

[3] Ocampo, Price & McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023).

[4] Cobb, Wallis, Mavor-Parker, Marignier, Price, d’Avezac, McEwen, Efficient Generalised Spherical CNNs, ICLR (2021).

[5] Cohen, Geiger, Koehler, Welling, Spherical CNNs, ICLR (2018).

[6] Jumper et al., Highly accurate protein structure prediction with AlphaFold, Nature (2021).

[7] Karniadakis et al, Physics-informed machine learning, Nature Reviews Physics (2021).

[8] Campagne et al., Jax-cosmo: An end-to-end differentiable and GPU accelerated cosmology library, arXiv:2302.05163 (2023).

[9] McEwen & Wiaux, A novel sampling theorem on the sphere, IEEE TSP (2012).

[10] Driscoll & Healy, Computing Fourier Transforms and Convolutions on the 2-Sphere, AAM (1994).

[11] Gorski et al., HEALPix: a Framework for High Resolution Discretization, and Fast Evaluation of Data Distributed on the Sphere, ApJ (2005).

1 COMMENT

LEAVE A REPLY

Please enter your comment!
Please enter your name here