Source code for dgl.nn.pytorch.network_emb
"""Network Embedding NN Modules"""
# pylint: disable= invalid-name
import random
import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F
import tqdm
from ...base import NID
from ...convert import to_homogeneous, to_heterogeneous
from ...random import choice
from ...sampling import random_walk
__all__ = ['DeepWalk', 'MetaPath2Vec']
[docs]class DeepWalk(nn.Module):
"""DeepWalk module from `DeepWalk: Online Learning of Social Representations
<https://arxiv.org/abs/1403.6652>`__
For a graph, it learns the node representations from scratch by maximizing the similarity of
node pairs that are nearby (positive node pairs) and minimizing the similarity of other
random node pairs (negative node pairs).
Parameters
----------
g : DGLGraph
Graph for learning node embeddings
emb_dim : int, optional
Size of each embedding vector. Default: 128
walk_length : int, optional
Number of nodes in a random walk sequence. Default: 40
window_size : int, optional
In a random walk :attr:`w`, a node :attr:`w[j]` is considered close to a node
:attr:`w[i]` if :attr:`i - window_size <= j <= i + window_size`. Default: 5
neg_weight : float, optional
Weight of the loss term for negative samples in the total loss. Default: 1.0
negative_size : int, optional
Number of negative samples to use for each positive sample. Default: 1
fast_neg : bool, optional
If True, it samples negative node pairs within a batch of random walks. Default: True
sparse : bool, optional
If True, gradients with respect to the learnable weights will be sparse.
Default: True
Attributes
----------
node_embed : nn.Embedding
Embedding table of the nodes
Examples
--------
>>> import torch
>>> from dgl.data import CoraGraphDataset
>>> from dgl.nn import DeepWalk
>>> from torch.optim import SparseAdam
>>> from torch.utils.data import DataLoader
>>> from sklearn.linear_model import LogisticRegression
>>> dataset = CoraGraphDataset()
>>> g = dataset[0]
>>> model = DeepWalk(g)
>>> dataloader = DataLoader(torch.arange(g.num_nodes()), batch_size=128,
... shuffle=True, collate_fn=model.sample)
>>> optimizer = SparseAdam(model.parameters(), lr=0.01)
>>> num_epochs = 5
>>> for epoch in range(num_epochs):
... for batch_walk in dataloader:
... loss = model(batch_walk)
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> train_mask = g.ndata['train_mask']
>>> test_mask = g.ndata['test_mask']
>>> X = model.node_embed.weight.detach()
>>> y = g.ndata['label']
>>> clf = LogisticRegression().fit(X[train_mask].numpy(), y[train_mask].numpy())
>>> clf.score(X[test_mask].numpy(), y[test_mask].numpy())
"""
def __init__(self,
g,
emb_dim=128,
walk_length=40,
window_size=5,
neg_weight=1,
negative_size=5,
fast_neg=True,
sparse=True):
super().__init__()
assert walk_length >= window_size + 1, \
f'Expect walk_length >= window_size + 1, got {walk_length} and {window_size + 1}'
self.g = g
self.emb_dim = emb_dim
self.window_size = window_size
self.walk_length = walk_length
self.neg_weight = neg_weight
self.negative_size = negative_size
self.fast_neg = fast_neg
num_nodes = g.num_nodes()
# center node embedding
self.node_embed = nn.Embedding(num_nodes, emb_dim, sparse=sparse)
self.context_embed = nn.Embedding(num_nodes, emb_dim, sparse=sparse)
self.reset_parameters()
if not fast_neg:
neg_prob = g.out_degrees().pow(0.75)
# categorical distribution for true negative sampling
self.neg_prob = neg_prob / neg_prob.sum()
# Get list index pairs for positive samples.
# Given i, positive index pairs are (i - window_size, i), ... ,
# (i - 1, i), (i + 1, i), ..., (i + window_size, i)
idx_list_src = []
idx_list_dst = []
for i in range(walk_length):
for j in range(max(0, i - window_size), i):
idx_list_src.append(j)
idx_list_dst.append(i)
for j in range(i + 1, min(walk_length, i + 1 + window_size)):
idx_list_src.append(j)
idx_list_dst.append(i)
self.idx_list_src = torch.LongTensor(idx_list_src)
self.idx_list_dst = torch.LongTensor(idx_list_dst)
[docs] def reset_parameters(self):
"""Reinitialize learnable parameters"""
init_range = 1.0 / self.emb_dim
init.uniform_(self.node_embed.weight.data, -init_range, init_range)
init.constant_(self.context_embed.weight.data, 0)
def sample(self, indices):
"""Sample random walks
Parameters
----------
indices : torch.Tensor
Nodes from which we perform random walk
Returns
-------
torch.Tensor
Random walks in the form of node ID sequences. The Tensor
is of shape :attr:`(len(indices), walk_length)`.
"""
return random_walk(self.g, indices, length=self.walk_length - 1)[0]
[docs] def forward(self, batch_walk):
"""Compute the loss for the batch of random walks
Parameters
----------
batch_walk : torch.Tensor
Random walks in the form of node ID sequences. The Tensor
is of shape :attr:`(batch_size, walk_length)`.
Returns
-------
torch.Tensor
Loss value
"""
batch_size = len(batch_walk)
device = batch_walk.device
batch_node_embed = self.node_embed(batch_walk).view(-1, self.emb_dim)
batch_context_embed = self.context_embed(batch_walk).view(-1, self.emb_dim)
batch_idx_list_offset = torch.arange(batch_size) * self.walk_length
batch_idx_list_offset = batch_idx_list_offset.unsqueeze(1)
idx_list_src = batch_idx_list_offset + self.idx_list_src.unsqueeze(0)
idx_list_dst = batch_idx_list_offset + self.idx_list_dst.unsqueeze(0)
idx_list_src = idx_list_src.view(-1).to(device)
idx_list_dst = idx_list_dst.view(-1).to(device)
pos_src_emb = batch_node_embed[idx_list_src]
pos_dst_emb = batch_context_embed[idx_list_dst]
neg_idx_list_src = idx_list_dst.unsqueeze(1) + torch.zeros(
self.negative_size).unsqueeze(0).to(device)
neg_idx_list_src = neg_idx_list_src.view(-1)
neg_src_emb = batch_node_embed[neg_idx_list_src.long()]
if self.fast_neg:
neg_idx_list_dst = list(range(batch_size * self.walk_length)) \
* (self.negative_size * self.window_size * 2)
random.shuffle(neg_idx_list_dst)
neg_idx_list_dst = neg_idx_list_dst[:len(neg_idx_list_src)]
neg_idx_list_dst = torch.LongTensor(neg_idx_list_dst).to(device)
neg_dst_emb = batch_context_embed[neg_idx_list_dst]
else:
neg_dst = choice(self.g.num_nodes(), size=len(neg_src_emb), prob=self.neg_prob)
neg_dst_emb = self.context_embed(neg_dst.to(device))
pos_score = torch.sum(torch.mul(pos_src_emb, pos_dst_emb), dim=1)
pos_score = torch.clamp(pos_score, max=6, min=-6)
pos_score = torch.mean(-F.logsigmoid(pos_score))
neg_score = torch.sum(torch.mul(neg_src_emb, neg_dst_emb), dim=1)
neg_score = torch.clamp(neg_score, max=6, min=-6)
neg_score = torch.mean(-F.logsigmoid(-neg_score)) * self.negative_size * self.neg_weight
return torch.mean(pos_score + neg_score)
[docs]class MetaPath2Vec(nn.Module):
r"""metapath2vec module from `metapath2vec: Scalable Representation Learning for
Heterogeneous Networks <https://dl.acm.org/doi/pdf/10.1145/3097983.3098036>`__
To achieve efficient optimization, we leverage the negative sampling technique for the
training process. Repeatedly for each node in meta-path, we treat it as the center node
and sample nearby positive nodes within context size and draw negative samples among all
types of nodes from all meta-paths. Then we can use the center-context paired nodes and
context-negative paired nodes to update the network.
Parameters
----------
g : DGLGraph
Graph for learning node embeddings. Two different canonical edge types
:attr:`(utype, etype, vtype)` are not allowed to have same :attr:`etype`.
metapath : list[str]
A sequence of edge types in the form of a string. It defines a new edge type by composing
multiple edge types in order. Note that the start node type and the end one are commonly
the same.
window_size : int
In a random walk :attr:`w`, a node :attr:`w[j]` is considered close to a node
:attr:`w[i]` if :attr:`i - window_size <= j <= i + window_size`.
emb_dim : int, optional
Size of each embedding vector. Default: 128
negative_size : int, optional
Number of negative samples to use for each positive sample. Default: 5
sparse : bool, optional
If True, gradients with respect to the learnable weights will be sparse.
Default: True
Attributes
----------
node_embed : nn.Embedding
Embedding table of all nodes
local_to_global_nid : dict[str, list]
Mapping from type-specific node IDs to global node IDs
Examples
--------
>>> import torch
>>> import dgl
>>> from torch.optim import SparseAdam
>>> from torch.utils.data import DataLoader
>>> from dgl.nn.pytorch import MetaPath2Vec
>>> # Define a model
>>> g = dgl.heterograph({
... ('user', 'uc', 'company'): dgl.rand_graph(100, 1000).edges(),
... ('company', 'cp', 'product'): dgl.rand_graph(100, 1000).edges(),
... ('company', 'cu', 'user'): dgl.rand_graph(100, 1000).edges(),
... ('product', 'pc', 'company'): dgl.rand_graph(100, 1000).edges()
... })
>>> model = MetaPath2Vec(g, ['uc', 'cu'], window_size=1)
>>> # Use the source node type of etype 'uc'
>>> dataloader = DataLoader(torch.arange(g.num_nodes('user')), batch_size=128,
... shuffle=True, collate_fn=model.sample)
>>> optimizer = SparseAdam(model.parameters(), lr=0.025)
>>> for (pos_u, pos_v, neg_v) in dataloader:
... loss = model(pos_u, pos_v, neg_v)
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> # Get the embeddings of all user nodes
>>> user_nids = torch.LongTensor(model.local_to_global_nid['user'])
>>> user_emb = model.node_embed(user_nids)
"""
def __init__(self,
g,
metapath,
window_size,
emb_dim=128,
negative_size=5,
sparse=True):
super().__init__()
assert len(metapath) + 1 >= window_size, \
f'Expect len(metapath) >= window_size - 1, got {metapath} and {window_size}'
self.hg = g
self.emb_dim = emb_dim
self.metapath = metapath
self.window_size = window_size
self.negative_size = negative_size
# convert edge metapath to node metapath
# get initial source node type
src_type, _, _ = g.to_canonical_etype(metapath[0])
node_metapath = [src_type]
for etype in metapath:
_, _, dst_type = g.to_canonical_etype(etype)
node_metapath.append(dst_type)
self.node_metapath = node_metapath
# Convert the graph into a homogeneous one for global to local node ID mapping
g = to_homogeneous(g)
# Convert it back to the hetero one for local to global node ID mapping
hg = to_heterogeneous(g, self.hg.ntypes, self.hg.etypes)
local_to_global_nid = hg.ndata[NID]
for key, val in local_to_global_nid.items():
local_to_global_nid[key] = list(val.cpu().numpy())
self.local_to_global_nid = local_to_global_nid
num_nodes_total = hg.num_nodes()
node_frequency = torch.zeros(num_nodes_total)
# random walk
for idx in tqdm.trange(hg.num_nodes(node_metapath[0])):
traces, _ = random_walk(g=hg, nodes=[idx], metapath=metapath)
for tr in traces.cpu().numpy():
tr_nids = [
self.local_to_global_nid[node_metapath[i]][tr[i]] for i in range(len(tr))]
node_frequency[torch.LongTensor(tr_nids)] += 1
neg_prob = node_frequency.pow(0.75)
self.neg_prob = neg_prob / neg_prob.sum()
# center node embedding
self.node_embed = nn.Embedding(num_nodes_total, self.emb_dim, sparse=sparse)
self.context_embed = nn.Embedding(num_nodes_total, self.emb_dim, sparse=sparse)
self.reset_parameters()
[docs] def reset_parameters(self):
"""Reinitialize learnable parameters"""
init_range = 1.0 / self.emb_dim
init.uniform_(self.node_embed.weight.data, -init_range, init_range)
init.constant_(self.context_embed.weight.data, 0)
def sample(self, indices):
"""Sample positive and negative samples
Parameters
----------
indices : torch.Tensor
Node IDs of the source node type from which we perform random walks
Returns
-------
torch.Tensor
Positive center nodes
torch.Tensor
Positive context nodes
torch.Tensor
Negative context nodes
"""
traces, _ = random_walk(g=self.hg, nodes=indices, metapath=self.metapath)
u_list = []
v_list = []
for tr in traces.cpu().numpy():
tr_nids = [
self.local_to_global_nid[self.node_metapath[i]][tr[i]] for i in range(len(tr))]
for i, u in enumerate(tr_nids):
for j, v in enumerate(tr_nids[max(i - self.window_size, 0):i + self.window_size]):
if i == j:
continue
u_list.append(u)
v_list.append(v)
neg_v = choice(self.hg.num_nodes(), size=len(u_list) * self.negative_size,
prob=self.neg_prob).reshape(len(u_list), self.negative_size)
return torch.LongTensor(u_list), torch.LongTensor(v_list), neg_v
[docs] def forward(self, pos_u, pos_v, neg_v):
r"""Compute the loss for the batch of positive and negative samples
Parameters
----------
pos_u : torch.Tensor
Positive center nodes
pos_v : torch.Tensor
Positive context nodes
neg_v : torch.Tensor
Negative context nodes
Returns
-------
torch.Tensor
Loss value
"""
emb_u = self.node_embed(pos_u)
emb_v = self.context_embed(pos_v)
emb_neg_v = self.context_embed(neg_v)
score = torch.sum(torch.mul(emb_u, emb_v), dim=1)
score = torch.clamp(score, max=10, min=-10)
score = -F.logsigmoid(score)
neg_score = torch.bmm(emb_neg_v, emb_u.unsqueeze(2)).squeeze()
neg_score = torch.clamp(neg_score, max=10, min=-10)
neg_score = -torch.sum(F.logsigmoid(-neg_score), dim=1)
return torch.mean(score + neg_score)