Source code for dgl.nn.pytorch.conv.egnnconv

"""Torch Module for E(n) Equivariant Graph Convolutional Layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch
import torch.nn as nn

from .... import function as fn


[docs]class EGNNConv(nn.Module): r"""Equivariant Graph Convolutional Layer from `E(n) Equivariant Graph Neural Networks <https://arxiv.org/abs/2102.09844>`__ .. math:: m_{ij}=\phi_e(h_i^l, h_j^l, ||x_i^l-x_j^l||^2, a_{ij}) x_i^{l+1} = x_i^l + C\sum_{j\in\mathcal{N}(i)}(x_i^l-x_j^l)\phi_x(m_{ij}) m_i = \sum_{j\in\mathcal{N}(i)} m_{ij} h_i^{l+1} = \phi_h(h_i^l, m_i) where :math:`h_i`, :math:`x_i`, :math:`a_{ij}` are node features, coordinate features, and edge features respectively. :math:`\phi_e`, :math:`\phi_h`, and :math:`\phi_x` are two-layer MLPs. :math:`C` is a constant for normalization, computed as :math:`1/|\mathcal{N}(i)|`. Parameters ---------- in_size : int Input feature size; i.e. the size of :math:`h_i^l`. hidden_size : int Hidden feature size; i.e. the size of hidden layer in the two-layer MLPs in :math:`\phi_e, \phi_x, \phi_h`. out_size : int Output feature size; i.e. the size of :math:`h_i^{l+1}`. edge_feat_size : int, optional Edge feature size; i.e. the size of :math:`a_{ij}`. Default: 0. Example ------- >>> import dgl >>> import torch as th >>> from dgl.nn import EGNNConv >>> >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> node_feat, coord_feat, edge_feat = th.ones(6, 10), th.ones(6, 3), th.ones(6, 2) >>> conv = EGNNConv(10, 10, 10, 2) >>> h, x = conv(g, node_feat, coord_feat, edge_feat) """ def __init__(self, in_size, hidden_size, out_size, edge_feat_size=0): super(EGNNConv, self).__init__() self.in_size = in_size self.hidden_size = hidden_size self.out_size = out_size self.edge_feat_size = edge_feat_size act_fn = nn.SiLU() # \phi_e self.edge_mlp = nn.Sequential( # +1 for the radial feature: ||x_i - x_j||^2 nn.Linear(in_size * 2 + edge_feat_size + 1, hidden_size), act_fn, nn.Linear(hidden_size, hidden_size), act_fn, ) # \phi_h self.node_mlp = nn.Sequential( nn.Linear(in_size + hidden_size, hidden_size), act_fn, nn.Linear(hidden_size, out_size), ) # \phi_x self.coord_mlp = nn.Sequential( nn.Linear(hidden_size, hidden_size), act_fn, nn.Linear(hidden_size, 1, bias=False), ) def message(self, edges): """message function for EGNN""" # concat features for edge mlp if self.edge_feat_size > 0: f = torch.cat( [ edges.src["h"], edges.dst["h"], edges.data["radial"], edges.data["a"], ], dim=-1, ) else: f = torch.cat( [edges.src["h"], edges.dst["h"], edges.data["radial"]], dim=-1 ) msg_h = self.edge_mlp(f) msg_x = self.coord_mlp(msg_h) * edges.data["x_diff"] return {"msg_x": msg_x, "msg_h": msg_h}
[docs] def forward(self, graph, node_feat, coord_feat, edge_feat=None): r""" Description ----------- Compute EGNN layer. Parameters ---------- graph : DGLGraph The graph. node_feat : torch.Tensor The input feature of shape :math:`(N, h_n)`. :math:`N` is the number of nodes, and :math:`h_n` must be the same as in_size. coord_feat : torch.Tensor The coordinate feature of shape :math:`(N, h_x)`. :math:`N` is the number of nodes, and :math:`h_x` can be any positive integer. edge_feat : torch.Tensor, optional The edge feature of shape :math:`(M, h_e)`. :math:`M` is the number of edges, and :math:`h_e` must be the same as edge_feat_size. Returns ------- node_feat_out : torch.Tensor The output node feature of shape :math:`(N, h_n')` where :math:`h_n'` is the same as out_size. coord_feat_out: torch.Tensor The output coordinate feature of shape :math:`(N, h_x)` where :math:`h_x` is the same as the input coordinate feature dimension. """ with graph.local_scope(): # node feature graph.ndata["h"] = node_feat # coordinate feature graph.ndata["x"] = coord_feat # edge feature if self.edge_feat_size > 0: assert edge_feat is not None, "Edge features must be provided." graph.edata["a"] = edge_feat # get coordinate diff & radial features graph.apply_edges(fn.u_sub_v("x", "x", "x_diff")) graph.edata["radial"] = ( graph.edata["x_diff"].square().sum(dim=1).unsqueeze(-1) ) # normalize coordinate difference graph.edata["x_diff"] = graph.edata["x_diff"] / ( graph.edata["radial"].sqrt() + 1e-30 ) graph.apply_edges(self.message) graph.update_all(fn.copy_e("msg_x", "m"), fn.mean("m", "x_neigh")) graph.update_all(fn.copy_e("msg_h", "m"), fn.sum("m", "h_neigh")) h_neigh, x_neigh = graph.ndata["h_neigh"], graph.ndata["x_neigh"] h = self.node_mlp(torch.cat([node_feat, h_neigh], dim=-1)) x = coord_feat + x_neigh return h, x