DGL
latest

Get Started

  • Install and Setup
  • A Blitz Introduction to DGL

Advanced Materials

  • User Guide
  • 用户指南
  • 사용자 가이드
  • 🆕 Tutorials: dgl.sparse
  • Stochastic Training of GNNs
  • Training on CPUs
  • Training on Multiple GPUs
  • Distributed training
    • Distributed Node Classification
    • Distributed Link Prediction
  • Paper Study with DGL

API Reference

  • dgl
  • dgl.data
  • dgl.dataloading
  • dgl.DGLGraph
  • dgl.distributed
  • dgl.function
  • dgl.geometry
  • dgl.nn (PyTorch)
  • dgl.nn (TensorFlow)
  • dgl.nn (MXNet)
  • dgl.nn.functional
  • dgl.ops
  • dgl.optim
  • dgl.sampling
  • 🆕 dgl.sparse
  • dgl.multiprocessing
  • dgl.transforms
  • User-defined Functions

Notes

  • Contribute to DGL
  • DGL Foreign Function Interface (FFI)
  • Performance Benchmarks

Misc

  • Frequently Asked Questions (FAQ)
  • Environment Variables
  • Resources
DGL
  • Distributed training
  • Distributed Link Prediction
  • Edit on GitHub

Note

Click here to download the full example code

Distributed Link Prediction¶

In this tutorial, we will walk through the steps of performing distributed GNN training for a link prediction task. This tutorial assumes that you have read the Distributed Node Classification and Stochastic Training of GNN for Link Prediction. The general pipeline is shown below.

Imgur

Partition a graph¶

In this tutorial, we will use OGBL citation2 graph as an example to illustrate the graph partitioning. Let’s first load the graph into a DGL graph and convert it into a training graph, validation edges and test edges with AsLinkPredDataset.

import os
os.environ['DGLBACKEND'] = 'pytorch'
import dgl
import torch as th
from ogb.linkproppred import DglLinkPropPredDataset
data = DglLinkPropPredDataset(name='ogbl-citation2')
graph = data[0]
data = dgl.data.AsLinkPredDataset(data, [0.8, 0.1, 0.1])
graph_train = data[0]
dgl.distributed.partition_graph(graph_train, graph_name='ogbl-citation2', num_parts=4,
                            out_path='4part_data',
                            balance_edges=True)

Then, we store the validation and test edges with the graph partitions.

import pickle
with open('4part_data/val.pkl', 'wb') as f:
    pickle.dump(data.val_edges, f)
with open('4part_data/test.pkl', 'wb') as f:
    pickle.dump(data.test_edges, f)

Distributed training script¶

The distributed link prediction script is very similar to distributed node classification script with just a few modifications.

Initialize network communication¶

We first initialize the network communication and Pytorch’s distributed communication.

import dgl
import torch as th
dgl.distributed.initialize(ip_config='ip_config.txt')
th.distributed.init_process_group(backend='gloo')

The configuration file ip_config.txt has the following format:

ip_addr1 [port1]
ip_addr2 [port2]

Each row is a machine. The first column is the IP address and the second column is the port for connecting to the DGL server on the machine. The port is optional and the default port is 30050.

Reference to the distributed graph¶

DGL’s servers load the graph partitions automatically. After the servers load the partitions, trainers connect to the servers and can start to reference to the distributed graph in the cluster as below.

g = dgl.distributed.DistGraph('ogbl-citation2')

As shown in the code, we refer to a distributed graph by its name. This name is basically the one passed to the partition_graph function as shown in the section above.

Get training and validation node IDs¶

For distributed training, each trainer can run its own set of training nodes. We can get the current graph in the trainer with its node ids and edge ids by invoking node_split and edge_split. We can also get the valid edges and test edges by loading the pickle files.

train_eids = dgl.distributed.edge_split(th.ones((g.number_of_edges(),), dtype=th.bool), g.get_partition_book(), force_even=True)
train_nids = dgl.distributed.node_split(th.ones((g.number_of_nodes(),), dtype=th.bool), g.get_partition_book())
with open('4part_data/val.pkl', 'rb') as f:
    global_valid_eid = pickle.load(f)
with open('4part_data/test.pkl', 'rb') as f:
    global_test_eid = pickle.load(f)

Define a GNN model¶

For distributed training, we define a GNN model exactly in the same way as mini-batch training or full-graph training. The code below defines the GraphSage model.

import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn
import torch.optim as optim

class SAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, n_layers):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))

    def forward(self, blocks, x):
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            x = layer(block, x)
            if l != self.n_layers - 1:
                x = F.relu(x)
        return x

num_hidden = 256
num_labels = len(th.unique(g.ndata['labels'][0:g.number_of_nodes()]))
num_layers = 2
lr = 0.001
model = SAGE(g.ndata['feat'].shape[1], num_hidden, num_labels, num_layers)
loss_fcn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

For distributed training, we need to convert the model into a distributed model with Pytorch’s DistributedDataParallel.

model = th.nn.parallel.DistributedDataParallel(model)

We also define an edge predictor EdgePredictor to predict the edge scores of pairs of node representations

from dgl.nn import EdgePredictor
predictor = EdgePredictor('dot')

Distributed mini-batch sampler¶

We can use DistEdgeDataLoader, the distributed counterpart of EdgeDataLoader, to create a distributed mini-batch sampler for link prediction.

Training loop¶

The training loop for distributed training is also exactly the same as the single-process training.

import sklearn.metrics
import numpy as np

epoch = 0
for epoch in range(10):
    for step, (input_nodes, pos_graph, neg_graph, mfgs) in enumerate(dataloader):
        pos_graph = pos_graph
        neg_graph = neg_graph
        node_inputs = mfgs[0].srcdata[dgl.NID]
        batch_inputs = g.ndata['feat'][node_inputs]

        batch_pred = model(mfgs, batch_inputs)
        pos_feature = batch_pred
        pos_graph.ndata['h'] = batch_pred
        pos_src, pos_dst = pos_graph.edges()
        pos_score = predictor(pos_feature[pos_src], pos_feature[pos_dst])

        neg_feature = batch_pred
        neg_graph.ndata['h'] = batch_pred
        neg_src, neg_dst = neg_graph.edges()
        neg_score = predictor(neg_feature[pos_src], neg_feature[pos_dst])

        score = th.cat([pos_score, neg_score])
        label = th.cat([th.ones_like(pos_score), th.zeros_like(neg_score)])
        loss = F.binary_cross_entropy_with_logits(score, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Inference¶

In the inference stage, we use the model after training loop to get the embedding of nodes.

def inference(model, graph, node_features, args):
    with th.no_grad():
        sampler = dgl.dataloading.MultiLayerNeighborSampler([25,10])
        train_dataloader = dgl.dataloading.DistNodeDataLoader(
            graph, th.arange(graph.number_of_nodes()), sampler,
            batch_size=1024,
            shuffle=False,
            drop_last=False)

        result = []
        for input_nodes, output_nodes, mfgs in train_dataloader:
            node_inputs = mfgs[0].srcdata[dgl.NID]
            inputs = node_features[node_inputs]
            result.append(model(mfgs, inputs))

        return th.cat(result)

node_reprs = inference(model, g, g.ndata['feat'], args)

The test edges is encoded as ((positive_edge_src, positive_edge_dst), (negative_edge_src, negative_edge_dst)). Therefore, we can get the ground truth with positive pairs and negative pairs.

test_pos_src = global_test_eid[0][0]
test_pos_dst = global_test_eid[0][1]
test_neg_src = global_test_eid[1][0]
test_neg_dst = global_test_eid[1][1]
test_labels = th.cat([th.ones_like(test_pos_src), th.zeros_like(test_neg_src)]).cpu().numpy()

Then, we use the dot product predictor to get the score of positive and negative test pairs to compute metrics such as AUC:

h_pos_src = node_reprs[test_pos_src]
h_pos_dst = node_reprs[test_pos_dst]
h_neg_src = node_reprs[test_neg_src]
h_neg_dst = node_reprs[test_neg_dst]
score_pos = predictor(h_pos_src, h_pos_dst)
score_neg = predictor(h_neg_src, h_neg_dst)

test_preds = th.cat([score_pos, score_neg]).cpu().numpy()
auc = skm.roc_auc_score(test_labels, test_preds)

Set up distributed training environment¶

The distributed training environment set up is similar to the distributed node classification. Please refer here for more details: Set up distributed training environment

Total running time of the script: ( 0 minutes 0.000 seconds)

Download Python source code: 2_link_prediction.py

Download Jupyter notebook: 2_link_prediction.ipynb

Gallery generated by Sphinx-Gallery

Previous Next

© Copyright 2018, DGL Team. Revision 5b409bf7.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
1.0.x
0.9.x
0.8.x
0.7.x
0.6.x
Downloads
On Read the Docs
Project Home
Builds