dgl.DGLGraph.filter_edges

DGLGraph.filter_edges(predicate, edges='__ALL__')[source]

Return a tensor of edge IDs that satisfy the given predicate.

Parameters:
  • predicate (callable) – A function of signature func(edges) -> tensor. edges are EdgeBatch objects as in udf. The tensor returned should be a 1-D boolean tensor with each element indicating whether the corresponding edge in the batch satisfies the predicate.
  • edges (valid edges type) – Edges on which to apply func. See send() for valid edges type. Default value is all the edges.
Returns:

The filtered edges represented by their ids.

Return type:

tensor

Examples

Construct a graph object for demo.

Note

Here we use pytorch syntax for demo. The general idea applies to other frameworks with minor syntax change (e.g. replace torch.tensor with mxnet.ndarray).

>>> import torch as th
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.ndata['x'] = th.tensor([[1.], [-1.], [1.]])
>>> g.add_edges([0, 1, 2], [2, 2, 1])

Define a function for filtering edges whose destinations have node feature \(1\).

>>> def has_dst_one(edges): return (edges.dst['x'] == 1).squeeze(1)

Filter the edges whose destination nodes have feature \(1\).

>>> g.filter_edges(has_dst_one)
tensor([0, 1])

See also

filter_nodes()