class dgl.nn.pytorch.graph_transformer.BiasedMultiheadAttention(feat_size, num_heads, bias=True, attn_bias_type='add', attn_drop=0.1)[source]

Bases: torch.nn.modules.module.Module

Dense Multi-Head Attention Module with Graph Attention Bias.

Compute attention between nodes with attention bias obtained from graph structures, as introduced in Do Transformers Really Perform Bad for Graph Representation?

\[\text{Attn}=\text{softmax}(\dfrac{QK^T}{\sqrt{d}} \circ b)\]

\(Q\) and \(K\) are feature representation of nodes. \(d\) is the corresponding feat_size. \(b\) is attention bias, which can be additive or multiplicative according to the operator \(\circ\).

  • feat_size (int) – Feature size.

  • num_heads (int) – Number of attention heads, by which attr:feat_size is divisible.

  • bias (bool, optional) – If True, it uses bias for linear projection. Default: True.

  • attn_bias_type (str, optional) –

    The type of attention bias used for modifying attention. Selected from ‘add’ or ‘mul’. Default: ‘add’.

    • ’add’ is for additive attention bias.

    • ’mul’ is for multiplicative attention bias.

  • attn_drop (float, optional) – Dropout probability on attention weights. Defalt: 0.1.


>>> import torch as th
>>> from dgl.nn import BiasedMultiheadAttention
>>> ndata = th.rand(16, 100, 512)
>>> bias = th.rand(16, 100, 100, 8)
>>> net = BiasedMultiheadAttention(feat_size=512, num_heads=8)
>>> out = net(ndata, bias)
forward(ndata, attn_bias=None, attn_mask=None)[source]

Forward computation.

  • ndata (torch.Tensor) – A 3D input tensor. Shape: (batch_size, N, feat_size), where N is the maximum number of nodes.

  • attn_bias (torch.Tensor, optional) – The attention bias used for attention modification. Shape: (batch_size, N, N, num_heads).

  • attn_mask (torch.Tensor, optional) – The attention mask used for avoiding computation on invalid positions, where invalid positions are indicated by non-zero values. Shape: (batch_size, N, N).


y – The output tensor. Shape: (batch_size, N, feat_size)

Return type



Reset parameters of projection matrices, the same settings as that in Graphormer.