Plotting with Radial Functions

using the e3nn repository

tutorial by: Tess E. Smidt (blondegeek)

We are going to use class methods of the SphericalTensor class to plot angular and radial Fourier transforms of geometry. This allows us to use this type of data as input or output to our models.

In [1]:
from functools import partial
import numpy as np
import torch
import e3nn
from spherical import plot_data_on_grid, SphericalTensor, projection
import e3nn.o3 as o3
import e3nn.rs as rs

import plotly
import plotly.graph_objects as go

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

import math

Examples of RadialModels for plotting

Note, these differ from the models in e3nn.radial since they do not inherit from torch.nn.Modules and have NO LEARNED PARAMETERS.

In [2]:
def ConstantRadialModel():
    def radial_function(r):
        shape = r.shape
        return torch.ones(list(shape) + [1])
    return radial_function

def FixedCosineRadialModel(max_radius, number_of_basis, min_radius=0.):
    spacing = (max_radius - min_radius) / number_of_basis
    radii = torch.linspace(min_radius, max_radius, number_of_basis)
    step = radii[1] - radii[0]

    def radial_function(r):
        shape = r.shape
        radial_shape = [1] * len(shape) + [number_of_basis]
        centers = radii.reshape(*radial_shape)
        return (r.unsqueeze(-1) - centers).div(step).add(1).relu().sub(2).neg().relu().add(1).mul(math.pi / 2).cos().pow(2)
    
    return radial_function

def FixedGaussianRadialModel(max_radius, number_of_basis, min_radius=0.):
    spacing = (max_radius - min_radius) / number_of_basis
    radii = torch.linspace(min_radius, max_radius, number_of_basis)
    gamma = 1. / spacing
    
    def radial_function(r):
        shape = r.shape
        radial_shape = [1] * len(shape) + [number_of_basis]
        centers = radii.reshape(*radial_shape)
        return torch.exp(-gamma * (r.unsqueeze(-1) - centers) ** 2)
    
    return radial_function

Set up coordinates for tetrahedra and set lmax

In [3]:
tetra_coords = torch.tensor( # The easiest way to construct a tetrahedron is using opposite corners of a box
    [[0., 0., 0.], [1., 1., 0.], [1., 0., 1.], [0., 1., 1.]]
)
tetra_coords -= tetra_coords.mean(-2)

lmax = 3

Create and plot spherical harmonic projection with radial functions

(Also called an angular and radial Fourier Transform)

In [4]:
n_radial = 3
max_radius = 2.
sphten = SphericalTensor.from_geometry_with_radial(tetra_coords, FixedCosineRadialModel(max_radius, n_radial), lmax)
x, f = sphten.plot_with_radial(5.)
In [5]:
plot_max = float(f.abs().max()) * 0.5
trace = go.Volume(
    x=x[:,0], y=x[:,1], z=x[:,2], value=f,
    isomin=-plot_max,
    isomax=plot_max,
    opacity=0.3, # needs to be small to see through all surfaces
    surface_count=10, # needs to be a large number for good volume rendering
    colorscale='RdBu'
)
go.Figure([trace])