"""PPI Dataset.
(zhang hao): Used for inductive learning.
"""
import json
import numpy as np
import networkx as nx
from networkx.readwrite import json_graph
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..graph import DGLGraph
_url = 'dataset/ppi.zip'
[docs]class PPIDataset(object):
"""A toy Protein-Protein Interaction network dataset.
Adapted from https://github.com/williamleif/GraphSAGE/tree/master/example_data.
The dataset contains 24 graphs. The average number of nodes per graph
is 2372. Each node has 50 features and 121 labels.
We use 20 graphs for training, 2 for validation and 2 for testing.
"""
def __init__(self, mode):
"""Initialize the dataset.
Paramters
---------
mode : str
('train', 'valid', 'test').
"""
self.mode = mode
self._load()
self._preprocess()
def _load(self):
"""Loads input data.
train/test/valid_graph.json => the graph data used for training,
test and validation as json format;
train/test/valid_feats.npy => the feature vectors of nodes as
numpy.ndarry object, it's shape is [n, v],
n is the number of nodes, v is the feature's dimension;
train/test/valid_labels.npy=> the labels of the input nodes, it
is a numpy ndarry, it's like[[0, 0, 1, ... 0],
[0, 1, 1, 0 ...1]], shape of it is n*h, n is the number of nodes,
h is the label's dimension;
train/test/valid/_graph_id.npy => the element in it indicates which
graph the nodes belong to, it is a one dimensional numpy.ndarray
object and the length of it is equal the number of nodes,
it's like [1, 1, 2, 1...20].
"""
name = 'ppi'
dir = get_download_dir()
zip_file_path = '{}/{}.zip'.format(dir, name)
download(_get_dgl_url(_url), path=zip_file_path)
extract_archive(zip_file_path,
'{}/{}'.format(dir, name))
print('Loading G...')
if self.mode == 'train':
with open('{}/ppi/train_graph.json'.format(dir)) as jsonfile:
g_data = json.load(jsonfile)
self.labels = np.load('{}/ppi/train_labels.npy'.format(dir))
self.features = np.load('{}/ppi/train_feats.npy'.format(dir))
self.graph = DGLGraph(nx.DiGraph(json_graph.node_link_graph(g_data)))
self.graph_id = np.load('{}/ppi/train_graph_id.npy'.format(dir))
if self.mode == 'valid':
with open('{}/ppi/valid_graph.json'.format(dir)) as jsonfile:
g_data = json.load(jsonfile)
self.labels = np.load('{}/ppi/valid_labels.npy'.format(dir))
self.features = np.load('{}/ppi/valid_feats.npy'.format(dir))
self.graph = DGLGraph(nx.DiGraph(json_graph.node_link_graph(g_data)))
self.graph_id = np.load('{}/ppi/valid_graph_id.npy'.format(dir))
if self.mode == 'test':
with open('{}/ppi/test_graph.json'.format(dir)) as jsonfile:
g_data = json.load(jsonfile)
self.labels = np.load('{}/ppi/test_labels.npy'.format(dir))
self.features = np.load('{}/ppi/test_feats.npy'.format(dir))
self.graph = DGLGraph(nx.DiGraph(json_graph.node_link_graph(g_data)))
self.graph_id = np.load('{}/ppi/test_graph_id.npy'.format(dir))
def _preprocess(self):
if self.mode == 'train':
self.train_mask_list = []
self.train_graphs = []
self.train_labels = []
for train_graph_id in range(1, 21):
train_graph_mask = np.where(self.graph_id == train_graph_id)[0]
self.train_mask_list.append(train_graph_mask)
self.train_graphs.append(self.graph.subgraph(train_graph_mask))
self.train_labels.append(self.labels[train_graph_mask])
if self.mode == 'valid':
self.valid_mask_list = []
self.valid_graphs = []
self.valid_labels = []
for valid_graph_id in range(21, 23):
valid_graph_mask = np.where(self.graph_id == valid_graph_id)[0]
self.valid_mask_list.append(valid_graph_mask)
self.valid_graphs.append(self.graph.subgraph(valid_graph_mask))
self.valid_labels.append(self.labels[valid_graph_mask])
if self.mode == 'test':
self.test_mask_list = []
self.test_graphs = []
self.test_labels = []
for test_graph_id in range(23, 25):
test_graph_mask = np.where(self.graph_id == test_graph_id)[0]
self.test_mask_list.append(test_graph_mask)
self.test_graphs.append(self.graph.subgraph(test_graph_mask))
self.test_labels.append(self.labels[test_graph_mask])
[docs] def __len__(self):
"""Return number of samples in this dataset."""
if self.mode == 'train':
return len(self.train_mask_list)
if self.mode == 'valid':
return len(self.valid_mask_list)
if self.mode == 'test':
return len(self.test_mask_list)
[docs] def __getitem__(self, item):
"""Get the i^th sample.
Paramters
---------
idx : int
The sample index.
Returns
-------
(dgl.DGLGraph, ndarray, ndarray)
The graph, features and its label.
"""
if self.mode == 'train':
return self.train_graphs[item], self.features[self.train_mask_list[item]], self.train_labels[item]
if self.mode == 'valid':
return self.valid_graphs[item], self.features[self.valid_mask_list[item]], self.valid_labels[item]
if self.mode == 'test':
return self.test_graphs[item], self.features[self.test_mask_list[item]], self.test_labels[item]