Simple Tasks and Symmetry

using the e3nn repository

tutorial by: Tess E. Smidt (blondegeek)

There are some unintuitive consequences of using E(3) equivariant neural networks.

One example is that the symmetry your output has to be equal to or higher than the symmetry of your input. The following 3 simple tasks are to help demonstrate this:

  • Task 1: Distort a rectangle to a square.
  • Task 2: Distort a square to a rectangle.
  • Task 3: Distort a square to a rectangle -- with symmetry breaking (using representation theory).
  • Task 4: Distort a square to a rectangle -- with symmetry breaking (using gradients to change input).

We will see that we can quickly do Task 1, but not Task 2. Only by using symmetry breaking in Task 3 and Task 4 are we able to distort a square into a rectangle.

In [1]:
import torch
from functools import partial

from e3nn import o3 
from e3nn.kernel_mod import Kernel
from e3nn.tensor.spherical_tensor import SphericalTensor

import plotly
import plotly.graph_objects as go

import matplotlib.pyplot as plt

torch.set_default_dtype(torch.float64)
In [2]:
# Define out geometry
square = torch.tensor(
    [[0., 0., 0.], [1., 0., 0.], [1., 1., 0.], [0., 1., 0.]]
)
square -= square.mean(-2)
sx, sy = 0.5, 1.5
rectangle = square * torch.tensor([sx, sy, 0.])
rectangle -= rectangle.mean(-2)

N, _ = square.shape

markersize = 15

def plot_task(ax, start, finish, title, marker=None):
    ax.plot(torch.cat([start[:, 0], start[:, 0]]), 
            torch.cat([start[:, 1], start[:, 1]]), 'o-', 
            markersize=markersize + 5 if marker else markersize, 
            marker=marker if marker else 'o')
    ax.plot(torch.cat([finish[:, 0], finish[:, 0]]), 
            torch.cat([finish[:, 1], finish[:, 1]]), 'o-', markersize=markersize)
    for i in range(N):
        ax.arrow(start[i, 0], start[i, 1], 
                 finish[i, 0] - start[i, 0], 
                 finish[i, 1] - start[i, 1],
                 length_includes_head=True, head_width=0.05, facecolor="black", zorder=100)

    ax.set_title(title)
    ax.set_axis_off()

fig, axes = plt.subplots(1, 2, figsize=(9, 6))
plot_task(axes[0], rectangle, square, "Task 1: Rectangle to Square")
plot_task(axes[1], square, rectangle, "Task 2: Square to Rectangle")

In these tasks, we want to move 4 points in one configuration to another configuration. The input to the network will be the initial geometry and features on that geometry. The output will be used to signify "displacement" of each point to the new configuration. We can represent displacement in a couple different ways. The simplest way is to represent a displacement as an L=1 vector, Rs=[(1, 1]]. However, to better illustrate the symmetry properties of the network, we instead are going to use a spherical harmonic signal or more specifically, the peak of the spherical harmonic signal, to signify the displacement of the original point.

First, we set up a very basic network that has the same representation list Rs = [(1, L) for L in range(5 + 1)] throughout the entire network. The input will be a spherical tensor with representation Rs and the output will also be a spherical tensor with representation Rs. We will interpret the output of the network as a spherical harmonic signal where the peak location will signify the desired displacement.

For these examples, we will used the default e3nn.networks.GatedConvNetwork class for our model

In [3]:
from e3nn.networks import GatedConvNetwork
L_max = 5
Rs = [(1, L) for L in range(L_max + 1)]
Network = partial(GatedConvNetwork,
                  Rs_in=Rs,
                  Rs_hidden=Rs,
                  Rs_out=Rs,
                  lmax=L_max,
                  max_radius=3.0,
                  kernel=Kernel)
In [4]:
from e3nn.networks import GatedConvParityNetwork
L_max = 5
Rs = [(1, L, (-1)**L) for L in range(L_max + 1)]
Network = partial(GatedConvParityNetwork, 
                  Rs_in=Rs, 
                  mul=5, 
                  Rs_out=Rs, 
                  lmax=L_max, 
                  max_radius=3.0, 
                  layers=1,
                  feature_product=True)

Task 1: Distort a rectangle to square.

In this task, our input is a four points in the shape of a rectangle with simple scalars (1.0) at each point. The task is to learn to displace the points to form a (more symmetric) square.

In [5]:
model = Network()

optimizer = torch.optim.Adam(model.parameters(), 1e-2)
In [6]:
input = torch.zeros((L_max + 1)**2)
input[0] = 1

displacements = square - rectangle
projections = torch.stack([
    SphericalTensor.from_geometry(r, L_max).signal 
    for r in displacements
])
In [7]:
for i in range(51):
    output = model(input.repeat(1, 4, 1), rectangle[None])[0]
    loss = (output - projections).pow(2).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i % 10 == 0:
        print(i, loss)
/pytorch/aten/src/ATen/native/BinaryOps.cpp:81: UserWarning:

Integer division of tensors using div or / is deprecated, and in a future release div will perform true division as in Python 3. Use true_divide or floor_divide (// in Python) instead.

0 tensor(0.0028, grad_fn=<MeanBackward0>)
10 tensor(0.0002, grad_fn=<MeanBackward0>)
20 tensor(4.9935e-05, grad_fn=<MeanBackward0>)
30 tensor(1.8558e-05, grad_fn=<MeanBackward0>)
40 tensor(5.7461e-06, grad_fn=<MeanBackward0>)
50 tensor(2.6443e-06, grad_fn=<MeanBackward0>)
In [8]:
def plot_output(start, finish, features, start_label, finish_label, bound=None):
    if bound is None:
        bound = max(start.norm(dim=1).max(), finish.norm(dim=1).max()).item()
    axis = dict(
        showbackground=False,
        showticklabels=False,
        showgrid=False,
        zeroline=False,
        title='',
        nticks=3,
        range=[-bound, bound]
    )

    resolution = 500
    layout = dict(
        width=resolution,
        height=resolution,
        scene=dict(
            xaxis=axis,
            yaxis=axis,
            zaxis=axis,
            aspectmode='manual',
            aspectratio=dict(x=1, y=1, z=1),
            camera=dict(
                up=dict(x=0, y=1, z=0),
                center=dict(x=0, y=0, z=0),
                eye=dict(x=0, y=0, z=2),
                projection=dict(type='perspective'),
            ),
        ),
        paper_bgcolor='rgba(255,255,255,255)',
        plot_bgcolor='rgba(0,0,0,0)',
        margin=dict(l=0, r=0, t=0, b=0)
    )

    traces = [
        go.Scatter3d(x=start[:, 0], y=start[:, 1], z=start[:, 2], mode="markers", name=start_label),
        go.Scatter3d(x=finish[:, 0], y=finish[:, 1], z=finish[:, 2], mode="markers", name=finish_label),
    ]
    
    for center, signal in zip(start, features):
        r, f = SphericalTensor(signal.detach()).plot(center=center)
        traces += [go.Surface(x=r[..., 0], y=r[..., 1], z=r[..., 2], surfacecolor=f.numpy(), showscale=False)]
        
    return go.Figure(traces, layout=layout)
In [9]:
output = model(input.repeat(1, 4, 1), rectangle[None])[0]
fig = plot_output(rectangle, square, output, "Rectangle", "Square")
fig.show()

And let's check that it's equivariant

In [10]:
angles = o3.rand_angles()
rot = -o3.rot(*angles)  # rotation + parity

rot_rectangle = torch.einsum('xy,jy->jx', rot, rectangle)
rot_square = torch.einsum('xy,jy->jx', rot, square)

output = model(input.repeat(1, 4, 1), rot_rectangle[None])[0]
fig = plot_output(rot_rectangle, rot_square, output, "Rectangle", "Square")
fig.show()

Task 2: Now the reverse! Distort a square to rectangle.

In this task, our input is a four points in the shape of a square with simple scalars (1.0) at each point. The task is to learn to displace the points to form a (less symmetric) rectangle. Can the network learn this task?

In [11]:
model = Network()

optimizer = torch.optim.Adam(model.parameters(), 1e-2)
In [12]:
input = torch.zeros((L_max + 1)**2)
input[0] = 1

displacements = rectangle - square
projections = torch.stack([
    SphericalTensor.from_geometry(r, L_max).signal 
    for r in displacements
])
In [13]:
for i in range(51):
    output = model(input.repeat(1, 4, 1), square[None])[0]
    loss = (output - projections).pow(2).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i % 10 == 0:
        print(i, loss)
0 tensor(0.0014, grad_fn=<MeanBackward0>)
10 tensor(0.0008, grad_fn=<MeanBackward0>)
20 tensor(0.0007, grad_fn=<MeanBackward0>)
30 tensor(0.0007, grad_fn=<MeanBackward0>)
40 tensor(0.0007, grad_fn=<MeanBackward0>)
50 tensor(0.0007, grad_fn=<MeanBackward0>)

Hmm... seems to get stuck. Let's try more iterations.

In [14]:
for i in range(51):
    output = model(input.repeat(1, 4, 1), square[None])[0]
    loss = (output - projections).pow(2).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i % 10 == 0:
        print(i, loss)
0 tensor(0.0007, grad_fn=<MeanBackward0>)
10 tensor(0.0007, grad_fn=<MeanBackward0>)
20 tensor(0.0007, grad_fn=<MeanBackward0>)
30 tensor(0.0007, grad_fn=<MeanBackward0>)
40 tensor(0.0007, grad_fn=<MeanBackward0>)
50 tensor(0.0007, grad_fn=<MeanBackward0>)

It's stuck. What's going on?

In [15]:
fig = plot_output(square, rectangle, output, "Square", "Rectangle")
fig.show()

The symmetry of the output must be higher or equal to the symmetry of the input!

To be able to do this task, you need to give the network more information -- information that breaks the symmetry to that of the desired output. The square has a point group of $D_{4h}$ (16 elements) while the rectangle has a point group of $D_{2h}$ (8 elements).

A technical note (for those who are interested).

In this example, if we do not use a network equivariant to parity) (e.g. using GatedConvNetwork instead of GatedConvParityNetwork) -- we would be only sensitive to the fact that the square has $C_4$ symmetry while the rectangle has $C_2$ symmetry.

Task 3: Fixing Task 2. Distort a square into a rectangle -- now, with symmetry breaking (using representation theory)!

In this task, our input is four points in the shape of a square with simple scalars (1.0) AND a contribution for the $x^2 - y^2$ feature at each point. The task is to learn to displace the points to form a (less symmetric) rectangle. Can the network learn this task?

In [16]:
model = Network()

optimizer = torch.optim.Adam(model.parameters(), 1e-2)
In [17]:
input = torch.zeros((L_max + 1)**2)

# Breaking x and y symmetry with x^2 - y^2 component
input[0] = 1
input[8] = 1  # x^2 - y^2

displacements = rectangle - square
projections = torch.stack([
    SphericalTensor.from_geometry(r, L_max).signal 
    for r in displacements
])
In [18]:
for i in range(51):
    output = model(input.repeat(1, 4, 1), square[None])[0]
    loss = (output - projections).pow(2).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i % 10 == 0:
        print(i, loss)
0 tensor(0.0030, grad_fn=<MeanBackward0>)
10 tensor(7.3147e-05, grad_fn=<MeanBackward0>)
20 tensor(2.1801e-05, grad_fn=<MeanBackward0>)
30 tensor(8.0135e-06, grad_fn=<MeanBackward0>)
40 tensor(2.1762e-06, grad_fn=<MeanBackward0>)
50 tensor(1.0138e-06, grad_fn=<MeanBackward0>)
In [19]:
fig = plot_output(square, rectangle, output, "Square", "Rectangle")
fig.show()

What is $x^2 - y^2$ the term doing? It's breaking the symmetry along the $\hat{x}$ and $\hat{y}$ directions.

Notice how the shape below is an ellisoid elongated in the y direction and squished in the x. This isn't the only pertubation we could've added, but it is the most symmetric.

In [20]:
fig = plot_output(square, square, 0.3 * input.repeat(4, 1), '', '', bound=0.75)
fig.show()

Sure, but where did the $x^2 - y^2$ come from?

It's a bit of a complicated story, but at the surface level here it is: Character tables are handy tabulations of how certain spherical tensor datatypes transform under that group symmetry. The rows are irreducible representations (irrep for short) and the columns are similar elements of the group (called conjugacy classes). Character tables are most commonly seen for finite groups of $E(3)$ symmetry as they are used extensively in solid state physics, crystallography, chemistry, etc. Next to the part of the table with the "characters", there are often columns showing linear, quadratic, and cubic functions (meaning they are of order 1, 2, and 3) that transform in the same way as a given irrep.

So, a square has a point group symmetry of $D_{4h}$ while a rectangle has a point group symmetry of $D_{2h}$

If we look at column headers of character tables for $D_{4h}$ and $D_{2h}$...

... we can see that the irrep $B_{1g}$ of $D_{4h}$ that has -1's in the columns for all the symmetry operations that $D_{2h}$ DOESN'T have and if we look down that row to the column "quadratic functions" we see, voila $x^2 - y^2$. So, to break all those symmetries that $D_{4h}$ has that $D_{2h}$ DOESN'T have -- we add a non-zero contribution to the $x^2 - y^2$ component of our spherical harmonic tensors.

WARNING: Character tables are written down with specific coordinate system conventions. For example, the $\hat{z}$ axis always points along the highest symmetry axis, $\hat{y}$ along the next highest, etc. We have specifically set up our problem have a coordinate frame that matches these conventions.

A technical note (for those who are interested).

Again, in this example if we leave out parity (by using GatedConvNetwork instead of GatedParityConvNetwork), we would only be sensitive to the fact that the square has $C_4$ symmetry while the rectangle has $C_2$ symmetry. However, you can check the character tables for the point groups $C_4$ and $C_2$ to see that the arguement above still holds for the $x^2 - y^2$ order parameter.

Task 4: Fixing Task 2 without having to read character tables like Task 4. Distort a square into a rectangle -- now, with symmetry breaking (using gradients to change the input)!

In this task, our input is four points in the shape of a square with simple scalars (1.0) AND then we LEARN how to change the inputs to break symmetry such that we can fit a better model.

In [21]:
model = Network()

input = torch.zeros((L_max + 1)**2, requires_grad=True)
with torch.no_grad():
    input[0] = 1

displacements = rectangle - square
projections = torch.stack([
    SphericalTensor.from_geometry(r, L_max).signal 
    for r in displacements
])

param_optimizer = torch.optim.Adam(model.parameters(), 1e-2)
input_optimizer = torch.optim.Adam([input], 1e-3)

First, we'll train the model until it gets stuck.

In [22]:
for i in range(21):
    output = model(input.repeat(4, 1)[None], square[None])[0]
    loss = (output - projections).pow(2).mean()
    param_optimizer.zero_grad()
    loss.backward()
    param_optimizer.step()
    if i % 10 == 0:
        print(i, loss)
0 tensor(0.0033, grad_fn=<MeanBackward0>)
10 tensor(0.0008, grad_fn=<MeanBackward0>)
20 tensor(0.0007, grad_fn=<MeanBackward0>)

This gets stuck like before. So let's try alternating between updating our input and updating the model.

In [23]:
for i in range(101):
    output = model(input.repeat(4, 1)[None], square[None])[0]
    loss = (output - projections).pow(2).mean()
    param_optimizer.zero_grad()
    loss.backward()
    param_optimizer.step()
    if i % 10 == 0:
        print(i, 'model loss: ', loss)


    output = model(input.repeat(4, 1)[None], square[None])[0]
    loss = (output - projections).pow(2).mean()
    # Add sparse penalty to L=2
    loss += 1e-3 * (input[1:].abs()).mean()
    input_optimizer.zero_grad()
    loss.backward()
    input_optimizer.step()
    
    # only allow L=2 to evolve
    with torch.no_grad():
        input[0] = 1  # L=0
#         input[1**2:2**2] = 0  # L=1
#         input[3**2:] = 0  # L>=3
        
    if i % 10 == 0:
        print(i, 'input loss: ', loss)
0 model loss:  tensor(0.0007, grad_fn=<MeanBackward0>)
0 input loss:  tensor(0.0007, grad_fn=<AddBackward0>)
10 model loss:  tensor(0.0007, grad_fn=<MeanBackward0>)
10 input loss:  tensor(0.0007, grad_fn=<AddBackward0>)
20 model loss:  tensor(0.0007, grad_fn=<MeanBackward0>)
20 input loss:  tensor(0.0007, grad_fn=<AddBackward0>)
30 model loss:  tensor(0.0004, grad_fn=<MeanBackward0>)
30 input loss:  tensor(0.0004, grad_fn=<AddBackward0>)
40 model loss:  tensor(0.0001, grad_fn=<MeanBackward0>)
40 input loss:  tensor(7.9150e-05, grad_fn=<AddBackward0>)
50 model loss:  tensor(2.4258e-05, grad_fn=<MeanBackward0>)
50 input loss:  tensor(2.6928e-05, grad_fn=<AddBackward0>)
60 model loss:  tensor(5.6501e-06, grad_fn=<MeanBackward0>)
60 input loss:  tensor(6.2071e-06, grad_fn=<AddBackward0>)
70 model loss:  tensor(2.5831e-06, grad_fn=<MeanBackward0>)
70 input loss:  tensor(4.7611e-06, grad_fn=<AddBackward0>)
80 model loss:  tensor(8.7206e-07, grad_fn=<MeanBackward0>)
80 input loss:  tensor(3.0165e-06, grad_fn=<AddBackward0>)
90 model loss:  tensor(2.2301e-07, grad_fn=<MeanBackward0>)
90 input loss:  tensor(2.5631e-06, grad_fn=<AddBackward0>)
100 model loss:  tensor(1.0968e-07, grad_fn=<MeanBackward0>)
100 input loss:  tensor(2.3728e-06, grad_fn=<AddBackward0>)

If we examine the input, we should see that the only components that are (largely) non-zero are the scalar features (which are all 1's) and the features that transform as the $B_{1g}$ irrep of $D_{4h}$ such as the L=2 feature corresponding to $x^2 - y^2$, which is the 5th element of the L=2 array, and L=4 the feature corresponding to $(x^2-y^2)(7z^2 - r^2)$, which is the 7th element of the L=4 array.

In [24]:
for L in range(L_max + 1):
    print("L={}".format(L))
    print(input[L**2:(L+1)**2].detach().numpy().round(3))
L=0
[1.]
L=1
[-0.  0. -0.]
L=2
[-0.     0.     0.     0.     0.036]
L=3
[ 0.  0.  0.  0. -0.  0. -0.]
L=4
[ 0.     0.     0.     0.     0.     0.    -0.042  0.     0.   ]
L=5
[-0.  0.  0.  0.  0.  0.  0.  0. -0.  0. -0.]

This plot shows what the new input looks like. It's similar to the above plot from Task 3.

In [25]:
fig = plot_output(square, square, input.repeat(4, 1), '', '', bound=1)
fig.show()
In [ ]: