# Plotting with Radial Functions¶

### tutorial by: Tess E. Smidt (blondegeek)¶

We are going to use class methods of the FourierTensor class to plot angular and radial Fourier transforms of geometry. This allows us to use this type of data as input or output to our models.

In [1]:
from functools import partial
import numpy as np
import torch
import e3nn
from e3nn.tensor.spherical_tensor import (
SphericalTensor,
projection
)
from e3nn.tensor.fourier_tensor import (
FourierTensor,
plot_on_grid
)
import e3nn.o3 as o3
import e3nn.rs as rs

import plotly
import plotly.graph_objects as go

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

import math

torch.set_default_dtype(torch.float64)


## Examples of RadialModels for plotting¶

Note, these differ from the models in e3nn.radial since they do not inherit from torch.nn.Modules and have NO LEARNED PARAMETERS.

In [2]:
def ConstantRadialModel():
shape = r.shape

spacing = (max_radius - min_radius) / number_of_basis

shape = r.shape
radial_shape = [1] * len(shape) + [number_of_basis]

spacing = (max_radius - min_radius) / number_of_basis
gamma = 1. / spacing

shape = r.shape
radial_shape = [1] * len(shape) + [number_of_basis]
return torch.exp(-gamma * (r.unsqueeze(-1) - centers) ** 2)



## Set up coordinates for tetrahedra and set lmax¶

In [3]:
tetra_coords = torch.tensor( # The easiest way to construct a tetrahedron is using opposite corners of a box
[[0., 0., 0.], [1., 1., 0.], [1., 0., 1.], [0., 1., 1.]]
)
tetra_coords -= tetra_coords.mean(-2)

lmax = 3


## Create and plot spherical harmonic projection with radial functions¶

(Also called an angular and radial Fourier Transform)

In [4]:
n_radial = 3
x, f = sphten.plot(5.)

/pytorch/aten/src/ATen/native/BinaryOps.cpp:81: UserWarning:

Integer division of tensors using div or / is deprecated, and in a future release div will perform true division as in Python 3. Use true_divide or floor_divide (// in Python) instead.


In [5]:
plot_max = float(f.abs().max()) * 0.5
trace = go.Volume(
x=x[:,0], y=x[:,1], z=x[:,2], value=f,
isomin=-plot_max,
isomax=plot_max,
opacity=0.3, # needs to be small to see through all surfaces
surface_count=10, # needs to be a large number for good volume rendering
colorscale='RdBu'
)
go.Figure([trace])