dgl.DGLGraph.filter_edges¶
-
DGLGraph.
filter_edges
(predicate, edges='__ALL__')[source]¶ Return a tensor of edge IDs that satisfy the given predicate.
Parameters: - predicate (callable) – A function of signature
func(edges) -> tensor
.edges
areEdgeBatch
objects as inudf
. Thetensor
returned should be a 1-D boolean tensor with each element indicating whether the corresponding edge in the batch satisfies the predicate. - edges (valid edges type) – Edges on which to apply
func
. Seesend()
for valid edges type. Default value is all the edges.
Returns: The filtered edges represented by their ids.
Return type: tensor
Examples
Construct a graph object for demo.
Note
Here we use pytorch syntax for demo. The general idea applies to other frameworks with minor syntax change (e.g. replace
torch.tensor
withmxnet.ndarray
).>>> import torch as th >>> g = dgl.DGLGraph() >>> g.add_nodes(3) >>> g.ndata['x'] = th.tensor([[1.], [-1.], [1.]]) >>> g.add_edges([0, 1, 2], [2, 2, 1])
Define a function for filtering edges whose destinations have node feature \(1\).
>>> def has_dst_one(edges): return (edges.dst['x'] == 1).squeeze(1)
Filter the edges whose destination nodes have feature \(1\).
>>> g.filter_edges(has_dst_one) tensor([0, 1])
See also
- predicate (callable) – A function of signature