dgl.DGLGraph.filter_nodes¶
-
DGLGraph.
filter_nodes
(predicate, nodes='__ALL__')[source]¶ Return a tensor of node IDs that satisfy the given predicate.
Parameters: - predicate (callable) – A function of signature
func(nodes) -> tensor
.nodes
areNodeBatch
objects as inudf
. Thetensor
returned should be a 1-D boolean tensor with each element indicating whether the corresponding node in the batch satisfies the predicate. - nodes (int, iterable or tensor of ints) – The nodes to filter on. Default value is all the nodes.
Returns: The filtered nodes.
Return type: tensor
Examples
Construct a graph object for demo.
Note
Here we use pytorch syntax for demo. The general idea applies to other frameworks with minor syntax change (e.g. replace
torch.tensor
withmxnet.ndarray
).>>> import torch as th >>> g = dgl.DGLGraph() >>> g.add_nodes(3) >>> g.ndata['x'] = th.tensor([[1.], [-1.], [1.]])
Define a function for filtering nodes with feature \(1\).
>>> def has_feature_one(nodes): return (nodes.data['x'] == 1).squeeze(1)
Filter the nodes with feature \(1\).
>>> g.filter_nodes(has_feature_one) tensor([0, 2])
See also
- predicate (callable) – A function of signature