dgl.out_subgraph¶
-
dgl.
out_subgraph
(g, nodes)[source]¶ Return the subgraph induced on the out-bound edges of all the edge types of the given nodes.
An edge-induced subgraph is equivalent to creating a new graph with the same number of nodes using the given edges. In addition to extracting the subgraph, DGL conducts the following:
Copy the features of the extracted nodes and edges to the resulting graph. The copy is lazy and incurs data movement only when needed.
Store the IDs of the extracted edges in the
edata
of the resulting graph under namedgl.EID
.
If the graph is heterogeneous, DGL extracts a subgraph per relation and composes them as the resulting graph. Thus, the resulting graph has the same set of relations as the input one.
- Parameters
g (DGLGraph) – The input graph.
nodes (nodes or dict[str, nodes]) –
The nodes to form the subgraph. The allowed nodes formats are:
Int Tensor: Each element is a node ID. The tensor must have the same device type and ID data type as the graph’s.
iterable[int]: Each element is a node ID.
If the graph is homogeneous, one can directly pass the above formats. Otherwise, the argument must be a dictionary with keys being node types and values being the nodes.
- Returns
The subgraph.
- Return type
Notes
This function discards the batch information. Please use
dgl.DGLGraph.set_batch_num_nodes()
anddgl.DGLGraph.set_batch_num_edges()
on the transformed graph to maintain the information.Examples
The following example uses PyTorch backend.
>>> import dgl >>> import torch
Extract a subgraph from a homogeneous graph.
>>> g = dgl.graph(([0, 1, 2, 3, 4], [1, 2, 3, 4, 0])) # 5-node cycle >>> g.edata['w'] = torch.arange(10).view(5, 2) >>> sg = dgl.out_subgraph(g, [2, 0]) >>> sg Graph(num_nodes=5, num_edges=2, ndata_schemes={} edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}) >>> sg.edges() (tensor([2, 0]), tensor([3, 1])) >>> sg.edata[dgl.EID] # original edge IDs tensor([2, 0]) >>> sg.edata['w'] # also extract the features tensor([[4, 5], [0, 1]])
Extract a subgraph from a heterogeneous graph.
>>> g = dgl.heterograph({ ... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]), ... ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])}) >>> sub_g = g.out_subgraph({'user': [1]}) >>> sub_g Graph(num_nodes={'game': 3, 'user': 3}, num_edges={('user', 'plays', 'game'): 2, ('user', 'follows', 'user'): 2}, metagraph=[('user', 'game', 'plays'), ('user', 'user', 'follows')])
See also