SphericalTensor class like we did in data_types.ipynb.¶import torch
from e3nn.tensor.spherical_tensor import SphericalTensor
from e3nn import rs
torch.set_default_dtype(torch.float64)
mul, L_max = 1, 1
signal_1 = torch.zeros((L_max + 1) ** 2)
signal_1[1 + 0] = 1.  # y
signal_2 = torch.zeros((L_max + 1) ** 2)
signal_2[1 + 2] = 1.  # x
sphten_1 = SphericalTensor(signal_1, L_max)
sphten_2 = SphericalTensor(signal_2, L_max)
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go
n = 50
def plot_operation(input1, input2, output):
    rows = 1
    cols = 3
    specs = [[{'is_3d': True} for i in range(cols)]
             for j in range(rows)]
    fig = make_subplots(rows=rows, cols=cols, specs=specs)
    for i, sphten in enumerate([input1, input2, output]):
        r, f = sphten.plot(relu=False, n=n)
        trace = go.Surface(x=r[..., 0], y=r[..., 1], z=r[..., 2], surfacecolor=f.numpy())
        trace.showscale = False
        fig.add_trace(trace, row=1, col=i + 1)
    fig.update_layout(scene_aspectmode='data')
    return fig
new_sphten = sphten_1 + sphten_2
# plots functions proportional to y, x, and (x + y)
fig = plot_operation(sphten_1, sphten_2, new_sphten)
fig.show()
dot_product = sphten_1.dot(sphten_2) # These functions are orthogonal
print(dot_product)
dot_product = sphten_1.dot(sphten_1) # These functions are identical
print(dot_product)
# Note that the product of two SphericalTensors is an IrrepTensor
new_irrten = sphten_1 @ sphten_2
print(type(new_irrten))
print("input1 Rs", sphten_1.Rs)
print("input2 Rs", sphten_2.Rs)
print("output Rs", new_irrten.Rs)
print(new_irrten.tensor)
# But we can take the non-trivial component and convert back to a SphericalTensor
# We want to drop the components generated from producting with L=0 features
new_sphten_signal = torch.cat([
        torch.zeros(1),  # L=0
        new_irrten.tensor[2:2 + 3],  # (nonzero) L=1
        new_irrten.tensor[rs.dim(new_irrten.Rs[:-1]):]  # L=2
])
print(new_sphten_signal.shape)
new_sphten = SphericalTensor(new_sphten_signal, lmax=2)
print("")
# plots functions proportional to y, x, and z + xy
print("Now we have contributions to z (cross product) and xy (outer product).")
fig = plot_operation(sphten_1, sphten_2, new_sphten)
print("SH:", "  1      y      z      x      xy     yz     *      zx     %",)
print("new", new_sphten.signal.numpy().round(3))
print("* == 2z^2 - x^2 - y^2")
print("% == x^2 - y^2")
fig.show()