RelGraphConv

class dgl.nn.pytorch.conv.RelGraphConv(in_feat, out_feat, num_rels, regularizer=None, num_bases=None, bias=True, activation=None, self_loop=True, dropout=0.0, layer_norm=False)[source]

Bases: torch.nn.modules.module.Module

Relational graph convolution layer from Modeling Relational Data with Graph Convolutional Networks

It can be described in as below:

\[h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}} \sum_{j\in\mathcal{N}^r(i)}e_{j,i}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)})\]

where \(\mathcal{N}^r(i)\) is the neighbor set of node \(i\) w.r.t. relation \(r\). \(e_{j,i}\) is the normalizer. \(\sigma\) is an activation function. \(W_0\) is the self-loop weight.

The basis regularization decomposes \(W_r\) by:

\[W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)}\]

where \(B\) is the number of bases, \(V_b^{(l)}\) are linearly combined with coefficients \(a_{rb}^{(l)}\).

The block-diagonal-decomposition regularization decomposes \(W_r\) into \(B\) number of block diagonal matrices. We refer \(B\) as the number of bases.

The block regularization decomposes \(W_r\) by:

\[W_r^{(l)} = \oplus_{b=1}^B Q_{rb}^{(l)}\]

where \(B\) is the number of bases, \(Q_{rb}^{(l)}\) are block bases with shape \(R^{(d^{(l+1)}/B)*(d^{l}/B)}\).

Parameters
  • in_feat (int) – Input feature size; i.e, the number of dimensions of \(h_j^{(l)}\).

  • out_feat (int) – Output feature size; i.e., the number of dimensions of \(h_i^{(l+1)}\).

  • num_rels (int) – Number of relations.

  • regularizer (str, optional) –

    Which weight regularizer to use (“basis”, “bdd” or None):

    • ”basis” is for basis-decomposition.

    • ”bdd” is for block-diagonal-decomposition.

    • None applies no regularization.

    Default: None.

  • num_bases (int, optional) – Number of bases. It comes into effect when a regularizer is applied. If None, it uses number of relations (num_rels). Default: None. Note that in_feat and out_feat must be divisible by num_bases when applying “bdd” regularizer.

  • bias (bool, optional) – True if bias is added. Default: True.

  • activation (callable, optional) – Activation function. Default: None.

  • self_loop (bool, optional) – True to include self loop message. Default: True.

  • dropout (float, optional) – Dropout rate. Default: 0.0

  • layer_norm (bool, optional) – True to add layer norm. Default: False

Examples

>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import RelGraphConv
>>>
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = th.ones(6, 10)
>>> conv = RelGraphConv(10, 2, 3, regularizer='basis', num_bases=2)
>>> etype = th.tensor([0,1,2,0,1,2])
>>> res = conv(g, feat, etype)
>>> res
tensor([[ 0.3996, -2.3303],
        [-0.4323, -0.1440],
        [ 0.3996, -2.3303],
        [ 2.1046, -2.8654],
        [-0.4323, -0.1440],
        [-0.1309, -1.0000]], grad_fn=<AddBackward0>)
forward(g, feat, etypes, norm=None, *, presorted=False)[source]

Forward computation.

Parameters
  • g (DGLGraph) – The graph.

  • feat (torch.Tensor) – A 2D tensor of node features. Shape: \((|V|, D_{in})\).

  • etypes (torch.Tensor or list[int]) – An 1D integer tensor of edge types. Shape: \((|E|,)\).

  • norm (torch.Tensor, optional) – An 1D tensor of edge norm value. Shape: \((|E|,)\).

  • presorted (bool, optional) – Whether the edges of the input graph have been sorted by their types. Forward on pre-sorted graph may be faster. Graphs created by to_homogeneous() automatically satisfy the condition. Also see reorder_graph() for sorting edges manually.

Returns

New node features. Shape: \((|V|, D_{out})\).

Return type

torch.Tensor