dgl.DGLGraph.filter_nodes¶
-
DGLGraph.
filter_nodes
(predicate, nodes='__ALL__', ntype=None)[source]¶ Return the IDs of the nodes with the given node type that satisfy the given predicate.
- Parameters
predicate (callable) – A function of signature
func(nodes) -> Tensor
.nodes
aredgl.NodeBatch
objects. Its output tensor should be a 1D boolean tensor with each element indicating whether the corresponding node in the batch satisfies the predicate.nodes (node ID(s), optional) –
The node(s) for query. The allowed formats are:
Tensor: A 1D tensor that contains the node(s) for query, whose data type and device should be the same as the
idtype
and device of the graph.iterable[int] : Similar to the tensor, but stores node IDs in a sequence (e.g. list, tuple, numpy.ndarray).
By default, it considers all nodes.
ntype (str, optional) – The node type for query. If the graph has multiple node types, one must specify the argument. Otherwise, it can be omitted.
- Returns
A 1D tensor that contains the ID(s) of the node(s) that satisfy the predicate.
- Return type
Tensor
Examples
The following example uses PyTorch backend.
Define a predicate function.
Filter nodes for a homogeneous graph.
Filter on nodes with IDs 0 and 1
Filter nodes for a heterogeneous graph.
>>> g = dgl.heterograph({ ... ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]), ... torch.tensor([0, 0, 1, 1]))}) >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [1.]]) >>> g.nodes['game'].data['h'] = torch.tensor([[0.], [1.]]) >>> # Filter for 'user' nodes >>> print(g.filter_nodes(nodes_with_feature_one, ntype='user')) tensor([1, 2])