Operations on Spherical Tensors

using the se3cnn repository

tutorial by: Tess E. Smidt

code by:

DOI

@misc{mario_geiger_2019_3348277,
  author       = {Mario Geiger and
                  Tess Smidt and
                  Wouter Boomsma and
                  Maurice Weiler and
                  MichaƂ Tyszkiewicz and
                  Jes Frellsen and
                  Benjamin K. Miller},
  title        = {mariogeiger/se3cnn: Point cloud support},
  month        = jul,
  year         = 2019,
  doi          = {10.5281/zenodo.3348277},
  url          = {https://doi.org/10.5281/zenodo.3348277}
}

We're going to use the SphericalTensor class like we did in data_types.ipynb.

In [1]:
import torch
from spherical import SphericalTensor

torch.set_default_dtype(torch.float64)

Spherical tensors can be added.

In [2]:
Rs = [(1, 1)]
sum_Ls = sum((2 * L + 1) for mult, L in Rs for _ in range(mult))
signal_1 = torch.zeros(sum_Ls)
signal_1[0] = 1.  # y
signal_2 = torch.zeros(sum_Ls)
signal_2[2] = 1.  # x
sphten_1 = SphericalTensor(signal_1, Rs)
sphten_2 = SphericalTensor(signal_2, Rs)
In [3]:
import plotly
from plotly.subplots import make_subplots

n = 50

def plot_operation(input1, input2, output):
    rows = 1
    cols = 3
    specs = [[{'is_3d': True} for i in range(cols)]
             for j in range(rows)]
    fig = make_subplots(rows=rows, cols=cols, specs=specs)
    for i, sphten in enumerate([input1, input2, output]):
        trace = sphten.plot(relu=False, n=n)
        trace.showscale = False
        fig.add_trace(trace, row=1, col=i + 1)
    fig.update_layout(scene_aspectmode='data')
    return fig

new_sphten = sphten_1 + sphten_2
# plots functions proportional to y, x, and (x + y)
fig = plot_operation(sphten_1, sphten_2, new_sphten)
fig.show()