dgl.broadcast_nodes

dgl.broadcast_nodes(graph, graph_feat, *, ntype=None)[source]

Generate a node feature equal to the graph-level feature graph_feat.

The operation is similar to numpy.repeat (or torch.repeat_interleave). It is commonly used to normalize node features by a global vector. For example, to normalize node features across graph to range \([0~1)\):

>>> g = dgl.batch([...])  # batch multiple graphs
>>> g.ndata['h'] = ...  # some node features
>>> h_sum = dgl.broadcast_nodes(g, dgl.sum_nodes(g, 'h'))
>>> g.ndata['h'] /= h_sum  # normalize by summation
Parameters:
  • graph (DGLGraph) – The graph.

  • graph_feat (tensor) – The feature to broadcast. Tensor shape is \((B, *)\) for batched graph, where \(B\) is the batch size.

  • ntype (str, optional) – Node type. Can be omitted if there is only one node type.

Returns:

The node features tensor with shape \((N, *)\), where \(N\) is the number of nodes.

Return type:

Tensor

Examples

>>> import dgl
>>> import torch as th

Create two DGLGraph objects and initialize their node features.

>>> g1 = dgl.graph(([0], [1]))                    # Graph 1
>>> g2 = dgl.graph(([0, 1], [1, 2]))              # Graph 2
>>> bg = dgl.batch([g1, g2])
>>> feat = th.rand(2, 5)
>>> feat
tensor([[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],
        [0.2721, 0.4629, 0.7269, 0.0724, 0.1014]])

Broadcast feature to all nodes in the batched graph, feat[i] is broadcast to nodes in the i-th example in the batch.

>>> dgl.broadcast_nodes(bg, feat)
tensor([[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],
        [0.4325, 0.7710, 0.5541, 0.0544, 0.9368],
        [0.2721, 0.4629, 0.7269, 0.0724, 0.1014],
        [0.2721, 0.4629, 0.7269, 0.0724, 0.1014],
        [0.2721, 0.4629, 0.7269, 0.0724, 0.1014]])

Broadcast feature to all nodes in the single graph (the feature tensor shape to broadcast should be \((1, *)\)).

>>> feat0 = th.unsqueeze(feat[0], 0)
>>> dgl.broadcast_nodes(g1, feat0)
tensor([[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],
        [0.4325, 0.7710, 0.5541, 0.0544, 0.9368]])

See also

broadcast_edges