6.5 Training GNN with DGL sparseΒΆ
This tutorial demonstrates how to use dgl sparse library to sample on graph and train model. It trains and tests a GraphSAGE model using the sparse sample and compact operators to sample submatrix from the whole matrix.
Training GNN with DGL sparse is quite similar to 6.1 Training GNN for Node Classification with Neighborhood Sampling. The major difference is the customized sampler and matrix that represents graph.
We have cutomized one sampler in 6.4 Implementing Custom Graph Samplers. In this tutorial, we will customize another sampler with DGL sparse library as shown below.
@functional_datapipe("sample_sparse_neighbor")
class SparseNeighborSampler(SubgraphSampler):
def __init__(self, datapipe, matrix, fanouts):
super().__init__(datapipe)
self.matrix = matrix
# Convert fanouts to a list of tensors.
self.fanouts = []
for fanout in fanouts:
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
self.fanouts.insert(0, fanout)
def sample_subgraphs(self, seeds):
sampled_matrices = []
src = seeds
#####################################################################
# (HIGHLIGHT) Using the sparse sample operator to preform random
# sampling on the neighboring nodes of the seeds nodes. The sparse
# compact operator is then employed to compact and relabel the sampled
# matrix, resulting in the sampled matrix and the relabel index.
#####################################################################
for fanout in self.fanouts:
# Sample neighbors.
sampled_matrix = self.matrix.sample(1, fanout, ids=src).coalesce()
# Compact the sampled matrix.
compacted_mat, row_ids = sampled_matrix.compact(0)
sampled_matrices.insert(0, compacted_mat)
src = row_ids
return src, sampled_matrices
Another major difference is the matrix that represents graph. Previously we use
FusedCSCSamplingGraph
for sampling. In this tutorial,
we use SparseMatrix
to represent graph.
dataset = gb.BuiltinDataset("ogbn-products").load()
g = dataset.graph
# Create sparse.
N = g.num_nodes
A = dglsp.from_csc(g.csc_indptr, g.indices, shape=(N, N))
The remaining code is almost same as node classification tutorial.
To use this sampler with DataLoader
:
datapipe = gb.ItemSampler(ids, batch_size=1024)
# Customize graphbolt sampler by sparse.
datapipe = datapipe.sample_sparse_neighbor(A, fanouts)
# Use grapbolt to fetch features.
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
Model definition is shown below:
class SAGEConv(nn.Module):
r"""GraphSAGE layer from `Inductive Representation Learning on
Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__
"""
def __init__(
self,
in_feats,
out_feats,
):
super(SAGEConv, self).__init__()
self._in_src_feats, self._in_dst_feats = in_feats, in_feats
self._out_feats = out_feats
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=True)
self.reset_parameters()
def reset_parameters(self):
gain = nn.init.calculate_gain("relu")
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
def forward(self, A, feat):
feat_src = feat
feat_dst = feat[: A.shape[1]]
# Aggregator type: mean.
srcdata = self.fc_neigh(feat_src)
# Divided by degree.
D_hat = dglsp.diag(A.sum(0)) ** -1
A_div = A @ D_hat
# Conv neighbors.
dstdata = A_div.T @ srcdata
rst = self.fc_self(feat_dst) + dstdata
return rst
class SAGE(nn.Module):
def __init__(self, in_size, hid_size, out_size):
super().__init__()
self.layers = nn.ModuleList()
# Three-layer GraphSAGE-gcn.
self.layers.append(SAGEConv(in_size, hid_size))
self.layers.append(SAGEConv(hid_size, hid_size))
self.layers.append(SAGEConv(hid_size, out_size))
self.dropout = nn.Dropout(0.5)
self.hid_size = hid_size
self.out_size = out_size
def forward(self, sampled_matrices, x):
hidden_x = x
for layer_idx, (layer, sampled_matrix) in enumerate(
zip(self.layers, sampled_matrices)
):
hidden_x = layer(sampled_matrix, hidden_x)
if layer_idx != len(self.layers) - 1:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
return hidden_x
Launch training:
features = dataset.feature
# Create GraphSAGE model.
in_size = features.size("node", None, "feat")[0]
num_classes = dataset.tasks[0].metadata["num_classes"]
out_size = num_classes
model = SAGE(in_size, 256, out_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
for epoch in range(10):
model.train()
total_loss = 0
for it, data in enumerate(dataloader):
node_feature = data.node_features["feat"].float()
blocks = data.sampled_subgraphs
y = data.labels
y_hat = model(blocks, node_feature)
loss = F.cross_entropy(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
For more details, please refer to the full example.