dgl.distributed
DGL distributed module contains classes and functions to support distributed Graph Neural Network training and inference on a cluster of machines.
This includes a few submodules:
distributed data structures including distributed graph, distributed tensor and distributed embeddings.
distributed sampling.
distributed workload split at runtime.
graph partition.
Initialization
|
Initialize DGL's distributed module |
Distributed Graph
- class dgl.distributed.DistGraph(graph_name, gpb=None, part_config=None)[source]
The class for accessing a distributed graph.
This class provides a subset of DGLGraph APIs for accessing partitioned graph data in distributed GNN training and inference. Thus, its main use case is to work with distributed sampling APIs to generate mini-batches and perform forward and backward computation on the mini-batches.
The class can run in two modes: the standalone mode and the distributed mode.
When a user runs the training script normally,
DistGraph
will be in the standalone mode. In this mode, the input data must be constructed bypartition_graph()
with only one partition. This mode is used for testing and debugging purpose. In this mode, users have to providepart_config
so thatDistGraph
can load the input graph.When a user runs the training script with the distributed launch script,
DistGraph
will be set into the distributed mode. This is used for actual distributed training. All data of partitions are loaded by theDistGraph
servers, which are created by DGL’s launch script.DistGraph
connects with the servers to access the partitioned graph data.
Currently, the
DistGraph
servers and clients run on the same set of machines in the distributed mode.DistGraph
uses shared-memory to access the partition data in the local machine. This gives the best performance for distributed trainingUsers may want to run
DistGraph
servers and clients on separate sets of machines. In this case, a user may want to disable shared memory by passingdisable_shared_mem=False
when creatingDistGraphServer
. When shared memory is disabled, a user has to pass a partition book.- Parameters:
graph_name (str) – The name of the graph. This name has to be the same as the one used for partitioning a graph in
dgl.distributed.partition.partition_graph()
.gpb (GraphPartitionBook, optional) – The partition book object. Normally, users do not need to provide the partition book. This argument is necessary only when users want to run server process and trainer processes on different machines.
part_config (str, optional) – The path of partition configuration file generated by
dgl.distributed.partition.partition_graph()
. It’s used in the standalone mode.
Examples
The example shows the creation of
DistGraph
in the standalone mode.>>> dgl.distributed.partition_graph(g, 'graph_name', 1, num_hops=1, part_method='metis', ... out_path='output/') >>> g = dgl.distributed.DistGraph('graph_name', part_config='output/graph_name.json')
The example shows the creation of
DistGraph
in the distributed mode.>>> g = dgl.distributed.DistGraph('graph-name')
The code below shows the mini-batch training using
DistGraph
.>>> def sample(seeds): ... seeds = th.LongTensor(np.asarray(seeds)) ... frontier = dgl.distributed.sample_neighbors(g, seeds, 10) ... return dgl.to_block(frontier, seeds) >>> dataloader = dgl.distributed.DistDataLoader(dataset=nodes, batch_size=1000, ... collate_fn=sample, shuffle=True) >>> for block in dataloader: ... feat = g.ndata['features'][block.srcdata[dgl.NID]] ... labels = g.ndata['labels'][block.dstdata[dgl.NID]] ... pred = model(block, feat)
Note
DGL’s distributed training by default runs server processes and trainer processes on the same set of machines. If users need to run them on different sets of machines, it requires manually setting up servers and trainers. The setup is not fully tested yet.
- barrier()[source]
Barrier for all client nodes.
This API blocks the current process untill all the clients invoke this API. Please use this API with caution.
- property device
Get the device context of this graph.
Examples
The following example uses PyTorch backend.
>>> g = dgl.heterograph({ ... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]) ... }) >>> print(g.device) device(type='cpu') >>> g = g.to('cuda:0') >>> print(g.device) device(type='cuda', index=0)
- Return type:
Device context object
- property edata
Return the data view of all the edges.
- Returns:
The data view in the distributed graph storage.
- Return type:
EdgeDataView
- edge_attr_schemes()[source]
Return the edge feature schemes.
Each feature scheme is a named tuple that stores the shape and data type of the edge feature.
- Returns:
The schemes of edge feature columns.
- Return type:
dict of str to schemes
Examples
The following uses PyTorch backend.
>>> g.edge_attr_schemes() {'h': Scheme(shape=(4,), dtype=torch.float32)}
See also
- property edges
Return an edge view
- property etypes
Return the list of edge types of this graph.
Examples
>>> g = DistGraph("test") >>> g.etypes ['_E']
- find_edges(edges, etype=None)[source]
Given an edge ID array, return the source and destination node ID array
s
andd
.s[i]
andd[i]
are source and destination node ID for edgeeid[i]
.- Parameters:
edges (Int Tensor) –
- Each element is an ID. The tensor must have the same device type
and ID data type as the graph’s.
etype (str or (str, str, str), optional) –
The type names of the edges. The allowed type name formats are:
(str, str, str)
for source node type, edge type and destination node type.or one
str
edge type name if the name can uniquely identify a triplet format in the graph.
Can be omitted if the graph has only one type of edges.
- Returns:
tensor – The source node ID array.
tensor – The destination node ID array.
- get_edge_partition_policy(etype)[source]
Get the partition policy for an edge type.
When creating a new distributed tensor, we need to provide a partition policy that indicates how to distribute data of the distributed tensor in a cluster of machines. When we load a distributed graph in the cluster, we have pre-defined partition policies for each node type and each edge type. By providing the edge type, we can reference to the pre-defined partition policy for the edge type.
- Parameters:
- Returns:
The partition policy for the edge type.
- Return type:
- get_etype_id(etype)[source]
Return the id of the given edge type.
etype can also be None. If so, there should be only one edge type in the graph.
- get_node_partition_policy(ntype)[source]
Get the partition policy for a node type.
When creating a new distributed tensor, we need to provide a partition policy that indicates how to distribute data of the distributed tensor in a cluster of machines. When we load a distributed graph in the cluster, we have pre-defined partition policies for each node type and each edge type. By providing the node type, we can reference to the pre-defined partition policy for the node type.
- Parameters:
ntype (str) – The node type
- Returns:
The partition policy for the node type.
- Return type:
- get_ntype_id(ntype)[source]
Return the ID of the given node type.
ntype can also be None. If so, there should be only one node type in the graph.
- get_partition_book()[source]
Get the partition information.
- Returns:
Object that stores all graph partition information.
- Return type:
- property idtype
The dtype of graph index
- Returns:
th.int32/th.int64 or tf.int32/tf.int64 etc.
- Return type:
backend dtype object
See also
long
,int
- in_degrees(v='__ALL__')[source]
Return the in-degree(s) of the given nodes.
It computes the in-degree(s). It does not support heterogeneous graphs yet.
- Parameters:
v (node IDs) –
The node IDs. The allowed formats are:
int
: A single node.Int Tensor: Each element is a node ID. The tensor must have the same device type and ID data type as the graph’s.
iterable[int]: Each element is a node ID.
If not given, return the in-degrees of all the nodes.
- Returns:
The in-degree(s) of the node(s) in a Tensor. The i-th element is the in-degree of the i-th input node. If
v
is anint
, return anint
too.- Return type:
int or Tensor
Examples
The following example uses PyTorch backend.
>>> import dgl >>> import torch
Query for all nodes.
>>> g.in_degrees() tensor([0, 2, 1, 1])
Query for nodes 1 and 2.
>>> g.in_degrees(torch.tensor([1, 2])) tensor([2, 1])
See also
- property local_partition
Return the local partition on the client
DistGraph provides a global view of the distributed graph. Internally, it may contains a partition of the graph if it is co-located with the server. When servers and clients run on separate sets of machines, this returns None.
- Returns:
The local partition
- Return type:
- property ndata
Return the data view of all the nodes.
- Returns:
The data view in the distributed graph storage.
- Return type:
NodeDataView
- node_attr_schemes()[source]
Return the node feature schemes.
Each feature scheme is a named tuple that stores the shape and data type of the node feature.
- Returns:
The schemes of node feature columns.
- Return type:
dict of str to schemes
Examples
The following uses PyTorch backend.
>>> g.node_attr_schemes() {'h': Scheme(shape=(4,), dtype=torch.float32)}
See also
- property nodes
Return a node view
- property ntypes
Return the list of node types of this graph.
Examples
>>> g = DistGraph("test") >>> g.ntypes ['_U']
- num_edges(etype=None)[source]
Return the total number of edges in the distributed graph.
- Parameters:
etype (str or (str, str, str), optional) –
The type name of the edges. The allowed type name formats are:
(str, str, str)
for source node type, edge type and destination node type.or one
str
edge type name if the name can uniquely identify a triplet format in the graph.
If not provided, return the total number of edges regardless of the types in the graph.
- Returns:
The number of edges
- Return type:
Examples
>>> g = dgl.distributed.DistGraph('ogb-product') >>> print(g.num_edges()) 123718280
- num_nodes(ntype=None)[source]
Return the total number of nodes in the distributed graph.
- Parameters:
ntype (str, optional) – The node type name. If given, it returns the number of nodes of the type. If not given (default), it returns the total number of nodes of all types.
- Returns:
The number of nodes
- Return type:
Examples
>>> g = dgl.distributed.DistGraph('ogb-product') >>> print(g.num_nodes()) 2449029
- number_of_edges(etype=None)[source]
Alias of
num_edges()
- number_of_nodes(ntype=None)[source]
Alias of
num_nodes()
- out_degrees(u='__ALL__')[source]
Return the out-degree(s) of the given nodes.
It computes the out-degree(s). It does not support heterogeneous graphs yet.
- Parameters:
u (node IDs) –
The node IDs. The allowed formats are:
int
: A single node.Int Tensor: Each element is a node ID. The tensor must have the same device type and ID data type as the graph’s.
iterable[int]: Each element is a node ID.
If not given, return the in-degrees of all the nodes.
- Returns:
The out-degree(s) of the node(s) in a Tensor. The i-th element is the out-degree of the i-th input node. If
v
is anint
, return anint
too.- Return type:
int or Tensor
Examples
The following example uses PyTorch backend.
>>> import dgl >>> import torch
Query for all nodes.
>>> g.out_degrees() tensor([2, 2, 0, 0])
Query for nodes 1 and 2.
>>> g.out_degrees(torch.tensor([1, 2])) tensor([2, 0])
See also
Distributed Tensor
- class dgl.distributed.DistTensor(shape, dtype, name=None, init_func=None, part_policy=None, persistent=False, is_gdata=True, attach=True)[source]
Distributed tensor.
DistTensor
references to a distributed tensor sharded and stored in a cluster of machines. It has the same interface as Pytorch Tensor to access its metadata (e.g., shape and data type). To access data in a distributed tensor, it supports slicing rows and writing data to rows. It does not support any operators of a deep learning framework, such as addition and multiplication.Currently, distributed tensors are designed to store node data and edge data of a distributed graph. Therefore, their first dimensions have to be the number of nodes or edges in the graph. The tensors are sharded in the first dimension based on the partition policy of nodes or edges. When a distributed tensor is created, the partition policy is automatically determined based on the first dimension if the partition policy is not provided. If the first dimension matches the number of nodes of a node type,
DistTensor
will use the partition policy for this particular node type; if the first dimension matches the number of edges of an edge type,DistTensor
will use the partition policy for this particular edge type. If DGL cannot determine the partition policy automatically (e.g., multiple node types or edge types have the same number of nodes or edges), users have to explicity provide the partition policy.A distributed tensor can be ether named or anonymous. When a distributed tensor has a name, the tensor can be persistent if
persistent=True
. Normally, DGL destroys the distributed tensor in the system when theDistTensor
object goes away. However, a persistent tensor lives in the system even if theDistTenor
object disappears in the trainer process. The persistent tensor has the same life span as the DGL servers. DGL does not allow an anonymous tensor to be persistent.When a
DistTensor
object is created, it may reference to an existing distributed tensor or create a new one. A distributed tensor is identified by the name passed to the constructor. If the name exists,DistTensor
will reference the existing one. In this case, the shape and the data type must match the existing tensor. If the name doesn’t exist, a new tensor will be created in the kvstore.When a distributed tensor is created, its values are initialized to zero. Users can define an initialization function to control how the values are initialized. The init function has two input arguments: shape and data type and returns a tensor. Below shows an example of an init function:
def init_func(shape, dtype): return torch.ones(shape=shape, dtype=dtype)
- Parameters:
shape (tuple) – The shape of the tensor. The first dimension has to be the number of nodes or the number of edges of a distributed graph.
dtype (dtype) – The dtype of the tensor. The data type has to be the one in the deep learning framework.
name (string, optional) – The name of the embeddings. The name can uniquely identify embeddings in a system so that another
DistTensor
object can referent to the distributed tensor.init_func (callable, optional) – The function to initialize data in the tensor. If the init function is not provided, the values of the embeddings are initialized to zero.
part_policy (PartitionPolicy, optional) – The partition policy of the rows of the tensor to different machines in the cluster. Currently, it only supports node partition policy or edge partition policy. The system determines the right partition policy automatically.
persistent (bool) – Whether the created tensor lives after the
DistTensor
object is destroyed.is_gdata (bool) – Whether the created tensor is a ndata/edata or not.
attach (bool) – Whether to attach group ID into name to be globally unique.
Examples
>>> init = lambda shape, dtype: th.ones(shape, dtype=dtype) >>> arr = dgl.distributed.DistTensor((g.num_nodes(), 2), th.int32, init_func=init) >>> print(arr[0:3]) tensor([[1, 1], [1, 1], [1, 1]], dtype=torch.int32) >>> arr[0:3] = th.ones((3, 2), dtype=th.int32) * 2 >>> print(arr[0:3]) tensor([[2, 2], [2, 2], [2, 2]], dtype=torch.int32)
Note
The creation of
DistTensor
is a synchronized operation. When a trainer process tries to create aDistTensor
object, the creation succeeds only when all trainer processes do the same.- property dtype
Return the data type of the distributed tensor.
- Returns:
The data type of the tensor.
- Return type:
dtype
- property name
Return the name of the distributed tensor
- Returns:
The name of the tensor.
- Return type:
- property part_policy
Return the partition policy
- Returns:
The partition policy of the distributed tensor.
- Return type:
Distributed Node Embedding
- class dgl.distributed.DistEmbedding(num_embeddings, embedding_dim, name=None, init_func=None, part_policy=None)[source]
Distributed node embeddings.
DGL provides a distributed embedding to support models that require learnable embeddings. DGL’s distributed embeddings are mainly used for learning node embeddings of graph models. Because distributed embeddings are part of a model, they are updated by mini-batches. The distributed embeddings have to be updated by DGL’s optimizers instead of the optimizers provided by the deep learning frameworks (e.g., Pytorch and MXNet).
To support efficient training on a graph with many nodes, the embeddings support sparse updates. That is, only the embeddings involved in a mini-batch computation are updated. Please refer to Distributed Optimizers for available optimizers in DGL.
Distributed embeddings are sharded and stored in a cluster of machines in the same way as
dgl.distributed.DistTensor
, except that distributed embeddings are trainable. Because distributed embeddings are sharded in the same way as nodes and edges of a distributed graph, it is usually much more efficient to access than the sparse embeddings provided by the deep learning frameworks.- Parameters:
num_embeddings (int) – The number of embeddings. Currently, the number of embeddings has to be the same as the number of nodes or the number of edges.
embedding_dim (int) – The dimension size of embeddings.
name (str, optional) – The name of the embeddings. The name can uniquely identify embeddings in a system so that another DistEmbedding object can referent to the same embeddings.
init_func (callable, optional) – The function to create the initial data. If the init function is not provided, the values of the embeddings are initialized to zero.
part_policy (PartitionPolicy, optional) – The partition policy that assigns embeddings to different machines in the cluster. Currently, it only supports node partition policy or edge partition policy. The system determines the right partition policy automatically.
Examples
>>> def initializer(shape, dtype): arr = th.zeros(shape, dtype=dtype) arr.uniform_(-1, 1) return arr >>> emb = dgl.distributed.DistEmbedding(g.num_nodes(), 10, init_func=initializer) >>> optimizer = dgl.distributed.optim.SparseAdagrad([emb], lr=0.001) >>> for blocks in dataloader: ... feats = emb(nids) ... loss = F.sum(feats + 1, 0) ... loss.backward() ... optimizer.step()
Note
When a
DistEmbedding
object is used in the forward computation, users have to invokestep()
afterwards. Otherwise, there will be some memory leak.
Distributed embedding optimizer
- class dgl.distributed.optim.SparseAdagrad(params, lr, eps=1e-10)[source]
Distributed Node embedding optimizer using the Adagrad algorithm.
This optimizer implements a distributed sparse version of Adagrad algorithm for optimizing
dgl.distributed.DistEmbedding
. Being sparse means it only updates the embeddings whose gradients have updates, which are usually a very small portion of the total embeddings.Adagrad maintains a \(G_{t,i,j}\) for every parameter in the embeddings, where \(G_{t,i,j}=G_{t-1,i,j} + g_{t,i,j}^2\) and \(g_{t,i,j}\) is the gradient of the dimension \(j\) of embedding \(i\) at step \(t\).
NOTE: The support of sparse Adagrad optimizer is experimental.
- Parameters:
params (list[dgl.distributed.DistEmbedding]) – The list of dgl.distributed.DistEmbedding.
lr (float) – The learning rate.
eps (float, Optional) – The term added to the denominator to improve numerical stability Default: 1e-10
- load(f)
Load the local state of the optimizer from the file on per rank.
NOTE: This needs to be called on all ranks.
- Parameters:
f (Union[str, os.PathLike]) – The path of the file to load from.
See also
- save(f)
Save the local state_dict to disk on per rank.
Saved dict contains 2 parts:
‘params’: hyper parameters of the optimizer.
- ‘emb_states’: partial optimizer states, each embedding contains 2 items:
`ids`
: global id of the nodes/edges stored in this rank.`states`
: state data corrseponding to`ids`
.
NOTE: This needs to be called on all ranks.
- Parameters:
f (Union[str, os.PathLike]) – The path of the file to save to.
See also
- step()
The step function.
The step function is invoked at the end of every batch to push the gradients of the embeddings involved in a mini-batch to DGL’s servers and update the embeddings.
- class dgl.distributed.optim.SparseAdam(params, lr, betas=(0.9, 0.999), eps=1e-08)[source]
Distributed Node embedding optimizer using the Adam algorithm.
This optimizer implements a distributed sparse version of Adam algorithm for optimizing
dgl.distributed.DistEmbedding
. Being sparse means it only updates the embeddings whose gradients have updates, which are usually a very small portion of the total embeddings.Adam maintains a \(Gm_{t,i,j}\) and Gp_{t,i,j} for every parameter in the embeddings, where \(Gm_{t,i,j}=beta1 * Gm_{t-1,i,j} + (1-beta1) * g_{t,i,j}\), \(Gp_{t,i,j}=beta2 * Gp_{t-1,i,j} + (1-beta2) * g_{t,i,j}^2\), \(g_{t,i,j} = lr * Gm_{t,i,j} / (1 - beta1^t) / \sqrt{Gp_{t,i,j} / (1 - beta2^t)}\) and \(g_{t,i,j}\) is the gradient of the dimension \(j\) of embedding \(i\) at step \(t\).
NOTE: The support of sparse Adam optimizer is experimental.
- Parameters:
params (list[dgl.distributed.DistEmbedding]) – The list of dgl.distributed.DistEmbedding.
lr (float) – The learning rate.
betas (tuple[float, float], Optional) – Coefficients used for computing running averages of gradient and its square. Default: (0.9, 0.999)
eps (float, Optional) – The term added to the denominator to improve numerical stability Default: 1e-8
- load(f)
Load the local state of the optimizer from the file on per rank.
NOTE: This needs to be called on all ranks.
- Parameters:
f (Union[str, os.PathLike]) – The path of the file to load from.
See also
- save(f)
Save the local state_dict to disk on per rank.
Saved dict contains 2 parts:
‘params’: hyper parameters of the optimizer.
- ‘emb_states’: partial optimizer states, each embedding contains 2 items:
`ids`
: global id of the nodes/edges stored in this rank.`states`
: state data corrseponding to`ids`
.
NOTE: This needs to be called on all ranks.
- Parameters:
f (Union[str, os.PathLike]) – The path of the file to save to.
See also
- step()
The step function.
The step function is invoked at the end of every batch to push the gradients of the embeddings involved in a mini-batch to DGL’s servers and update the embeddings.
Distributed workload split
|
Split nodes and return a subset for the local rank. |
|
Split edges and return a subset for the local rank. |
Distributed Sampling
Distributed DataLoader
- class dgl.distributed.DistDataLoader(dataset, batch_size, shuffle=False, collate_fn=None, drop_last=False, queue_size=None)[source]
DGL customized multiprocessing dataloader.
DistDataLoader provides a similar interface to Pytorch’s DataLoader to generate mini-batches with multiprocessing. It utilizes the worker processes created by
dgl.distributed.initialize()
to parallelize sampling.- Parameters:
dataset (a tensor) – Tensors of node IDs or edge IDs.
batch_size (int) – The number of samples per batch to load.
shuffle (bool, optional) – Set to
True
to have the data reshuffled at every epoch (default:False
).collate_fn (callable, optional) – The function is typically used to sample neighbors of the nodes in a batch or the endpoint nodes of the edges in a batch.
drop_last (bool, optional) – Set to
True
to drop the last incomplete batch, if the dataset size is not divisible by the batch size. IfFalse
and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default:False
)queue_size (int, optional) – Size of multiprocessing queue
Examples
>>> g = dgl.distributed.DistGraph('graph-name') >>> def sample(seeds): ... seeds = th.LongTensor(np.asarray(seeds)) ... frontier = dgl.distributed.sample_neighbors(g, seeds, 10) ... return dgl.to_block(frontier, seeds) >>> dataloader = dgl.distributed.DistDataLoader(dataset=nodes, batch_size=1000, collate_fn=sample, shuffle=True) >>> for block in dataloader: ... feat = g.ndata['features'][block.srcdata[dgl.NID]] ... labels = g.ndata['labels'][block.dstdata[dgl.NID]] ... pred = model(block, feat)
Note
When performing DGL’s distributed sampling with multiprocessing, users have to use this class instead of Pytorch’s DataLoader because DGL’s RPC requires that all processes establish connections with servers before invoking any DGL’s distributed API. Therefore, this dataloader uses the worker processes created in
dgl.distributed.initialize()
.Note
This dataloader does not guarantee the iteration order. For example, if dataset = [1, 2, 3, 4], batch_size = 2 and shuffle = False, the order of [1, 2] and [3, 4] is not guaranteed.
Distributed Graph Sampling Operators
|
Sample from the neighbors of the given nodes from a distributed graph. |
|
Sample from the neighbors of the given nodes from a distributed graph. |
|
Given an edge ID array, return the source and destination node ID array |
|
Return the subgraph induced on the inbound edges of the given nodes. |
Partition
Graph partition book
- class dgl.distributed.GraphPartitionBook[source]
The base class of the graph partition book.
For distributed training, a graph is partitioned into multiple parts and is loaded in multiple machines. The partition book contains all necessary information to locate nodes and edges in the cluster.
The partition book contains various partition information, including
the number of partitions,
the partition ID that a node or edge belongs to,
the node IDs and the edge IDs that a partition has.
the local IDs of nodes and edges in a partition.
Currently, only one class that implement
GraphPartitionBook
:RangePartitionBook
. It calculates the mapping between node/edge IDs and partition IDs based on some small metadata because nodes/edges have been relabeled to have IDs in the same partition fall in a contiguous ID range.A graph partition book is constructed automatically when a graph is partitioned. When a graph partition is loaded, a graph partition book is loaded as well. Please see
partition_graph()
,load_partition()
andload_partition_book()
for more details.- property canonical_etypes
Get the list of canonical edge types
- map_to_homo_nid(ids, ntype)[source]
Map type-wise node IDs and type IDs to homogeneous node IDs.
- Parameters:
ids (tensor) – Type-wise node Ids
ntype (str) – node type
- Returns:
Homogeneous node IDs.
- Return type:
Tensor
- map_to_per_etype(ids)[source]
Map homogeneous edge IDs to type-wise IDs and edge types.
- Parameters:
ids (tensor) – Homogeneous edge IDs.
- Returns:
edge type IDs and type-wise edge IDs.
- Return type:
(tensor, tensor)
- map_to_per_ntype(ids)[source]
Map homogeneous node IDs to type-wise IDs and node types.
- Parameters:
ids (tensor) – Homogeneous node IDs.
- Returns:
node type IDs and type-wise node IDs.
- Return type:
(tensor, tensor)
- metadata()[source]
Return the partition meta data.
The meta data includes:
The machine ID.
Number of nodes and edges of each partition.
Examples
>>> print(g.get_partition_book().metadata()) >>> [{'machine_id' : 0, 'num_nodes' : 3000, 'num_edges' : 5000}, ... {'machine_id' : 1, 'num_nodes' : 2000, 'num_edges' : 4888}, ... ...]
- nid2partid(nids, ntype)[source]
From global node IDs to partition IDs
- Parameters:
nids (tensor) – global node IDs
ntype (str) – The node type
- Returns:
partition IDs
- Return type:
tensor
- num_partitions()[source]
Return the number of partitions.
- Returns:
number of partitions
- Return type:
- property partid
Get the current partition ID
- Returns:
The partition ID of current machine
- Return type:
Move the partition book to shared memory.
- Parameters:
graph_name (str) – The graph name. This name will be used to read the partition book from shared memory in another process.
- class dgl.distributed.PartitionPolicy(policy_str, partition_book)[source]
This defines a partition policy for a distributed tensor or distributed embedding.
When DGL shards tensors and stores them in a cluster of machines, it requires partition policies that map rows of the tensors to machines in the cluster.
Although an arbitrary partition policy can be defined, DGL currently supports two partition policies for mapping nodes and edges to machines. To define a partition policy from a graph partition book, users need to specify the policy name (‘node’ or ‘edge’).
- Parameters:
policy_str (str) – Partition policy name, e.g., ‘edge~_N:_E:_N’ or ‘node~_N’.
partition_book (GraphPartitionBook) – A graph partition book
- property partition_book
Get partition book
- Returns:
The graph partition book
- Return type:
- property policy_str
Get the policy name
- Returns:
The name of the partition policy.
- Return type:
Split and Load Partitions
|
Load data of a partition from the data path. |
|
Load node/edge feature data from a partition. |
|
Load a graph partition book from the partition config file. |
|
Partition a graph for distributed training and store the partitions on files. |