dgl.DGLGraph.pull¶
-
DGLGraph.
pull
(v, message_func='default', reduce_func='default', apply_node_func='default', inplace=False)[source]¶ Pull messages from the node(s)’ predecessors and then update their features.
Optionally, apply a function to update the node features after receive.
- reduce_func will be skipped for nodes with no incoming message.
- If all
v
have no incoming message, this will downgrade to anapply_nodes()
. - If some
v
have no incoming message, their new feature value will be calculated by the column initializer (seeset_n_initializer()
). The feature shapes and dtypes will be inferred.
Parameters: - v (int, iterable of int, or tensor) – The node(s) to be updated.
- message_func (callable, optional) – Message function on the edges. The function should be
an
Edge UDF
. - reduce_func (callable, optional) – Reduce function on the node. The function should be
a
Node UDF
. - apply_node_func (callable, optional) – Apply function on the nodes. The function should be
a
Node UDF
. - inplace (bool, optional) – If True, update will be done in place, but autograd will break.
Examples
Create a graph for demo.
Note
Here we use pytorch syntax for demo. The general idea applies to other frameworks with minor syntax change (e.g. replace
torch.tensor
withmxnet.ndarray
).>>> import torch as th >>> g = dgl.DGLGraph() >>> g.add_nodes(3) >>> g.ndata['x'] = th.tensor([[0.], [1.], [2.]])
Use the built-in message function
copy_src()
for copying node features as the message.>>> m_func = dgl.function.copy_src('x', 'm') >>> g.register_message_func(m_func)
Use the built-int message reducing function
sum()
, which sums the messages received and replace the old node features with it.>>> m_reduce_func = dgl.function.sum('m', 'x') >>> g.register_reduce_func(m_reduce_func)
As no edges exist, nothing happens.
>>> g.pull(g.nodes()) >>> g.ndata['x'] tensor([[0.], [1.], [2.]])
Add edges
0 -> 1, 1 -> 2
. Pull messages for the node \(2\).>>> g.add_edges([0, 1], [1, 2]) >>> g.pull(2) >>> g.ndata['x'] tensor([[0.], [1.], [1.]])
The feature of node \(2\) changes but the feature of node \(1\) remains the same as we did not
pull()
(and reduce) messages for it.See also