Source code for dgl.nn.pytorch.conv.tagconv

"""Torch Module for Topology Adaptive Graph Convolutional layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch as th
from torch import nn

from .... import function as fn
from .graphconv import EdgeWeightNorm


[docs]class TAGConv(nn.Module): r"""Topology Adaptive Graph Convolutional layer from `Topology Adaptive Graph Convolutional Networks <https://arxiv.org/pdf/1710.10370.pdf>`__ .. math:: H^{K} = {\sum}_{k=0}^K (D^{-1/2} A D^{-1/2})^{k} X {\Theta}_{k}, where :math:`A` denotes the adjacency matrix, :math:`D_{ii} = \sum_{j=0} A_{ij}` its diagonal degree matrix, :math:`{\Theta}_{k}` denotes the linear weights to sum the results of different hops together. Parameters ---------- in_feats : int Input feature size. i.e, the number of dimensions of :math:`X`. out_feats : int Output feature size. i.e, the number of dimensions of :math:`H^{K}`. k: int, optional Number of hops :math:`K`. Default: ``2``. bias: bool, optional If True, adds a learnable bias to the output. Default: ``True``. activation: callable activation function/layer or None, optional If not None, applies an activation function to the updated node features. Default: ``None``. Attributes ---------- lin : torch.Module The learnable linear module. Example ------- >>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import TAGConv >>> >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 10) >>> conv = TAGConv(10, 2, k=2) >>> res = conv(g, feat) >>> res tensor([[ 0.5490, -1.6373], [ 0.5490, -1.6373], [ 0.5490, -1.6373], [ 0.5513, -1.8208], [ 0.5215, -1.6044], [ 0.3304, -1.9927]], grad_fn=<AddmmBackward>) """ def __init__( self, in_feats, out_feats, k=2, bias=True, activation=None, ): super(TAGConv, self).__init__() self._in_feats = in_feats self._out_feats = out_feats self._k = k self._activation = activation self.lin = nn.Linear(in_feats * (self._k + 1), out_feats, bias=bias) self.reset_parameters()
[docs] def reset_parameters(self): r""" Description ----------- Reinitialize learnable parameters. Note ---- The model parameters are initialized using Glorot uniform initialization. """ gain = nn.init.calculate_gain("relu") nn.init.xavier_normal_(self.lin.weight, gain=gain)
[docs] def forward(self, graph, feat, edge_weight=None): r""" Description ----------- Compute topology adaptive graph convolution. Parameters ---------- graph : DGLGraph The graph. feat : torch.Tensor The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. edge_weight: torch.Tensor, optional edge_weight to use in the message passing process. This is equivalent to using weighted adjacency matrix in the equation above, and :math:`\tilde{D}^{-1/2}\tilde{A} \tilde{D}^{-1/2}` is based on :class:`dgl.nn.pytorch.conv.graphconv.EdgeWeightNorm`. Returns ------- torch.Tensor The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` is size of output feature. """ with graph.local_scope(): assert graph.is_homogeneous, "Graph is not homogeneous" if edge_weight is None: norm = th.pow(graph.in_degrees().to(feat).clamp(min=1), -0.5) shp = norm.shape + (1,) * (feat.dim() - 1) norm = th.reshape(norm, shp).to(feat.device) msg_func = fn.copy_u("h", "m") if edge_weight is not None: graph.edata["_edge_weight"] = EdgeWeightNorm("both")( graph, edge_weight ) msg_func = fn.u_mul_e("h", "_edge_weight", "m") # D-1/2 A D -1/2 X fstack = [feat] for _ in range(self._k): if edge_weight is None: rst = fstack[-1] * norm else: rst = fstack[-1] graph.ndata["h"] = rst graph.update_all(msg_func, fn.sum(msg="m", out="h")) rst = graph.ndata["h"] if edge_weight is None: rst = rst * norm fstack.append(rst) rst = self.lin(th.cat(fstack, dim=-1)) if self._activation is not None: rst = self._activation(rst) return rst