"""Graphormer Layer"""
import torch.nn as nn
from .biased_mha import BiasedMHA
[docs]class GraphormerLayer(nn.Module):
r"""Graphormer Layer with Dense Multi-Head Attention, as introduced
in `Do Transformers Really Perform Bad for Graph Representation?
<https://arxiv.org/pdf/2106.05234>`__
Parameters
----------
feat_size : int
Feature size.
hidden_size : int
Hidden size of feedforward layers.
num_heads : int
Number of attention heads, by which :attr:`feat_size` is divisible.
attn_bias_type : str, optional
The type of attention bias used for modifying attention. Selected from
'add' or 'mul'. Default: 'add'.
* 'add' is for additive attention bias.
* 'mul' is for multiplicative attention bias.
norm_first : bool, optional
If True, it performs layer normalization before attention and
feedforward operations. Otherwise, it applies layer normalization
afterwards. Default: False.
dropout : float, optional
Dropout probability. Default: 0.1.
attn_dropout : float, optional
Attention dropout probability. Default: 0.1.
activation : callable activation layer, optional
Activation function. Default: nn.ReLU().
Examples
--------
>>> import torch as th
>>> from dgl.nn import GraphormerLayer
>>> batch_size = 16
>>> num_nodes = 100
>>> feat_size = 512
>>> num_heads = 8
>>> nfeat = th.rand(batch_size, num_nodes, feat_size)
>>> bias = th.rand(batch_size, num_nodes, num_nodes, num_heads)
>>> net = GraphormerLayer(
feat_size=feat_size,
hidden_size=2048,
num_heads=num_heads
)
>>> out = net(nfeat, bias)
"""
def __init__(
self,
feat_size,
hidden_size,
num_heads,
attn_bias_type="add",
norm_first=False,
dropout=0.1,
attn_dropout=0.1,
activation=nn.ReLU(),
):
super().__init__()
self.norm_first = norm_first
self.attn = BiasedMHA(
feat_size=feat_size,
num_heads=num_heads,
attn_bias_type=attn_bias_type,
attn_drop=attn_dropout,
)
self.ffn = nn.Sequential(
nn.Linear(feat_size, hidden_size),
activation,
nn.Dropout(p=dropout),
nn.Linear(hidden_size, feat_size),
nn.Dropout(p=dropout),
)
self.dropout = nn.Dropout(p=dropout)
self.attn_layer_norm = nn.LayerNorm(feat_size)
self.ffn_layer_norm = nn.LayerNorm(feat_size)
[docs] def forward(self, nfeat, attn_bias=None, attn_mask=None):
"""Forward computation.
Parameters
----------
nfeat : torch.Tensor
A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where
N is the maximum number of nodes.
attn_bias : torch.Tensor, optional
The attention bias used for attention modification. Shape:
(batch_size, N, N, :attr:`num_heads`).
attn_mask : torch.Tensor, optional
The attention mask used for avoiding computation on invalid
positions, where invalid positions are indicated by `True` values.
Shape: (batch_size, N, N). Note: For rows corresponding to
unexisting nodes, make sure at least one entry is set to `False` to
prevent obtaining NaNs with softmax.
Returns
-------
y : torch.Tensor
The output tensor. Shape: (batch_size, N, :attr:`feat_size`)
"""
residual = nfeat
if self.norm_first:
nfeat = self.attn_layer_norm(nfeat)
nfeat = self.attn(nfeat, attn_bias, attn_mask)
nfeat = self.dropout(nfeat)
nfeat = residual + nfeat
if not self.norm_first:
nfeat = self.attn_layer_norm(nfeat)
residual = nfeat
if self.norm_first:
nfeat = self.ffn_layer_norm(nfeat)
nfeat = self.ffn(nfeat)
nfeat = residual + nfeat
if not self.norm_first:
nfeat = self.ffn_layer_norm(nfeat)
return nfeat