SpatialEncoder

class dgl.nn.pytorch.gt.SpatialEncoder(max_dist, num_heads=1)[source]

Bases: Module

Spatial Encoder, as introduced in Do Transformers Really Perform Bad for Graph Representation?

This module is a learnable spatial embedding module, which encodes the shortest distance between each node pair for attention bias.

Parameters:
  • max_dist (int) – Upper bound of the shortest path distance between each node pair to be encoded. All distance will be clamped into the range [0, max_dist].

  • num_heads (int, optional) – Number of attention heads if multi-head attention mechanism is applied. Default : 1.

Examples

>>> import torch as th
>>> import dgl
>>> from dgl.nn import SpatialEncoder
>>> from dgl import shortest_dist
>>> g1 = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> g2 = dgl.graph(([0,1], [1,0]))
>>> n1, n2 = g1.num_nodes(), g2.num_nodes()
>>> # use -1 padding since shortest_dist returns -1 for unreachable node pairs
>>> dist = -th.ones((2, 4, 4), dtype=th.long)
>>> dist[0, :n1, :n1] = shortest_dist(g1, root=None, return_paths=False)
>>> dist[1, :n2, :n2] = shortest_dist(g2, root=None, return_paths=False)
>>> spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8)
>>> out = spatial_encoder(dist)
>>> print(out.shape)
torch.Size([2, 4, 4, 8])
forward(dist)[source]
Parameters:

dist (Tensor) – Shortest path distance of the batched graph with -1 padding, a tensor of shape \((B, N, N)\), where \(B\) is the batch size of the batched graph, and \(N\) is the maximum number of nodes.

Returns:

Return attention bias as spatial encoding of shape \((B, N, N, H)\), where \(H\) is num_heads.

Return type:

torch.Tensor