dgl.in_subgraph(g, nodes)[source]

Return the subgraph induced on the inbound edges of all the edge types of the given nodes.

An edge-induced subgraph is equivalent to creating a new graph with the same number of nodes using the given edges. In addition to extracting the subgraph, DGL conducts the following:

  • Copy the features of the extracted nodes and edges to the resulting graph. The copy is lazy and incurs data movement only when needed.

  • Store the IDs of the extracted edges in the edata of the resulting graph under name dgl.EID.

If the graph is heterogeneous, DGL extracts a subgraph per relation and composes them as the resulting graph. Thus, the resulting graph has the same set of relations as the input one.

  • g (DGLGraph) – The input graph.

  • nodes (nodes or dict[str, nodes]) –

    The nodes to form the subgraph. The allowed nodes formats are:

    • Int Tensor: Each element is an ID. The tensor must have the same device type and ID data type as the graph’s.

    • iterable[int]: Each element is an ID.

    If the graph is homogeneous, one can directly pass the above formats. Otherwise, the argument must be a dictionary with keys being node types and values being the nodes.


The subgraph.

Return type



This function discards the batch information. Please use dgl.DGLGraph.set_batch_num_nodes() and dgl.DGLGraph.set_batch_num_edges() on the transformed graph to maintain the information.


The following example uses PyTorch backend.

>>> import dgl
>>> import torch

Extract a subgraph from a homogeneous graph.

>>> g = dgl.graph(([0, 1, 2, 3, 4], [1, 2, 3, 4, 0]))  # 5-node cycle
>>> g.edata['w'] = torch.arange(10).view(5, 2)
>>> sg = dgl.in_subgraph(g, [2, 0])
>>> sg
Graph(num_nodes=5, num_edges=2,
      edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64),
                     '_ID': Scheme(shape=(), dtype=torch.int64)})
>>> sg.edges()
(tensor([1, 4]), tensor([2, 0]))
>>> sg.edata[dgl.EID]  # original edge IDs
tensor([1, 4])
>>> sg.edata['w']  # also extract the features
tensor([[2, 3],
        [8, 9]])

Extract a subgraph from a heterogeneous graph.

>>> g = dgl.heterograph({
...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),
...     ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])})
>>> sub_g = g.in_subgraph({'user': [2], 'game': [2]})
>>> sub_g
Graph(num_nodes={'game': 3, 'user': 3},
      num_edges={('user', 'plays', 'game'): 1, ('user', 'follows', 'user'): 2},
      metagraph=[('user', 'game', 'plays'), ('user', 'user', 'follows')])

See also