SetTransformerDecoderΒΆ

class dgl.nn.pytorch.glob.SetTransformerDecoder(d_model, num_heads, d_head, d_ff, n_layers, k, dropouth=0.0, dropouta=0.0)[source]ΒΆ

Bases: torch.nn.modules.module.Module

The Decoder module from Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks

Parameters
  • d_model (int) – Hidden size of the model.

  • num_heads (int) – The number of heads.

  • d_head (int) – Hidden size of each head.

  • d_ff (int) – Kernel size in FFN (Positionwise Feed-Forward Network) layer.

  • n_layers (int) – The number of layers.

  • k (int) – The number of seed vectors in PMA (Pooling by Multihead Attention) layer.

  • dropouth (float) – Dropout rate of each sublayer.

  • dropouta (float) – Dropout rate of attention heads.

Examples

>>> import dgl
>>> import torch as th
>>> from dgl.nn import SetTransformerDecoder
>>>
>>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges
>>> g1_node_feats = th.rand(3, 5)  # feature size is 5
>>> g1_node_feats
tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],
        [0.8137, 0.8938, 0.8377, 0.4249, 0.6118],
        [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])
>>>
>>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges
>>> g2_node_feats = th.rand(4, 5)  # feature size is 5
>>> g2_node_feats
tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],
        [0.5278, 0.6365, 0.9990, 0.2351, 0.8945],
        [0.3134, 0.0580, 0.4349, 0.7949, 0.3891],
        [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])
>>>
>>> set_trans_dec = SetTransformerDecoder(5, 4, 4, 20, 1, 3)  # define the layer

Case 1: Input a single graph

>>> set_trans_dec(g1, g1_node_feats)
tensor([[-0.5538,  1.8726, -1.0470,  0.0276, -0.2994, -0.6317,  1.6754, -1.3189,
          0.2291,  0.0461, -0.4042,  0.8387, -1.7091,  1.0845,  0.1902]],
       grad_fn=<ViewBackward>)

Case 2: Input a batch of graphs

Build a batch of DGL graphs and concatenate all graphs’ node features into one tensor.

>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats])
>>>
>>> set_trans_dec(batch_g, batch_f)
tensor([[-0.5538,  1.8726, -1.0470,  0.0276, -0.2994, -0.6317,  1.6754, -1.3189,
          0.2291,  0.0461, -0.4042,  0.8387, -1.7091,  1.0845,  0.1902],
        [-0.5511,  1.8869, -1.0156,  0.0028, -0.3231, -0.6305,  1.6845, -1.3105,
          0.2136,  0.0428, -0.3820,  0.8043, -1.7138,  1.1126,  0.1789]],
       grad_fn=<ViewBackward>)
forward(graph, feat)[source]ΒΆ

Compute the decoder part of Set Transformer.

Parameters
  • graph (DGLGraph) – The input graph.

  • feat (torch.Tensor) – The input feature with shape \((N, D)\), where \(N\) is the number of nodes in the graph, and \(D\) means the size of features.

Returns

The output feature with shape \((B, D)\), where \(B\) refers to the batch size.

Return type

torch.Tensor