Datatypes in E(3) Neural Networks

using the se3cnn repository

tutorial by: Tess E. Smidt

code by:

DOI

@misc{mario_geiger_2019_3348277,
  author       = {Mario Geiger and
                  Tess Smidt and
                  Wouter Boomsma and
                  Maurice Weiler and
                  MichaƂ Tyszkiewicz and
                  Jes Frellsen and
                  Benjamin K. Miller},
  title        = {mariogeiger/se3cnn: Point cloud support},
  month        = jul,
  year         = 2019,
  doi          = {10.5281/zenodo.3348277},
  url          = {https://doi.org/10.5281/zenodo.3348277}
}

Our data types are geometry and features on that geometry expressed as geometric tensors

Most properties of physical systems are expressed in terms of geometric tensors. Scalars (mass), vectors (velocities, forces, polarizations), matrices (polarizability, moment of inertia) and higher rank tensors are all geometric tensors.

Geometric tensors: Cartesian tensors and Spherical Tensors

Geometric tensors are commonly expressed with Cartesian indicies $(x, y, z)$ -- we will call these Cartesian tensors. However, there is an equally expressive way of representing geometric tensors as spherical tensors.

Whereas for Cartesian tensors the indices can be interpreted as information along $(x, y, z)$, spherical tensors are index by which spherical harmonic they are associated with. These representations can be used interchangable.

We use spherical tensors in our network because our convolutional filters are expressed in terms of spherical harmonics -- more about that later.

Wikipedia has a great overview of spherical harmonics. As a quick recap, the spherical harmonics are the Fourier basis for functions on the unit sphere. They have two indices, most commonly called the "degree" $L$ and "order" $m$ and are commonly parameterized by spherical coordinate angles $\theta$ and $\phi$.

$Y_{l}^{m}(\theta, \phi)$ for complex spherical harmonics or $Y_{lm}(\theta, \phi)$ for real spherical harmonics.

In se3cnn, we use real spherical harmonics. There are $2 L + 1$ functions (indexed by $m$) for each $L$. Functions of degree $L$ have the same frequency. Note, that these frequencies must be integral (or half-integral for $SU(2)$) because of the periodic boundary conditions of the sphere.

Representation Lists in se3cnn

To keep track of which spherical tensor entries correspond to which spherical harmonic, we use representation lists, commonly saved as a variable Rs.

Rs is a list of tuples (mult, L) where mult is the multiplicity (or number of copies) and L is the degree of the spherical harmonic.

For example, the Rs of a single vector is Rs_vec = [(1, 1)] and two vectors Rs_2vec = [(2, 1)]

You will sometimes see an Rs with three integers in the tuple (mult, L, parity), where the first two are the same as before and parity indicates whether that part of the tensor has equal 0 or opposite 1 parity as the spherical harmonic. All odd $L$ spherical harmonics have odd parity (they do change under parity) and all even $L$ spherical harmonics have even parity (they do NOT change under parity).

Spherical Harmonics

First, let's draw the spherical harmonics using the SphericalTensor class defined in spherical.py. This is a handy helper class that I've written for this tutorial so we can quickly manipulate and plot spherical tensors.

In [1]:
import torch 
import numpy as np
import se3cnn.SO3 as SO3
from spherical import SphericalTensor # a small Signal class written for ease of handling Spherical Tensors
import plotly
from plotly.subplots import make_subplots

torch.set_default_dtype(torch.float64)

L_max = 3
rows = L_max + 1
cols = 2 * L_max + 1

specs = [[{'is_3d': True} for i in range(cols)]
         for j in range(rows)]
fig = make_subplots(rows=rows, cols=cols, specs=specs)

for L in range(L_max + 1):
    for m in range(0, 2 * L + 1):
        tensor = torch.zeros(2 * L + 1)
        tensor[m] = 1.0
        sphten = SphericalTensor(tensor, Rs=[(1, L)])
        row, col = L + 1, (L_max - L) + m + 1
        trace = sphten.plot(relu=False, n=60)
        if m != 2 * L_max:
            trace.showscale = False
        fig.add_trace(trace, row=row, col=col)

fig.show()