ClusterGCNSamplerΒΆ

class dgl.dataloading.ClusterGCNSampler(g, k, cache_path='cluster_gcn.pkl', balance_ntypes=None, balance_edges=False, mode='k-way', prefetch_ndata=None, prefetch_edata=None, output_device=None)[source]ΒΆ

Bases: dgl.dataloading.base.Sampler

Cluster sampler from Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks

This sampler first partitions the graph with METIS partitioning, then it caches the nodes of each partition to a file within the given cache directory.

The sampler then selects the graph partitions according to the provided partition IDs, take the union of all nodes in those partitions, and return an induced subgraph in its sample method.

Parameters
  • g (DGLGraph) – The original graph. Must be homogeneous and on CPU.

  • k (int) – The number of partitions.

  • cache_path (str) – The path to the cache directory for storing the partition result.

  • balance_ntypes – Passed to dgl.metis_partition_assignment().

  • balkance_edges – Passed to dgl.metis_partition_assignment().

  • mode – Passed to dgl.metis_partition_assignment().

  • prefetch_ndata (list[str], optional) –

    The node data to prefetch for the subgraph.

    See guide-minibatch-prefetching for a detailed explanation of prefetching.

  • prefetch_edata (list[str], optional) –

    The edge data to prefetch for the subgraph.

    See guide-minibatch-prefetching for a detailed explanation of prefetching.

  • output_device (device, optional) – The device of the output subgraphs or MFGs. Default is the same as the minibatch of partition indices.

Examples

Node classification

With this sampler, the data loader will accept the list of partition IDs as indices to iterate over. For instance, the following code first splits the graph into 1000 partitions using METIS, and at each iteration it gets a subgraph induced by the nodes covered by 20 randomly selected partitions.

>>> num_parts = 1000
>>> sampler = dgl.dataloading.ClusterGCNSampler(g, num_parts)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, torch.arange(num_parts), sampler,
...     batch_size=20, shuffle=True, drop_last=False, num_workers=4)
>>> for subg in dataloader:
...     train_on(subg)