SphericalTensor
class like we did in data_types.ipynb
.¶import torch
from e3nn.tensor.spherical_tensor import SphericalTensor
from e3nn import rs
torch.set_default_dtype(torch.float64)
mul, L_max = 1, 1
signal_1 = torch.zeros((L_max + 1) ** 2)
signal_1[1 + 0] = 1. # y
signal_2 = torch.zeros((L_max + 1) ** 2)
signal_2[1 + 2] = 1. # x
sphten_1 = SphericalTensor(signal_1, L_max)
sphten_2 = SphericalTensor(signal_2, L_max)
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go
n = 50
def plot_operation(input1, input2, output):
rows = 1
cols = 3
specs = [[{'is_3d': True} for i in range(cols)]
for j in range(rows)]
fig = make_subplots(rows=rows, cols=cols, specs=specs)
for i, sphten in enumerate([input1, input2, output]):
r, f = sphten.plot(relu=False, n=n)
trace = go.Surface(x=r[..., 0], y=r[..., 1], z=r[..., 2], surfacecolor=f.numpy())
trace.showscale = False
fig.add_trace(trace, row=1, col=i + 1)
fig.update_layout(scene_aspectmode='data')
return fig
new_sphten = sphten_1 + sphten_2
# plots functions proportional to y, x, and (x + y)
fig = plot_operation(sphten_1, sphten_2, new_sphten)
fig.show()
dot_product = sphten_1.dot(sphten_2) # These functions are orthogonal
print(dot_product)
dot_product = sphten_1.dot(sphten_1) # These functions are identical
print(dot_product)
# Note that the product of two SphericalTensors is an IrrepTensor
new_irrten = sphten_1 @ sphten_2
print(type(new_irrten))
print("input1 Rs", sphten_1.Rs)
print("input2 Rs", sphten_2.Rs)
print("output Rs", new_irrten.Rs)
print(new_irrten.tensor)
# But we can take the non-trivial component and convert back to a SphericalTensor
# We want to drop the components generated from producting with L=0 features
new_sphten_signal = torch.cat([
torch.zeros(1), # L=0
new_irrten.tensor[2:2 + 3], # (nonzero) L=1
new_irrten.tensor[rs.dim(new_irrten.Rs[:-1]):] # L=2
])
print(new_sphten_signal.shape)
new_sphten = SphericalTensor(new_sphten_signal, lmax=2)
print("")
# plots functions proportional to y, x, and z + xy
print("Now we have contributions to z (cross product) and xy (outer product).")
fig = plot_operation(sphten_1, sphten_2, new_sphten)
print("SH:", " 1 y z x xy yz * zx %",)
print("new", new_sphten.signal.numpy().round(3))
print("* == 2z^2 - x^2 - y^2")
print("% == x^2 - y^2")
fig.show()