TWIRLSConvΒΆ

class dgl.nn.pytorch.conv.TWIRLSConv(input_d, output_d, hidden_d, prop_step, num_mlp_before=1, num_mlp_after=1, norm='none', precond=True, alp=0, lam=1, attention=False, tau=0.2, T=- 1, p=1, use_eta=False, attn_bef=False, dropout=0.0, attn_dropout=0.0, inp_dropout=0.0)[source]ΒΆ

Bases: torch.nn.modules.module.Module

Convolution together with iteratively reweighting least squre from Graph Neural Networks Inspired by Classical Iterative Algorithms

Parameters
  • input_d (int) – Number of input features.

  • output_d (int) – Number of output features.

  • hidden_d (int) – Size of hidden layers.

  • prop_step (int) – Number of propagation steps

  • num_mlp_before (int) – Number of mlp layers before propagation. Default: 1.

  • num_mlp_after (int) – Number of mlp layers after propagation. Default: 1.

  • norm (str) – The type of norm layers inside mlp layers. Can be 'batch', 'layer' or 'none'. Default: 'none'

  • precond (str) – If True, use pre conditioning and unormalized laplacian, else not use pre conditioning and use normalized laplacian. Default: True

  • alp (float) – The \(\alpha\) in paper. If equal to \(0\), will be automatically decided based on other hyper prameters. Default: 0.

  • lam (float) – The \(\lambda\) in paper. Default: 1.

  • attention (bool) – If True, add an attention layer inside propagations. Default: False.

  • tau (float) – The \(\tau\) in paper. Default: 0.2.

  • T (float) – The \(T\) in paper. If < 0, \(T\) will be set to infty. Default: -1.

  • p (float) – The \(p\) in paper. Default: 1.

  • use_eta (bool) – If True, add a learnable weight on each dimension in attention. Default: False.

  • attn_bef (bool) – If True, add another attention layer before propagation. Default: False.

  • dropout (float) – The dropout rate in mlp layers. Default: 0.0.

  • attn_dropout (float) – The dropout rate of attention values. Default: 0.0.

  • inp_dropout (float) – The dropout rate on input features. Default: 0.0.

Note

add_self_loop will be automatically called before propagation.

Example

>>> import dgl
>>> from dgl.nn import TWIRLSConv
>>> import torch as th
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = th.ones(6, 10)
>>> conv = TWIRLSConv(10, 2, 128, prop_step = 64)
>>> res = conv(g , feat)
>>> res.size()
torch.Size([6, 2])
forward(graph, feat)[source]ΒΆ

Run TWIRLS forward.

Parameters
  • graph (DGLGraph) – The graph.

  • feat (torch.Tensor) – The initial node features.

Returns

The output feature

Return type

torch.Tensor

Note

  • Input shape: \((N, \text{input_d})\) where \(N\) is the number of nodes.

  • Output shape: \((N, \text{output_d})\).