"""Path Encoder"""
import torch as th
import torch.nn as nn

[docs]class PathEncoder(nn.Module): r"""Path Encoder, as introduced in Edge Encoding of `Do Transformers Really Perform Bad for Graph Representation? <>`__ This module is a learnable path embedding module and encodes the shortest path between each pair of nodes as attention bias. Parameters ---------- max_len : int Maximum number of edges in each path to be encoded. Exceeding part of each path will be truncated, i.e. truncating edges with serial number no less than :attr:`max_len`. feat_dim : int Dimension of edge features in the input graph. num_heads : int, optional Number of attention heads if multi-head attention mechanism is applied. Default : 1. Examples -------- >>> import torch as th >>> import dgl >>> from dgl.nn import PathEncoder >>> from dgl import shortest_dist >>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1])) >>> edata = th.rand(8, 16) >>> # Since shortest_dist returns -1 for unreachable node pairs, >>> # edata[-1] should be filled with zero padding. >>> edata = (edata, th.zeros(1, 16)), dim=0 ) >>> dist, path = shortest_dist(g, root=None, return_paths=True) >>> path_data = edata[path[:, :, :2]] >>> path_encoder = PathEncoder(2, 16, num_heads=8) >>> out = path_encoder(dist.unsqueeze(0), path_data.unsqueeze(0)) >>> print(out.shape) torch.Size([1, 4, 4, 8]) """ def __init__(self, max_len, feat_dim, num_heads=1): super().__init__() self.max_len = max_len self.feat_dim = feat_dim self.num_heads = num_heads self.embedding_table = nn.Embedding(max_len * num_heads, feat_dim)
[docs] def forward(self, dist, path_data): """ Parameters ---------- dist : Tensor Shortest path distance matrix of the batched graph with zero padding, of shape :math:`(B, N, N)`, where :math:`B` is the batch size of the batched graph, and :math:`N` is the maximum number of nodes. path_data : Tensor Edge feature along the shortest path with zero padding, of shape :math:`(B, N, N, L, d)`, where :math:`L` is the maximum length of the shortest paths, and :math:`d` is :attr:`feat_dim`. Returns ------- torch.Tensor Return attention bias as path encoding, of shape :math:`(B, N, N, H)`, where :math:`B` is the batch size of the input graph, :math:`N` is the maximum number of nodes, and :math:`H` is :attr:`num_heads`. """ shortest_distance = th.clamp(dist, min=1, max=self.max_len) edge_embedding = self.embedding_table.weight.reshape( self.max_len, self.num_heads, -1 ) path_encoding = th.div( th.einsum("bxyld,lhd->bxyh", path_data, edge_embedding).permute( 3, 0, 1, 2 ), shortest_distance, ).permute(1, 2, 3, 0) return path_encoding