dgl.DGLGraph.prop_nodes¶
-
DGLGraph.
prop_nodes
(nodes_generator, message_func='default', reduce_func='default', apply_node_func='default')[source]¶ Propagate messages using graph traversal by triggering
pull()
on nodes.The traversal order is specified by the
nodes_generator
. It generates node frontiers, which is a list or a tensor of nodes. The nodes in the same frontier will be triggered together, while nodes in different frontiers will be triggered according to the generating order.Parameters: - node_generators (iterable, each element is a list or a tensor of node ids) – The generator of node frontiers. It specifies which nodes perform
pull()
at each timestep. - message_func (callable, optional) – Message function on the edges. The function should be
an
Edge UDF
. - reduce_func (callable, optional) – Reduce function on the node. The function should be
a
Node UDF
. - apply_node_func (callable, optional) – Apply function on the nodes. The function should be
a
Node UDF
.
Examples
Create a graph for demo.
Note
Here we use pytorch syntax for demo. The general idea applies to other frameworks with minor syntax change (e.g. replace
torch.tensor
withmxnet.ndarray
).>>> import torch as th >>> g = dgl.DGLGraph() >>> g.add_nodes(4) >>> g.ndata['x'] = th.tensor([[1.], [2.], [3.], [4.]]) >>> g.add_edges([0, 1, 1, 2], [1, 2, 3, 3])
Prepare message function and message reduce function for demo.
>>> def send_source(edges): return {'m': edges.src['x']} >>> g.register_message_func(send_source) >>> def simple_reduce(nodes): return {'x': nodes.mailbox['m'].sum(1)} >>> g.register_reduce_func(simple_reduce)
First pull messages for nodes \(1, 2\) with edges
0 -> 1
and1 -> 2
; and then pull messages for node \(3\) with edges1 -> 3
and2 -> 3
.>>> g.prop_nodes([[1, 2], [3]]) >>> g.ndata['x'] tensor([[1.], [1.], [2.], [3.]])
In the first stage, we pull messages for nodes \(1, 2\). The feature of node \(1\) is replaced by that of node \(0\), i.e. 1 The feature of node \(2\) is replaced by that of node \(1\), i.e. 2. Both of the replacement happen simultaneously.
In the second stage, we pull messages for node \(3\). The feature of node \(3\) becomes the sum of node \(1\)‘s feature and \(2\)‘s feature, i.e. 1 + 2 = 3.
See also
- node_generators (iterable, each element is a list or a tensor of node ids) – The generator of node frontiers. It specifies which nodes perform