SpatialEncoder3d¶
-
class
dgl.nn.pytorch.graph_transformer.
SpatialEncoder3d
(num_kernels, num_heads=1, max_node_type=1)[source]¶ Bases:
torch.nn.modules.module.Module
3D Spatial Encoder, as introduced in One Transformer Can Understand Both 2D & 3D Molecular Data This module encodes pair-wise relation between atom pair \((i,j)\) in the 3D geometric space, according to the Gaussian Basis Kernel function:
\(\psi _{(i,j)} ^k = -\frac{1}{\sqrt{2\pi} \lvert \sigma^k \rvert} \exp{\left ( -\frac{1}{2} \left( \frac{\gamma_{(i,j)} \lvert \lvert r_i - r_j \rvert \rvert + \beta_{(i,j)} - \mu^k}{\lvert \sigma^k \rvert} \right) ^2 \right)},k=1,...,K,\)
where \(K\) is the number of Gaussian Basis kernels. \(r_i\) is the Cartesian coordinate of atom \(i\). \(\gamma_{(i,j)}, \beta_{(i,j)}\) are learnable scaling factors of the Gaussian Basis kernels.
- Parameters
num_kernels (int) – Number of Gaussian Basis Kernels to be applied. Each Gaussian Basis Kernel contains a learnable kernel center and a learnable scaling factor.
num_heads (int, optional) – Number of attention heads if multi-head attention mechanism is applied. Default : 1.
max_node_type (int, optional) – Maximum number of node types. Default : 1.
Examples
>>> import torch as th >>> import dgl >>> from dgl.nn import SpatialEncoder3d
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3]) >>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1]) >>> g = dgl.graph((u, v)) >>> coordinate = th.rand(4, 3) >>> node_type = th.tensor([1, 0, 2, 1]) >>> spatial_encoder = SpatialEncoder3d(num_kernels=4, ... num_heads=8, ... max_node_type=3) >>> out = spatial_encoder(g, coordinate, node_type=node_type) >>> print(out.shape) torch.Size([1, 4, 4, 8])
-
forward
(g, coord, node_type=None)[source]¶ - Parameters
g (DGLGraph) – A DGLGraph to be encoded, which must be a homogeneous one.
coord (torch.Tensor) – 3D coordinates of nodes in
g
, of shape \((N, 3)\), where \(N\): is the number of nodes ing
.node_type (torch.Tensor, optional) –
Node types of
g
. Default : None.If
max_node_type
is not 1,node_type
needs to be a tensor in shape \((N,)\). The scaling factors of each pair of nodes are determined by their node types.Otherwise,
node_type
should be None.
- Returns
Return attention bias as 3D spatial encoding of shape \((B, n, n, H)\), where \(B\) is the batch size, \(n\) is the maximum number of nodes in unbatched graphs from
g
, and \(H\) isnum_heads
.- Return type
torch.Tensor