GraphDataLoaderΒΆ
-
class
dgl.dataloading.
GraphDataLoader
(dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs)[source]ΒΆ Bases:
Generic
[torch.utils.data.dataloader.T_co
]Batched graph data loader.
PyTorch dataloader for batch-iterating over a set of graphs, generating the batched graph and corresponding label tensor (if provided) of the said minibatch.
- Parameters
dataset (torch.utils.data.Dataset) β The dataset to load graphs from.
collate_fn (Function, default is None) β The customized collate function. Will use the default collate function if not given.
use_ddp (boolean, optional) β
If True, tells the DataLoader to split the training set for each participating process appropriately using
torch.utils.data.distributed.DistributedSampler
.Overrides the
sampler
argument oftorch.utils.data.DataLoader
.ddp_seed (int, optional) β
The seed for shuffling the dataset in
torch.utils.data.distributed.DistributedSampler
.Only effective when
use_ddp
is True.kwargs (dict) β
Key-word arguments to be passed to the parent PyTorch
torch.utils.data.DataLoader
class. Common arguments are:batch_size
(int): The number of indices in each batch.drop_last
(bool): Whether to drop the last incomplete batch.shuffle
(bool): Whether to randomly shuffle the indices at each epoch.
Examples
To train a GNN for graph classification on a set of graphs in
dataset
:>>> dataloader = dgl.dataloading.GraphDataLoader( ... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for batched_graph, labels in dataloader: ... train_on(batched_graph, labels)
With Distributed Data Parallel
If you are using PyTorchβs distributed training (e.g. when using
torch.nn.parallel.DistributedDataParallel
), you can train the model by turning on theuse_ddp
option:>>> dataloader = dgl.dataloading.GraphDataLoader( ... dataset, use_ddp=True, batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for epoch in range(start_epoch, n_epochs): ... dataloader.set_epoch(epoch) ... for batched_graph, labels in dataloader: ... train_on(batched_graph, labels)