5.2 边分类/回归¶

```src = np.random.randint(0, 100, 500)
dst = np.random.randint(0, 100, 500)
# 同时建立反向边
edge_pred_graph = dgl.graph((np.concatenate([src, dst]), np.concatenate([dst, src])))
# 建立点和边特征，以及边的标签
edge_pred_graph.ndata['feature'] = torch.randn(100, 10)
edge_pred_graph.edata['feature'] = torch.randn(1000, 10)
edge_pred_graph.edata['label'] = torch.randn(1000)
# 进行训练、验证和测试集划分
edge_pred_graph.edata['train_mask'] = torch.zeros(1000, dtype=torch.bool).bernoulli(0.6)
```

与节点分类在模型实现上的差别¶

```import dgl.function as fn
class DotProductPredictor(nn.Module):
def forward(self, graph, h):
# h是从5.1节的GNN模型中计算出的节点表示
with graph.local_scope():
graph.ndata['h'] = h
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
return graph.edata['score']
```

```class MLPPredictor(nn.Module):
def __init__(self, in_features, out_classes):
super().__init__()
self.W = nn.Linear(in_features * 2, out_classes)

def apply_edges(self, edges):
h_u = edges.src['h']
h_v = edges.dst['h']
score = self.W(torch.cat([h_u, h_v], 1))
return {'score': score}

def forward(self, graph, h):
# h是从5.1节的GNN模型中计算出的节点表示
with graph.local_scope():
graph.ndata['h'] = h
graph.apply_edges(self.apply_edges)
return graph.edata['score']
```

模型的训练¶

```class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super().__init__()
self.sage = SAGE(in_features, hidden_features, out_features)
self.pred = DotProductPredictor()
def forward(self, g, x):
h = self.sage(g, x)
return self.pred(g, h)
```

```node_features = edge_pred_graph.ndata['feature']
edge_label = edge_pred_graph.edata['label']
model = Model(10, 20, 5)
for epoch in range(10):
pred = model(edge_pred_graph, node_features)
loss = ((pred[train_mask] - edge_label[train_mask]) ** 2).mean()
loss.backward()
opt.step()
print(loss.item())
```

异构图上的边预测模型的训练¶

```class HeteroDotProductPredictor(nn.Module):
def forward(self, graph, h, etype):
# h是从5.1节中对每种类型的边所计算的节点表示
with graph.local_scope():
graph.ndata['h'] = h   #一次性为所有节点类型的 'h'赋值
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
return graph.edges[etype].data['score']
```

```class MLPPredictor(nn.Module):
def __init__(self, in_features, out_classes):
super().__init__()
self.W = nn.Linear(in_features * 2, out_classes)

def apply_edges(self, edges):
h_u = edges.src['h']
h_v = edges.dst['h']
score = self.W(torch.cat([h_u, h_v], 1))
return {'score': score}

def forward(self, graph, h, etype):
# h是从5.1节中对异构图的每种类型的边所计算的节点表示
with graph.local_scope():
graph.ndata['h'] = h   #一次性为所有节点类型的 'h'赋值
graph.apply_edges(self.apply_edges, etype=etype)
return graph.edges[etype].data['score']
```

```class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features, rel_names):
super().__init__()
self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
self.pred = HeteroDotProductPredictor()
def forward(self, g, x, etype):
h = self.sage(g, x)
return self.pred(g, h, etype)
```

```model = Model(10, 20, 5, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
label = hetero_graph.edges['click'].data['label']
node_features = {'user': user_feats, 'item': item_feats}
```

```opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
pred = model(hetero_graph, node_features, 'click')
loss = ((pred[train_mask] - label[train_mask]) ** 2).mean()
loss.backward()
opt.step()
print(loss.item())
```

在异构图中预测已有边的类型¶

```dec_graph = hetero_graph['user', :, 'item']
```

```edge_label = dec_graph.edata[dgl.ETYPE]
```

```class HeteroMLPPredictor(nn.Module):
def __init__(self, in_dims, n_classes):
super().__init__()
self.W = nn.Linear(in_dims * 2, n_classes)

def apply_edges(self, edges):
x = torch.cat([edges.src['h'], edges.dst['h']], 1)
y = self.W(x)
return {'score': y}

def forward(self, graph, h):
# h是从5.1节中对异构图的每种类型的边所计算的节点表示
with graph.local_scope():
graph.ndata['h'] = h   #一次性为所有节点类型的 'h'赋值
graph.apply_edges(self.apply_edges)
return graph.edata['score']
```

```class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features, rel_names):
super().__init__()
self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
self.pred = HeteroMLPPredictor(out_features, len(rel_names))
def forward(self, g, x, dec_graph):
h = self.sage(g, x)
return self.pred(dec_graph, h)
```

```model = Model(10, 20, 5, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
node_features = {'user': user_feats, 'item': item_feats}