Operations on Spherical Tensors

using the e3nn repository

tutorial by: Tess E. Smidt (blondegeek)

We're going to use the SphericalTensor class like we did in data_types.ipynb.

In [1]:
import torch
from e3nn.tensor.spherical_tensor import SphericalTensor
from e3nn import rs

torch.set_default_dtype(torch.float64)

Spherical tensors can be added.

In [2]:
mul, L_max = 1, 1
signal_1 = torch.zeros((L_max + 1) ** 2)
signal_1[1 + 0] = 1.  # y
signal_2 = torch.zeros((L_max + 1) ** 2)
signal_2[1 + 2] = 1.  # x
sphten_1 = SphericalTensor(signal_1, L_max)
sphten_2 = SphericalTensor(signal_2, L_max)
In [3]:
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go

n = 50

def plot_operation(input1, input2, output):
    rows = 1
    cols = 3
    specs = [[{'is_3d': True} for i in range(cols)]
             for j in range(rows)]
    fig = make_subplots(rows=rows, cols=cols, specs=specs)
    for i, sphten in enumerate([input1, input2, output]):
        r, f = sphten.plot(relu=False, n=n)
        trace = go.Surface(x=r[..., 0], y=r[..., 1], z=r[..., 2], surfacecolor=f.numpy())
        trace.showscale = False
        fig.add_trace(trace, row=1, col=i + 1)
    fig.update_layout(scene_aspectmode='data')
    return fig

new_sphten = sphten_1 + sphten_2
# plots functions proportional to y, x, and (x + y)
fig = plot_operation(sphten_1, sphten_2, new_sphten)
fig.show()

We can compute the dot product of two spherical tensors

In [4]:
dot_product = sphten_1.dot(sphten_2) # These functions are orthogonal
print(dot_product)

dot_product = sphten_1.dot(sphten_1) # These functions are identical
print(dot_product)
tensor(0.)
tensor(1.)

We CANNOT multiply two spherical tensors, but we can compute their tensor product and use Clebsch-Gordon coeffients to combine two tensor indices into one

In [5]:
# Note that the product of two SphericalTensors is an IrrepTensor
new_irrten = sphten_1 @ sphten_2
print(type(new_irrten))

print("input1 Rs", sphten_1.Rs)
print("input2 Rs", sphten_2.Rs)
print("output Rs", new_irrten.Rs)

print(new_irrten.tensor)
# But we can take the non-trivial component and convert back to a SphericalTensor
# We want to drop the components generated from producting with L=0 features
new_sphten_signal = torch.cat([
        torch.zeros(1),  # L=0
        new_irrten.tensor[2:2 + 3],  # (nonzero) L=1
        new_irrten.tensor[rs.dim(new_irrten.Rs[:-1]):]  # L=2
])
print(new_sphten_signal.shape)
new_sphten = SphericalTensor(new_sphten_signal, lmax=2)

print("")
# plots functions proportional to y, x, and z + xy
print("Now we have contributions to z (cross product) and xy (outer product).")
fig = plot_operation(sphten_1, sphten_2, new_sphten)
print("SH:", "  1      y      z      x      xy     yz     *      zx     %",)
print("new", new_sphten.signal.numpy().round(3))
print("* == 2z^2 - x^2 - y^2")
print("% == x^2 - y^2")
fig.show()
<class 'e3nn.tensor.irrep_tensor.IrrepTensor'>
input1 Rs [(1, 0, 0), (1, 1, 0)]
input2 Rs [(1, 0, 0), (1, 1, 0)]
output Rs [(2, 0, 0), (3, 1, 0), (1, 2, 0)]
tensor([ 0.0000,  0.0000,  0.0000, -0.7071,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000, -0.7071,  0.0000,  0.0000,  0.0000,  0.0000])
torch.Size([9])

Now we have contributions to z (cross product) and xy (outer product).
SH:   1      y      z      x      xy     yz     *      zx     %
new [ 0.     0.    -0.707  0.    -0.707  0.     0.     0.     0.   ]
* == 2z^2 - x^2 - y^2
% == x^2 - y^2
In [ ]: