from functools import partial
import numpy as np
import torch
import e3nn
from e3nn.tensor.spherical_tensor import (
SphericalTensor,
projection
)
from e3nn.tensor.fourier_tensor import (
FourierTensor,
plot_on_grid
)
import e3nn.o3 as o3
import e3nn.rs as rs
import plotly
import plotly.graph_objects as go
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import math
torch.set_default_dtype(torch.float64)
Note, these differ from the models in e3nn.radial
since they do not inherit from torch.nn.Modules
and have NO LEARNED PARAMETERS.
def ConstantRadialModel():
def radial_function(r):
shape = r.shape
return torch.ones(list(shape) + [1])
return radial_function
def FixedCosineRadialModel(max_radius, number_of_basis, min_radius=0.):
spacing = (max_radius - min_radius) / number_of_basis
radii = torch.linspace(min_radius, max_radius, number_of_basis)
step = radii[1] - radii[0]
def radial_function(r):
shape = r.shape
radial_shape = [1] * len(shape) + [number_of_basis]
centers = radii.reshape(*radial_shape)
return (r.unsqueeze(-1) - centers).div(step).add(1).relu().sub(2).neg().relu().add(1).mul(math.pi / 2).cos().pow(2)
return radial_function
def FixedGaussianRadialModel(max_radius, number_of_basis, min_radius=0.):
spacing = (max_radius - min_radius) / number_of_basis
radii = torch.linspace(min_radius, max_radius, number_of_basis)
gamma = 1. / spacing
def radial_function(r):
shape = r.shape
radial_shape = [1] * len(shape) + [number_of_basis]
centers = radii.reshape(*radial_shape)
return torch.exp(-gamma * (r.unsqueeze(-1) - centers) ** 2)
return radial_function
lmax
¶tetra_coords = torch.tensor( # The easiest way to construct a tetrahedron is using opposite corners of a box
[[0., 0., 0.], [1., 1., 0.], [1., 0., 1.], [0., 1., 1.]]
)
tetra_coords -= tetra_coords.mean(-2)
lmax = 3
(Also called an angular and radial Fourier Transform)
n_radial = 3
max_radius = 2.
sphten = FourierTensor.from_geometry(tetra_coords, FixedCosineRadialModel(max_radius, n_radial), lmax)
x, f = sphten.plot(5.)
plot_max = float(f.abs().max()) * 0.5
trace = go.Volume(
x=x[:,0], y=x[:,1], z=x[:,2], value=f,
isomin=-plot_max,
isomax=plot_max,
opacity=0.3, # needs to be small to see through all surfaces
surface_count=10, # needs to be a large number for good volume rendering
colorscale='RdBu'
)
go.Figure([trace])
sphten = SphericalTensor.from_geometry(tetra_coords, lmax)
r, f = sphten.plot(n=100, relu=True)
trace = go.Surface(x=r[..., 0], y=r[..., 1], z=r[..., 2], surfacecolor=f)
go.Figure(trace)