Plotting with Radial Functions

using the e3nn repository

tutorial by: Tess E. Smidt (blondegeek)

We are going to use class methods of the FourierTensor class to plot angular and radial Fourier transforms of geometry. This allows us to use this type of data as input or output to our models.

In [1]:
from functools import partial
import numpy as np
import torch
import e3nn
from e3nn.tensor.spherical_tensor import (
    SphericalTensor,
    projection
)
from e3nn.tensor.fourier_tensor import (
    FourierTensor,
    plot_on_grid
)
import e3nn.o3 as o3
import e3nn.rs as rs

import plotly
import plotly.graph_objects as go

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

import math

torch.set_default_dtype(torch.float64)

Examples of RadialModels for plotting

Note, these differ from the models in e3nn.radial since they do not inherit from torch.nn.Modules and have NO LEARNED PARAMETERS.

In [2]:
def ConstantRadialModel():
    def radial_function(r):
        shape = r.shape
        return torch.ones(list(shape) + [1])
    return radial_function

def FixedCosineRadialModel(max_radius, number_of_basis, min_radius=0.):
    spacing = (max_radius - min_radius) / number_of_basis
    radii = torch.linspace(min_radius, max_radius, number_of_basis)
    step = radii[1] - radii[0]

    def radial_function(r):
        shape = r.shape
        radial_shape = [1] * len(shape) + [number_of_basis]
        centers = radii.reshape(*radial_shape)
        return (r.unsqueeze(-1) - centers).div(step).add(1).relu().sub(2).neg().relu().add(1).mul(math.pi / 2).cos().pow(2)
    
    return radial_function

def FixedGaussianRadialModel(max_radius, number_of_basis, min_radius=0.):
    spacing = (max_radius - min_radius) / number_of_basis
    radii = torch.linspace(min_radius, max_radius, number_of_basis)
    gamma = 1. / spacing
    
    def radial_function(r):
        shape = r.shape
        radial_shape = [1] * len(shape) + [number_of_basis]
        centers = radii.reshape(*radial_shape)
        return torch.exp(-gamma * (r.unsqueeze(-1) - centers) ** 2)
    
    return radial_function

Set up coordinates for tetrahedra and set lmax

In [3]:
tetra_coords = torch.tensor( # The easiest way to construct a tetrahedron is using opposite corners of a box
    [[0., 0., 0.], [1., 1., 0.], [1., 0., 1.], [0., 1., 1.]]
)
tetra_coords -= tetra_coords.mean(-2)

lmax = 3

Create and plot spherical harmonic projection with radial functions

(Also called an angular and radial Fourier Transform)

In [4]:
n_radial = 3
max_radius = 2.
sphten = FourierTensor.from_geometry(tetra_coords, FixedCosineRadialModel(max_radius, n_radial), lmax)
x, f = sphten.plot(5.)
/pytorch/aten/src/ATen/native/BinaryOps.cpp:81: UserWarning:

Integer division of tensors using div or / is deprecated, and in a future release div will perform true division as in Python 3. Use true_divide or floor_divide (// in Python) instead.

In [5]:
plot_max = float(f.abs().max()) * 0.5
trace = go.Volume(
    x=x[:,0], y=x[:,1], z=x[:,2], value=f,
    isomin=-plot_max,
    isomax=plot_max,
    opacity=0.3, # needs to be small to see through all surfaces
    surface_count=10, # needs to be a large number for good volume rendering
    colorscale='RdBu'
)
go.Figure([trace])

Create and plot spherical harmonic projection with magnitude as radius

In [6]:
sphten = SphericalTensor.from_geometry(tetra_coords, lmax)
r, f = sphten.plot(n=100, relu=True)
trace = go.Surface(x=r[..., 0], y=r[..., 1], z=r[..., 2], surfacecolor=f)
go.Figure(trace)
In [ ]: