class dgl.nn.pytorch.utils.LaplacianPosEnc(model_type, num_layer, k, lpe_dim, n_head=1, batch_norm=False, num_post_layer=0)[source]

Bases: torch.nn.modules.module.Module

Laplacian Positional Encoder (LPE), as introduced in GraphGPS: General Powerful Scalable Graph Transformers This module is a learned laplacian positional encoding module using Transformer or DeepSet.

  • model_type (str) – Encoder model type for LPE, can only be “Transformer” or “DeepSet”.

  • num_layer (int) – Number of layers in Transformer/DeepSet Encoder.

  • k (int) – Number of smallest non-trivial eigenvectors.

  • lpe_dim (int) – Output size of final laplacian encoding.

  • n_head (int, optional) – Number of heads in Transformer Encoder. Default : 1.

  • batch_norm (bool, optional) – If True, apply batch normalization on raw LaplacianPE. Default : False.

  • num_post_layer (int, optional) – If num_post_layer > 0, apply an MLP of num_post_layer layers after pooling. Default : 0.


>>> import dgl
>>> from dgl import LaplacianPE
>>> from dgl.nn import LaplacianPosEnc
>>> transform = LaplacianPE(k=5, feat_name='eigvec', eigval_name='eigval', padding=True)
>>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))
>>> g = transform(g)
>>> EigVals, EigVecs = g.ndata['eigval'], g.ndata['eigvec']
>>> TransformerLPE = LaplacianPosEnc(model_type="Transformer", num_layer=3, k=5,
                                     lpe_dim=16, n_head=4)
>>> PosEnc = TransformerLPE(EigVals, EigVecs)
>>> DeepSetLPE = LaplacianPosEnc(model_type="DeepSet", num_layer=3, k=5,
                                 lpe_dim=16, num_post_layer=2)
>>> PosEnc = DeepSetLPE(EigVals, EigVecs)
forward(EigVals, EigVecs)[source]
  • EigVals (Tensor) – Laplacian Eigenvalues of shape \((N, k)\), k different eigenvalues repeat N times, can be obtained by using LaplacianPE.

  • EigVecs (Tensor) – Laplacian Eigenvectors of shape \((N, k)\), can be obtained by using LaplacianPE.


Return the laplacian positional encodings of shape \((N, lpe_dim)\), where \(N\) is the number of nodes in the input graph.

Return type