Source code for dgl.nn.pytorch.conv.gineconv

"""Torch Module for Graph Isomorphism Network layer variant with edge features"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch as th
import torch.nn.functional as F
from torch import nn

from .... import function as fn
from ....utils import expand_as_pair


[docs]class GINEConv(nn.Module): r"""Graph Isomorphism Network with Edge Features, introduced by `Strategies for Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__ .. math:: h_i^{(l+1)} = f_\Theta \left((1 + \epsilon) h_i^{l} + \sum_{j\in\mathcal{N}(i)}\mathrm{ReLU}(h_j^{l} + e_{j,i}^{l})\right) where :math:`e_{j,i}^{l}` is the edge feature. Parameters ---------- apply_func : callable module or None The :math:`f_\Theta` in the formula. If not None, it will be applied to the updated node features. The default value is None. init_eps : float, optional Initial :math:`\epsilon` value, default: ``0``. learn_eps : bool, optional If True, :math:`\epsilon` will be a learnable parameter. Default: ``False``. Examples -------- >>> import dgl >>> import torch >>> import torch.nn as nn >>> from dgl.nn import GINEConv >>> g = dgl.graph(([0, 1, 2], [1, 1, 3])) >>> in_feats = 10 >>> out_feats = 20 >>> nfeat = torch.randn(g.num_nodes(), in_feats) >>> efeat = torch.randn(g.num_edges(), in_feats) >>> conv = GINEConv(nn.Linear(in_feats, out_feats)) >>> res = conv(g, nfeat, efeat) >>> print(res.shape) torch.Size([4, 20]) """ def __init__(self, apply_func=None, init_eps=0, learn_eps=False): super(GINEConv, self).__init__() self.apply_func = apply_func # to specify whether eps is trainable or not. if learn_eps: self.eps = nn.Parameter(th.FloatTensor([init_eps])) else: self.register_buffer("eps", th.FloatTensor([init_eps])) def message(self, edges): r"""User-defined Message Function""" return {"m": F.relu(edges.src["hn"] + edges.data["he"])}
[docs] def forward(self, graph, node_feat, edge_feat): r"""Forward computation. Parameters ---------- graph : DGLGraph The graph. node_feat : torch.Tensor or pair of torch.Tensor If a torch.Tensor is given, it is the input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. If a pair of torch.Tensor is given, the pair must contain two tensors of shape :math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`. If ``apply_func`` is not None, :math:`D_{in}` should fit the input feature size requirement of ``apply_func``. edge_feat : torch.Tensor Edge feature. It is a tensor of shape :math:`(E, D_{in})` where :math:`E` is the number of edges. Returns ------- torch.Tensor The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` is the output feature size of ``apply_func``. If ``apply_func`` is None, :math:`D_{out}` should be the same as :math:`D_{in}`. """ with graph.local_scope(): feat_src, feat_dst = expand_as_pair(node_feat, graph) graph.srcdata["hn"] = feat_src graph.edata["he"] = edge_feat graph.update_all(self.message, fn.sum("m", "neigh")) rst = (1 + self.eps) * feat_dst + graph.dstdata["neigh"] if self.apply_func is not None: rst = self.apply_func(rst) return rst