SubgraphSamplerΒΆ
-
class
dgl.graphbolt.
SubgraphSampler
(datapipe)[source]ΒΆ Bases:
torch.utils.data.datapipes.datapipe.IterDataPipe
[torch.utils.data.datapipes.iter.callable.T_co
]A subgraph sampler used to sample a subgraph from a given set of nodes from a larger graph.
Functional name:
sample_subgraph
.This class is the base class of all subgraph samplers. Any subclass of SubgraphSampler should implement the
sample_subgraphs()
method.- Parameters
datapipe (DataPipe) β The datapipe.
-
sample_subgraphs
(seeds, seeds_timestamp=None)[source]ΒΆ Sample subgraphs from the given seeds.
Any subclass of SubgraphSampler should implement this method.
- Parameters
seeds (Union[torch.Tensor, Dict[str, torch.Tensor]]) β The seed nodes.
- Returns
Union[torch.Tensor, Dict[str, torch.Tensor]] β The input nodes.
List[SampledSubgraph] β The sampled subgraphs.
Examples
>>> @functional_datapipe("my_sample_subgraph") >>> class MySubgraphSampler(SubgraphSampler): >>> def __init__(self, datapipe, graph, fanouts): >>> super().__init__(datapipe) >>> self.graph = graph >>> self.fanouts = fanouts >>> def sample_subgraphs(self, seeds): >>> # Sample subgraphs from the given seeds. >>> subgraphs = [] >>> subgraphs_nodes = [] >>> for fanout in reversed(self.fanouts): >>> subgraph = self.graph.sample_neighbors(seeds, fanout) >>> subgraphs.insert(0, subgraph) >>> subgraphs_nodes.append(subgraph.nodes) >>> seeds = subgraph.nodes >>> subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes)) >>> return subgraphs_nodes, subgraphs