EGTLayerο
- class dgl.nn.pytorch.gt.EGTLayer(feat_size, edge_feat_size, num_heads, num_virtual_nodes, dropout=0, attn_dropout=0, activation=ELU(alpha=1.0), edge_update=True)[source]ο
Bases:
Module
EGTLayer for Edge-augmented Graph Transformer (EGT), as introduced in `Global Self-Attention as a Replacement for Graph Convolution Reference `<https://arxiv.org/pdf/2108.03348.pdf>`_
- Parameters:
feat_size (int) β Node feature size.
edge_feat_size (int) β Edge feature size.
num_heads (int) β Number of attention heads, by which :attr: feat_size is divisible.
num_virtual_nodes (int) β Number of virtual nodes.
dropout (float, optional) β Dropout probability. Default: 0.0.
attn_dropout (float, optional) β Attention dropout probability. Default: 0.0.
activation (callable activation layer, optional) β Activation function. Default: nn.ELU().
edge_update (bool, optional) β Whether to update the edge embedding. Default: True.
Examples
>>> import torch as th >>> from dgl.nn import EGTLayer
>>> batch_size = 16 >>> num_nodes = 100 >>> feat_size, edge_feat_size = 128, 32 >>> nfeat = th.rand(batch_size, num_nodes, feat_size) >>> efeat = th.rand(batch_size, num_nodes, num_nodes, edge_feat_size) >>> net = EGTLayer( feat_size=feat_size, edge_feat_size=edge_feat_size, num_heads=8, num_virtual_nodes=4, ) >>> out = net(nfeat, efeat)
- forward(nfeat, efeat, mask=None)[source]ο
Forward computation. Note:
nfeat
andefeat
should be padded with embedding of virtual nodes ifnum_virtual_nodes
> 0, whilemask
should be padded with 0 values for virtual nodes. The padding should be put at the beginning.- Parameters:
nfeat (torch.Tensor) β A 3D input tensor. Shape: (batch_size, N,
feat_size
), where N is the sum of the maximum number of nodes and the number of virtual nodes.efeat (torch.Tensor) β Edge embedding used for attention computation and self update. Shape: (batch_size, N, N,
edge_feat_size
).mask (torch.Tensor, optional) β The attention mask used for avoiding computation on invalid positions, where valid positions are indicated by 0 and invalid positions are indicated by -inf. Shape: (batch_size, N, N). Default: None.
- Returns:
nfeat (torch.Tensor) β The output node embedding. Shape: (batch_size, N,
feat_size
).efeat (torch.Tensor, optional) β The output edge embedding. Shape: (batch_size, N, N,
edge_feat_size
). It is returned only ifedge_update
is True.