BatchedDGLGraph – Enable batched graph operations¶
-
class
dgl.
BatchedDGLGraph
(graph_list, node_attrs, edge_attrs)[source]¶ Class for batched DGL graphs.
A
BatchedDGLGraph
basically merges a list of small graphs into a giant graph so that one can perform message passing and readout over a batch of graphs simultaneously.The nodes and edges are re-indexed with a new id in the batched graph with the rule below:
item Graph 1 Graph 2 … Graph k raw id 0, …, N1 0, …, N2 … …, Nk new id 0, …, N1 N1 + 1, …, N1 + N2 + 1 … …, N1 + … + Nk + k - 1 The batched graph is read-only, i.e. one cannot further add nodes and edges. A
RuntimeError
will be raised if one attempts.To modify the features in
BatchedDGLGraph
has no effect on the original graphs. See the examples below about how to work around.Parameters: - graph_list (iterable) – A collection of
DGLGraph
objects to be batched. - node_attrs (None, str or iterable, optional) – The node attributes to be batched. If
None
, theBatchedDGLGraph
object will not have any node attributes. By default, all node attributes will be batched. An error will be raised if graphs having nodes have different attributes. Ifstr
oriterable
, this should specify exactly what node attributes to be batched. - edge_attrs (None, str or iterable, optional) – Same as for the case of
node_attrs
Examples
Create two
DGLGraph
objects.Instantiation:
>>> import dgl >>> import torch as th >>> g1 = dgl.DGLGraph() >>> g1.add_nodes(2) # Add 2 nodes >>> g1.add_edge(0, 1) # Add edge 0 -> 1 >>> g1.ndata['hv'] = th.tensor([[0.], [1.]]) # Initialize node features >>> g1.edata['he'] = th.tensor([[0.]]) # Initialize edge features
>>> g2 = dgl.DGLGraph() >>> g2.add_nodes(3) # Add 3 nodes >>> g2.add_edges([0, 2], [1, 1]) # Add edges 0 -> 1, 2 -> 1 >>> g2.ndata['hv'] = th.tensor([[2.], [3.], [4.]]) # Initialize node features >>> g2.edata['he'] = th.tensor([[1.], [2.]]) # Initialize edge features
Merge two
DGLGraph
objects into oneBatchedDGLGraph
object. When merging a list of graphs, we can choose to include only a subset of the attributes.>>> bg = dgl.batch([g1, g2], edge_attrs=None) >>> bg.edata {}
Below one can see that the nodes are re-indexed. The edges are re-indexed in the same way.
>>> bg.nodes() tensor([0, 1, 2, 3, 4]) >>> bg.ndata['hv'] tensor([[0.], [1.], [2.], [3.], [4.]])
Property:
We can still get a brief summary of the graphs that constitute the batched graph.
>>> bg.batch_size 2 >>> bg.batch_num_nodes [2, 3] >>> bg.batch_num_edges [1, 2]
Readout:
Another common demand for graph neural networks is graph readout, which is a function that takes in the node attributes and/or edge attributes for a graph and outputs a vector summarizing the information in the graph.
BatchedDGLGraph
also supports performing readout for a batch of graphs at once.Below we take the built-in readout function
sum_nodes()
as an example, which sums over a particular kind of node attribute for each graph.>>> dgl.sum_nodes(bg, 'hv') # Sum the node attribute 'hv' for each graph. tensor([[1.], # 0 + 1 [9.]]) # 2 + 3 + 4
Message passing:
For message passing and related operations,
BatchedDGLGraph
acts exactly the same asDGLGraph
.Update Attributes:
Updating the attributes of the batched graph has no effect on the original graphs.
>>> bg.edata['he'] = th.zeros(3, 2) >>> g2.edata['he'] tensor([[1.], [2.]])}
Instead, we can decompose the batched graph back into a list of graphs and use them to replace the original graphs.
>>> g1, g2 = dgl.unbatch(bg) # returns a list of DGLGraph objects >>> g2.edata['he'] tensor([[0., 0.], [0., 0.]])}
- graph_list (iterable) – A collection of
Merge and decompose¶
batch (graph_list[, node_attrs, edge_attrs]) |
Batch a collection of DGLGraph and return a BatchedDGLGraph object that is independent of the graph_list . |
unbatch (graph) |
Return the list of graphs in this batch. |
Query batch summary¶
BatchedDGLGraph.batch_size |
Number of graphs in this batch. |
BatchedDGLGraph.batch_num_nodes |
Number of nodes of each graph in this batch. |
BatchedDGLGraph.batch_num_edges |
Number of edges of each graph in this batch. |
Graph Readout¶
sum_nodes (graph, input[, weight]) |
Sums all the values of node field input in graph , optionally multiplies the field by a scalar node field weight . |
sum_edges (graph, input[, weight]) |
Sums all the values of edge field input in graph , optionally multiplies the field by a scalar edge field weight . |
mean_nodes (graph, input[, weight]) |
Averages all the values of node field input in graph , optionally multiplies the field by a scalar node field weight . |
mean_edges (graph, input[, weight]) |
Averages all the values of edge field input in graph , optionally multiplies the field by a scalar edge field weight . |