Source code for dgl.nn.mxnet.conv

"""MXNet modules for graph convolutions."""
# pylint: disable= no-member, arguments-differ
import mxnet as mx
from mxnet import gluon

from ... import function as fn
from ...utils import get_ndata_name

__all__ = ['GraphConv']

[docs]class GraphConv(gluon.Block): r"""Apply graph convolution over an input signal. Graph convolution is introduced in `GCN <https://arxiv.org/abs/1609.02907>`__ and can be described as below: .. math:: h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ij}}h_j^{(l)}W^{(l)}) where :math:`\mathcal{N}(i)` is the neighbor set of node :math:`i`. :math:`c_{ij}` is equal to the product of the square root of node degrees: :math:`\sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|}`. :math:`\sigma` is an activation function. The model parameters are initialized as in the `original implementation <https://github.com/tkipf/gcn/blob/master/gcn/layers.py>`__ where the weight :math:`W^{(l)}` is initialized using Glorot uniform initialization and the bias is initialized to be zero. Notes ----- Zero in degree nodes could lead to invalid normalizer. A common practice to avoid this is to add a self-loop for each node in the graph, which can be achieved by: >>> g = ... # some DGLGraph >>> g.add_edges(g.nodes(), g.nodes()) Parameters ---------- in_feats : int Number of input features. out_feats : int Number of output features. norm : bool, optional If True, the normalizer :math:`c_{ij}` is applied. Default: ``True``. bias : bool, optional If True, adds a learnable bias to the output. Default: ``True``. activation: callable activation function/layer or None, optional If not None, applies an activation function to the updated node features. Default: ``None``. Attributes ---------- weight : mxnet.gluon.parameter.Parameter The learnable weight tensor. bias : mxnet.gluon.parameter.Parameter The learnable bias tensor. """ def __init__(self, in_feats, out_feats, norm=True, bias=True, activation=None): super(GraphConv, self).__init__() self._in_feats = in_feats self._out_feats = out_feats self._norm = norm self._feat_name = "_gconv_feat" self._msg_name = "_gconv_msg" with self.name_scope(): self.weight = self.params.get('weight', shape=(in_feats, out_feats), init=mx.init.Xavier()) if bias: self.bias = self.params.get('bias', shape=(out_feats,), init=mx.init.Zero()) else: self.bias = None self._activation = activation
[docs] def forward(self, feat, graph): r"""Compute graph convolution. Notes ----- * Input shape: :math:`(N, *, \text{in_feats})` where * means any number of additional dimensions, :math:`N` is the number of nodes. * Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are the same shape as the input. Parameters ---------- feat : mxnet.NDArray The input feature graph : DGLGraph The graph. Returns ------- mxnet.NDArray The output feature """ self._feat_name = get_ndata_name(graph, self._feat_name) if self._norm: degs = graph.in_degrees().astype('float32') norm = mx.nd.power(degs, -0.5) shp = norm.shape + (1,) * (feat.ndim - 1) norm = norm.reshape(shp).as_in_context(feat.context) feat = feat * norm if self._in_feats > self._out_feats: # mult W first to reduce the feature size for aggregation. feat = mx.nd.dot(feat, self.weight.data(feat.context)) graph.ndata[self._feat_name] = feat graph.update_all(fn.copy_src(src=self._feat_name, out=self._msg_name), fn.sum(msg=self._msg_name, out=self._feat_name)) rst = graph.ndata.pop(self._feat_name) else: # aggregate first then mult W graph.ndata[self._feat_name] = feat graph.update_all(fn.copy_src(src=self._feat_name, out=self._msg_name), fn.sum(msg=self._msg_name, out=self._feat_name)) rst = graph.ndata.pop(self._feat_name) rst = mx.nd.dot(rst, self.weight.data(feat.context)) if self._norm: rst = rst * norm if self.bias is not None: rst = rst + self.bias.data(rst.context) if self._activation is not None: rst = self._activation(rst) return rst
def __repr__(self): summary = 'GraphConv(' summary += 'in={:d}, out={:d}, normalization={}, activation={}'.format( self._in_feats, self._out_feats, self._norm, self._activation) summary += '\n)' return summary