# SpatialEncoder3dο

Bases: Module

3D Spatial Encoder, as introduced in One Transformer Can Understand Both 2D & 3D Molecular Data

This module encodes pair-wise relation between node pair $$(i,j)$$ in the 3D geometric space, according to the Gaussian Basis Kernel function:

$$\psi _{(i,j)} ^k = \frac{1}{\sqrt{2\pi} \lvert \sigma^k \rvert} \exp{\left ( -\frac{1}{2} \left( \frac{\gamma_{(i,j)} \lvert \lvert r_i - r_j \rvert \rvert + \beta_{(i,j)} - \mu^k}{\lvert \sigma^k \rvert} \right) ^2 \right)}οΌk=1,...,K,$$

where $$K$$ is the number of Gaussian Basis kernels. $$r_i$$ is the Cartesian coordinate of node $$i$$. $$\gamma_{(i,j)}, \beta_{(i,j)}$$ are learnable scaling factors and biases determined by node types. $$\mu^k, \sigma^k$$ are learnable centers and standard deviations of the Gaussian Basis kernels.

Parameters:
• num_kernels (int) β Number of Gaussian Basis Kernels to be applied. Each Gaussian Basis Kernel contains a learnable kernel center and a learnable standard deviation.

• num_heads (int, optional) β Number of attention heads if multi-head attention mechanism is applied. Default : 1.

• max_node_type (int, optional) β Maximum number of node types. Each node type has a corresponding learnable scaling factor and a bias. Default : 100.

Examples

>>> import torch as th
>>> import dgl
>>> from dgl.nn import SpatialEncoder3d

>>> coordinate = th.rand(1, 4, 3)
>>> node_type = th.tensor([[1, 0, 2, 1]])
>>> spatial_encoder = SpatialEncoder3d(num_kernels=4,
...                                    max_node_type=3)
>>> out = spatial_encoder(coordinate, node_type=node_type)
>>> print(out.shape)
torch.Size([1, 4, 4, 8])

forward(coord, node_type=None)[source]ο
Parameters:
• coord (torch.Tensor) β 3D coordinates of nodes in shape $$(B, N, 3)$$, where $$B$$ is the batch size, $$N$$: is the maximum number of nodes.

• node_type (torch.Tensor, optional) β

Node type ids of nodes. Default : None.

• If specified, node_type should be a tensor in shape $$(B, N,)$$. The scaling factors in gaussian kernels of each pair of nodes are determined by their node types.

• Otherwise, node_type will be set to zeros of the same shape by default.

Returns:

Return attention bias as 3D spatial encoding of shape $$(B, N, N, H)$$, where $$H$$ is num_heads.

Return type:

torch.Tensor