SpatialEncoder3d

class dgl.nn.pytorch.graph_transformer.SpatialEncoder3d(num_kernels, num_heads=1, max_node_type=1)[source]

Bases: torch.nn.modules.module.Module

3D Spatial Encoder, as introduced in One Transformer Can Understand Both 2D & 3D Molecular Data This module encodes pair-wise relation between atom pair \((i,j)\) in the 3D geometric space, according to the Gaussian Basis Kernel function:

\(\psi _{(i,j)} ^k = -\frac{1}{\sqrt{2\pi} \lvert \sigma^k \rvert} \exp{\left ( -\frac{1}{2} \left( \frac{\gamma_{(i,j)} \lvert \lvert r_i - r_j \rvert \rvert + \beta_{(i,j)} - \mu^k}{\lvert \sigma^k \rvert} \right) ^2 \right)},k=1,...,K,\)

where \(K\) is the number of Gaussian Basis kernels. \(r_i\) is the Cartesian coordinate of atom \(i\). \(\gamma_{(i,j)}, \beta_{(i,j)}\) are learnable scaling factors of the Gaussian Basis kernels.

Parameters
  • num_kernels (int) – Number of Gaussian Basis Kernels to be applied. Each Gaussian Basis Kernel contains a learnable kernel center and a learnable scaling factor.

  • num_heads (int, optional) – Number of attention heads if multi-head attention mechanism is applied. Default : 1.

  • max_node_type (int, optional) – Maximum number of node types. Default : 1.

Examples

>>> import torch as th
>>> import dgl
>>> from dgl.nn import SpatialEncoder3d
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g = dgl.graph((u, v))
>>> coordinate = th.rand(4, 3)
>>> node_type = th.tensor([1, 0, 2, 1])
>>> spatial_encoder = SpatialEncoder3d(num_kernels=4,
...                                    num_heads=8,
...                                    max_node_type=3)
>>> out = spatial_encoder(g, coordinate, node_type=node_type)
>>> print(out.shape)
torch.Size([1, 4, 4, 8])
forward(g, coord, node_type=None)[source]
Parameters
  • g (DGLGraph) – A DGLGraph to be encoded, which must be a homogeneous one.

  • coord (torch.Tensor) – 3D coordinates of nodes in g, of shape \((N, 3)\), where \(N\): is the number of nodes in g.

  • node_type (torch.Tensor, optional) –

    Node types of g. Default : None.

    • If max_node_type is not 1, node_type needs to be a tensor in shape \((N,)\). The scaling factors of each pair of nodes are determined by their node types.

    • Otherwise, node_type should be None.

Returns

Return attention bias as 3D spatial encoding of shape \((B, n, n, H)\), where \(B\) is the batch size, \(n\) is the maximum number of nodes in unbatched graphs from g, and \(H\) is num_heads.

Return type

torch.Tensor