JumpingKnowledge

class dgl.nn.pytorch.utils.JumpingKnowledge(mode='cat', in_feats=None, num_layers=None)[source]

Bases: torch.nn.modules.module.Module

The Jumping Knowledge aggregation module from Representation Learning on Graphs with Jumping Knowledge Networks

It aggregates the output representations of multiple GNN layers with

concatenation

\[h_i^{(1)} \, \Vert \, \ldots \, \Vert \, h_i^{(T)}\]

or max pooling

\[\max \left( h_i^{(1)}, \ldots, h_i^{(T)} \right)\]

or LSTM

\[\sum_{t=1}^T \alpha_i^{(t)} h_i^{(t)}\]

with attention scores \(\alpha_i^{(t)}\) obtained from a BiLSTM

Parameters
  • mode (str) – The aggregation to apply. It can be ‘cat’, ‘max’, or ‘lstm’, corresponding to the equations above in order.

  • in_feats (int, optional) – This argument is only required if mode is 'lstm'. The output representation size of a single GNN layer. Note that all GNN layers need to have the same output representation size.

  • num_layers (int, optional) – This argument is only required if mode is 'lstm'. The number of GNN layers for output aggregation.

Examples

>>> import dgl
>>> import torch as th
>>> from dgl.nn import JumpingKnowledge
>>> # Output representations of two GNN layers
>>> num_nodes = 3
>>> in_feats = 4
>>> feat_list = [th.zeros(num_nodes, in_feats), th.ones(num_nodes, in_feats)]
>>> # Case1
>>> model = JumpingKnowledge()
>>> model(feat_list).shape
torch.Size([3, 8])
>>> # Case2
>>> model = JumpingKnowledge(mode='max')
>>> model(feat_list).shape
torch.Size([3, 4])
>>> # Case3
>>> model = JumpingKnowledge(mode='max', in_feats=in_feats, num_layers=len(feat_list))
>>> model(feat_list).shape
torch.Size([3, 4])
forward(feat_list)[source]

Aggregate output representations across multiple GNN layers.

Parameters

feat_list (list[Tensor]) – feat_list[i] is the output representations of a GNN layer.

Returns

The aggregated representations.

Return type

Tensor

reset_parameters()[source]

Reinitialize learnable parameters. This comes into effect only for the lstm mode.