OnDiskDataset for Homogeneous Graphο
This tutorial shows how to create OnDiskDataset
for homogeneous graph that could be used in GraphBolt framework.
By the end of this tutorial, you will be able to
organize graph structure data.
organize feature data.
organize training/validation/test set for specific tasks.
To create an OnDiskDataset
object, you need to organize all the data including graph structure, feature data and tasks into a directory. The directory should contain a metadata.yaml
file that describes the metadata of the dataset.
Now letβs generate various data step by step and organize them together to instantiate OnDiskDataset
finally.
Install DGL packageο
[1]:
# Install required packages.
import os
import torch
import numpy as np
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"
# Install the CPU version.
device = torch.device("cpu")
!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html
try:
import dgl
import dgl.graphbolt as gb
installed = True
except ImportError as error:
installed = False
print(error)
print("DGL installed!" if installed else "DGL not found!")
Looking in links: https://data.dgl.ai/wheels-test/repo.html
Requirement already satisfied: dgl in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages/dgl-2.1.0-py3.8-linux-x86_64.egg (2.1.0)
Requirement already satisfied: numpy>=1.14.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from dgl) (1.24.4)
Requirement already satisfied: scipy>=1.1.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from dgl) (1.10.1)
Requirement already satisfied: networkx>=2.1 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from dgl) (3.1)
Requirement already satisfied: requests>=2.19.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from dgl) (2.31.0)
Requirement already satisfied: tqdm in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from dgl) (4.66.2)
Requirement already satisfied: psutil>=5.8.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from dgl) (5.9.8)
Requirement already satisfied: torchdata>=0.5.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from dgl) (0.7.1)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from requests>=2.19.0->dgl) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from requests>=2.19.0->dgl) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from requests>=2.19.0->dgl) (2.2.1)
Requirement already satisfied: certifi>=2017.4.17 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from requests>=2.19.0->dgl) (2024.2.2)
Requirement already satisfied: torch>=2 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torchdata>=0.5.0->dgl) (2.0.0)
Requirement already satisfied: filelock in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.13.1)
Requirement already satisfied: typing-extensions in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (4.10.0)
Requirement already satisfied: sympy in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (1.12)
Requirement already satisfied: jinja2 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.1.3)
Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.99)
Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.99)
Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.101)
Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (8.5.0.96)
Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.10.3.66)
Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (10.9.0.58)
Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (10.2.10.91)
Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.4.0.1)
Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.4.91)
Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (2.14.3)
Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.91)
Requirement already satisfied: triton==2.0.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (2.0.0)
Requirement already satisfied: setuptools in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=2->torchdata>=0.5.0->dgl) (69.1.0)
Requirement already satisfied: wheel in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=2->torchdata>=0.5.0->dgl) (0.42.0)
Requirement already satisfied: cmake in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from triton==2.0.0->torch>=2->torchdata>=0.5.0->dgl) (3.28.3)
Requirement already satisfied: lit in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from triton==2.0.0->torch>=2->torchdata>=0.5.0->dgl) (17.0.6)
Requirement already satisfied: MarkupSafe>=2.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from jinja2->torch>=2->torchdata>=0.5.0->dgl) (2.1.5)
Requirement already satisfied: mpmath>=0.19 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages (from sympy->torch>=2->torchdata>=0.5.0->dgl) (1.3.0)
DGL installed!
Data preparationο
In order to demonstrate how to organize various data, letβs create a base directory first.
[2]:
base_dir = './ondisk_dataset_homograph'
os.makedirs(base_dir, exist_ok=True)
print(f"Created base directory: {base_dir}")
Created base directory: ./ondisk_dataset_homograph
Generate graph structure dataο
For homogeneous graph, we just need to save edges(namely node pairs) into Numpy or CSV file.
Note: - when saving to Numpy, the array requires to be in shape of (2, N)
. This format is recommended as constructing graph from it is much faster than CSV file. - when saving to CSV file, do not save index and header.
[3]:
import numpy as np
import pandas as pd
num_nodes = 1000
num_edges = 10 * num_nodes
edges_path = os.path.join(base_dir, "edges.csv")
edges = np.random.randint(0, num_nodes, size=(num_edges, 2))
print(f"Part of edges: {edges[:5, :]}")
df = pd.DataFrame(edges)
df.to_csv(edges_path, index=False, header=False)
print(f"Edges are saved into {edges_path}")
Part of edges: [[634 822]
[285 606]
[268 361]
[434 858]
[872 677]]
Edges are saved into ./ondisk_dataset_homograph/edges.csv
Generate feature data for graphο
For feature data, numpy arrays and torch tensors are supported for now.
[4]:
# Generate node feature in numpy array.
node_feat_0_path = os.path.join(base_dir, "node-feat-0.npy")
node_feat_0 = np.random.rand(num_nodes, 5)
print(f"Part of node feature [feat_0]: {node_feat_0[:3, :]}")
np.save(node_feat_0_path, node_feat_0)
print(f"Node feature [feat_0] is saved to {node_feat_0_path}\n")
# Generate another node feature in torch tensor
node_feat_1_path = os.path.join(base_dir, "node-feat-1.pt")
node_feat_1 = torch.rand(num_nodes, 5)
print(f"Part of node feature [feat_1]: {node_feat_1[:3, :]}")
torch.save(node_feat_1, node_feat_1_path)
print(f"Node feature [feat_1] is saved to {node_feat_1_path}\n")
# Generate edge feature in numpy array.
edge_feat_0_path = os.path.join(base_dir, "edge-feat-0.npy")
edge_feat_0 = np.random.rand(num_edges, 5)
print(f"Part of edge feature [feat_0]: {edge_feat_0[:3, :]}")
np.save(edge_feat_0_path, edge_feat_0)
print(f"Edge feature [feat_0] is saved to {edge_feat_0_path}\n")
# Generate another edge feature in torch tensor
edge_feat_1_path = os.path.join(base_dir, "edge-feat-1.pt")
edge_feat_1 = torch.rand(num_edges, 5)
print(f"Part of edge feature [feat_1]: {edge_feat_1[:3, :]}")
torch.save(edge_feat_1, edge_feat_1_path)
print(f"Edge feature [feat_1] is saved to {edge_feat_1_path}\n")
Part of node feature [feat_0]: [[0.24347811 0.72126432 0.34543555 0.81000785 0.80420882]
[0.92015901 0.74593549 0.16546228 0.39865952 0.453004 ]
[0.37700494 0.85856436 0.81924908 0.74399198 0.6601044 ]]
Node feature [feat_0] is saved to ./ondisk_dataset_homograph/node-feat-0.npy
Part of node feature [feat_1]: tensor([[0.9658, 0.8694, 0.0736, 0.3241, 0.4566],
[0.7273, 0.1294, 0.3846, 0.5297, 0.9789],
[0.7273, 0.0582, 0.0536, 0.0130, 0.8722]])
Node feature [feat_1] is saved to ./ondisk_dataset_homograph/node-feat-1.pt
Part of edge feature [feat_0]: [[0.05609649 0.54517599 0.06213402 0.73013919 0.0360162 ]
[0.55684805 0.39823114 0.94442217 0.4939441 0.66789595]
[0.83361042 0.89735919 0.27657286 0.23773658 0.20972403]]
Edge feature [feat_0] is saved to ./ondisk_dataset_homograph/edge-feat-0.npy
Part of edge feature [feat_1]: tensor([[0.2361, 0.5417, 0.0181, 0.3542, 0.0013],
[0.0961, 0.4163, 0.4765, 0.2000, 0.0975],
[0.1520, 0.2778, 0.0170, 0.3812, 0.9566]])
Edge feature [feat_1] is saved to ./ondisk_dataset_homograph/edge-feat-1.pt
Generate tasksο
OnDiskDataset
supports multiple tasks. For each task, we need to prepare training/validation/test sets respectively. Such sets usually vary among different tasks. In this tutorial, letβs create a Node Classification task and Link Prediction task.
Node Classification Taskο
For node classification task, we need node IDs and corresponding labels for each training/validation/test set. Like feature data, numpy arrays and torch tensors are supported for these sets.
[5]:
num_trains = int(num_nodes * 0.6)
num_vals = int(num_nodes * 0.2)
num_tests = num_nodes - num_trains - num_vals
ids = np.arange(num_nodes)
np.random.shuffle(ids)
nc_train_ids_path = os.path.join(base_dir, "nc-train-ids.npy")
nc_train_ids = ids[:num_trains]
print(f"Part of train ids for node classification: {nc_train_ids[:3]}")
np.save(nc_train_ids_path, nc_train_ids)
print(f"NC train ids are saved to {nc_train_ids_path}\n")
nc_train_labels_path = os.path.join(base_dir, "nc-train-labels.pt")
nc_train_labels = torch.randint(0, 10, (num_trains,))
print(f"Part of train labels for node classification: {nc_train_labels[:3]}")
torch.save(nc_train_labels, nc_train_labels_path)
print(f"NC train labels are saved to {nc_train_labels_path}\n")
nc_val_ids_path = os.path.join(base_dir, "nc-val-ids.npy")
nc_val_ids = ids[num_trains:num_trains+num_vals]
print(f"Part of val ids for node classification: {nc_val_ids[:3]}")
np.save(nc_val_ids_path, nc_val_ids)
print(f"NC val ids are saved to {nc_val_ids_path}\n")
nc_val_labels_path = os.path.join(base_dir, "nc-val-labels.pt")
nc_val_labels = torch.randint(0, 10, (num_vals,))
print(f"Part of val labels for node classification: {nc_val_labels[:3]}")
torch.save(nc_val_labels, nc_val_labels_path)
print(f"NC val labels are saved to {nc_val_labels_path}\n")
nc_test_ids_path = os.path.join(base_dir, "nc-test-ids.npy")
nc_test_ids = ids[-num_tests:]
print(f"Part of test ids for node classification: {nc_test_ids[:3]}")
np.save(nc_test_ids_path, nc_test_ids)
print(f"NC test ids are saved to {nc_test_ids_path}\n")
nc_test_labels_path = os.path.join(base_dir, "nc-test-labels.pt")
nc_test_labels = torch.randint(0, 10, (num_tests,))
print(f"Part of test labels for node classification: {nc_test_labels[:3]}")
torch.save(nc_test_labels, nc_test_labels_path)
print(f"NC test labels are saved to {nc_test_labels_path}\n")
Part of train ids for node classification: [208 110 309]
NC train ids are saved to ./ondisk_dataset_homograph/nc-train-ids.npy
Part of train labels for node classification: tensor([6, 7, 0])
NC train labels are saved to ./ondisk_dataset_homograph/nc-train-labels.pt
Part of val ids for node classification: [458 839 656]
NC val ids are saved to ./ondisk_dataset_homograph/nc-val-ids.npy
Part of val labels for node classification: tensor([8, 7, 9])
NC val labels are saved to ./ondisk_dataset_homograph/nc-val-labels.pt
Part of test ids for node classification: [ 19 697 644]
NC test ids are saved to ./ondisk_dataset_homograph/nc-test-ids.npy
Part of test labels for node classification: tensor([9, 3, 9])
NC test labels are saved to ./ondisk_dataset_homograph/nc-test-labels.pt
Link Prediction Taskο
For link prediction task, we need node pairs or negative src/dsts for each training/validation/test set. Like feature data, numpy arrays and torch tensors are supported for these sets.
[6]:
num_trains = int(num_edges * 0.6)
num_vals = int(num_edges * 0.2)
num_tests = num_edges - num_trains - num_vals
lp_train_node_pairs_path = os.path.join(base_dir, "lp-train-node-pairs.npy")
lp_train_node_pairs = edges[:num_trains, :]
print(f"Part of train node pairs for link prediction: {lp_train_node_pairs[:3]}")
np.save(lp_train_node_pairs_path, lp_train_node_pairs)
print(f"LP train node pairs are saved to {lp_train_node_pairs_path}\n")
lp_val_node_pairs_path = os.path.join(base_dir, "lp-val-node-pairs.npy")
lp_val_node_pairs = edges[num_trains:num_trains+num_vals, :]
print(f"Part of val node pairs for link prediction: {lp_val_node_pairs[:3]}")
np.save(lp_val_node_pairs_path, lp_val_node_pairs)
print(f"LP val node pairs are saved to {lp_val_node_pairs_path}\n")
lp_val_neg_dsts_path = os.path.join(base_dir, "lp-val-neg-dsts.pt")
lp_val_neg_dsts = torch.randint(0, num_nodes, (num_vals, 10))
print(f"Part of val negative dsts for link prediction: {lp_val_neg_dsts[:3]}")
torch.save(lp_val_neg_dsts, lp_val_neg_dsts_path)
print(f"LP val negative dsts are saved to {lp_val_neg_dsts_path}\n")
lp_test_node_pairs_path = os.path.join(base_dir, "lp-test-node-pairs.npy")
lp_test_node_pairs = edges[-num_tests, :]
print(f"Part of test node pairs for link prediction: {lp_test_node_pairs[:3]}")
np.save(lp_test_node_pairs_path, lp_test_node_pairs)
print(f"LP test node pairs are saved to {lp_test_node_pairs_path}\n")
lp_test_neg_dsts_path = os.path.join(base_dir, "lp-test-neg-dsts.pt")
lp_test_neg_dsts = torch.randint(0, num_nodes, (num_tests, 10))
print(f"Part of test negative dsts for link prediction: {lp_test_neg_dsts[:3]}")
torch.save(lp_test_neg_dsts, lp_test_neg_dsts_path)
print(f"LP test negative dsts are saved to {lp_test_neg_dsts_path}\n")
Part of train node pairs for link prediction: [[634 822]
[285 606]
[268 361]]
LP train node pairs are saved to ./ondisk_dataset_homograph/lp-train-node-pairs.npy
Part of val node pairs for link prediction: [[638 687]
[ 28 478]
[804 472]]
LP val node pairs are saved to ./ondisk_dataset_homograph/lp-val-node-pairs.npy
Part of val negative dsts for link prediction: tensor([[ 97, 503, 353, 422, 890, 991, 474, 118, 125, 754],
[794, 32, 332, 869, 145, 90, 689, 846, 974, 919],
[793, 447, 801, 832, 195, 312, 477, 477, 843, 742]])
LP val negative dsts are saved to ./ondisk_dataset_homograph/lp-val-neg-dsts.pt
Part of test node pairs for link prediction: [906 728]
LP test node pairs are saved to ./ondisk_dataset_homograph/lp-test-node-pairs.npy
Part of test negative dsts for link prediction: tensor([[363, 917, 374, 593, 696, 203, 252, 505, 65, 83],
[637, 138, 71, 951, 905, 145, 962, 994, 714, 417],
[222, 599, 817, 854, 500, 325, 406, 298, 208, 664]])
LP test negative dsts are saved to ./ondisk_dataset_homograph/lp-test-neg-dsts.pt
Organize Data into YAML Fileο
Now we need to create a metadata.yaml
file which contains the paths, dadta types of graph structure, feature data, training/validation/test sets.
Notes: - all path should be relative to metadata.yaml
. - Below fields are optional and not specified in below example. - in_memory
: indicates whether to load dada into memory or mmap
. Default is True
.
Please refer to YAML specification for more details.
[7]:
yaml_content = f"""
dataset_name: homogeneous_graph_nc_lp
graph:
nodes:
- num: {num_nodes}
edges:
- format: csv
path: {os.path.basename(edges_path)}
feature_data:
- domain: node
name: feat_0
format: numpy
path: {os.path.basename(node_feat_0_path)}
- domain: node
name: feat_1
format: torch
path: {os.path.basename(node_feat_1_path)}
- domain: edge
name: feat_0
format: numpy
path: {os.path.basename(edge_feat_0_path)}
- domain: edge
name: feat_1
format: torch
path: {os.path.basename(edge_feat_1_path)}
tasks:
- name: node_classification
num_classes: 10
train_set:
- data:
- name: seed_nodes
format: numpy
path: {os.path.basename(nc_train_ids_path)}
- name: labels
format: torch
path: {os.path.basename(nc_train_labels_path)}
validation_set:
- data:
- name: seed_nodes
format: numpy
path: {os.path.basename(nc_val_ids_path)}
- name: labels
format: torch
path: {os.path.basename(nc_val_labels_path)}
test_set:
- data:
- name: seed_nodes
format: numpy
path: {os.path.basename(nc_test_ids_path)}
- name: labels
format: torch
path: {os.path.basename(nc_test_labels_path)}
- name: link_prediction
num_classes: 10
train_set:
- data:
- name: node_pairs
format: numpy
path: {os.path.basename(lp_train_node_pairs_path)}
validation_set:
- data:
- name: node_pairs
format: numpy
path: {os.path.basename(lp_val_node_pairs_path)}
- name: negative_dsts
format: torch
path: {os.path.basename(lp_val_neg_dsts_path)}
test_set:
- data:
- name: node_pairs
format: numpy
path: {os.path.basename(lp_test_node_pairs_path)}
- name: negative_dsts
format: torch
path: {os.path.basename(lp_test_neg_dsts_path)}
"""
metadata_path = os.path.join(base_dir, "metadata.yaml")
with open(metadata_path, "w") as f:
f.write(yaml_content)
Instantiate OnDiskDataset
ο
Now weβre ready to load dataset via dgl.graphbolt.OnDiskDataset
. When instantiating, we just pass in the base directory where metadata.yaml
file lies.
During first instantiation, GraphBolt preprocesses the raw data such as constructing FusedCSCSamplingGraph
from edges. All data including graph, feature data, training/validation/test sets are put into preprocessed
directory after preprocessing. Any following dataset loading will skip the preprocess stage.
After preprocessing, load()
is required to be called explicitly in order to load graph, feature data and tasks.
[8]:
dataset = gb.OnDiskDataset(base_dir).load()
graph = dataset.graph
print(f"Loaded graph: {graph}\n")
feature = dataset.feature
print(f"Loaded feature store: {feature}\n")
tasks = dataset.tasks
nc_task = tasks[0]
print(f"Loaded node classification task: {nc_task}\n")
lp_task = tasks[1]
print(f"Loaded link prediction task: {lp_task}\n")
Start to preprocess the on-disk dataset.
Finish preprocessing the on-disk dataset.
Loaded graph: FusedCSCSamplingGraph(csc_indptr=tensor([ 0, 6, 11, ..., 9981, 9991, 10000], dtype=torch.int32),
indices=tensor([ 82, 99, 267, ..., 995, 469, 10], dtype=torch.int32),
total_num_nodes=1000, num_edges=10000,)
Loaded feature store: TorchBasedFeatureStore(
{(<OnDiskFeatureDataDomain.NODE: 'node'>, None, 'feat_0'): TorchBasedFeature(
feature=tensor([[0.2435, 0.7213, 0.3454, 0.8100, 0.8042],
[0.9202, 0.7459, 0.1655, 0.3987, 0.4530],
[0.3770, 0.8586, 0.8192, 0.7440, 0.6601],
...,
[0.3633, 0.3403, 0.7562, 0.0764, 0.5614],
[0.0192, 0.2037, 0.3350, 0.4269, 0.1275],
[0.4615, 0.0228, 0.4669, 0.6982, 0.7113]], dtype=torch.float64),
metadata={},
), (<OnDiskFeatureDataDomain.NODE: 'node'>, None, 'feat_1'): TorchBasedFeature(
feature=tensor([[0.9658, 0.8694, 0.0736, 0.3241, 0.4566],
[0.7273, 0.1294, 0.3846, 0.5297, 0.9789],
[0.7273, 0.0582, 0.0536, 0.0130, 0.8722],
...,
[0.5342, 0.0865, 0.7177, 0.5614, 0.5862],
[0.3911, 0.7904, 0.2788, 0.2150, 0.8728],
[0.8593, 0.7647, 0.9011, 0.4649, 0.4697]]),
metadata={},
), (<OnDiskFeatureDataDomain.EDGE: 'edge'>, None, 'feat_0'): TorchBasedFeature(
feature=tensor([[0.0561, 0.5452, 0.0621, 0.7301, 0.0360],
[0.5568, 0.3982, 0.9444, 0.4939, 0.6679],
[0.8336, 0.8974, 0.2766, 0.2377, 0.2097],
...,
[0.4551, 0.8815, 0.8289, 0.0957, 0.6033],
[0.5589, 0.7092, 0.3844, 0.8435, 0.7082],
[0.8129, 0.2193, 0.2167, 0.2461, 0.6876]], dtype=torch.float64),
metadata={},
), (<OnDiskFeatureDataDomain.EDGE: 'edge'>, None, 'feat_1'): TorchBasedFeature(
feature=tensor([[0.2361, 0.5417, 0.0181, 0.3542, 0.0013],
[0.0961, 0.4163, 0.4765, 0.2000, 0.0975],
[0.1520, 0.2778, 0.0170, 0.3812, 0.9566],
...,
[0.8570, 0.1460, 0.8484, 0.9372, 0.1869],
[0.4008, 0.1138, 0.7057, 0.3723, 0.0517],
[0.2654, 0.8791, 0.9277, 0.4150, 0.4404]]),
metadata={},
)}
)
Loaded node classification task: OnDiskTask(validation_set=ItemSet(
items=(tensor([458, 839, 656, 245, 833, 642, 763, 81, 718, 400, 246, 553, 141, 939,
595, 163, 979, 880, 834, 95, 346, 360, 10, 50, 467, 807, 372, 332,
964, 237, 36, 527, 757, 587, 652, 410, 895, 207, 419, 166, 630, 608,
681, 704, 919, 550, 25, 546, 722, 985, 56, 8, 244, 318, 96, 23,
625, 501, 488, 522, 277, 348, 866, 621, 847, 503, 256, 424, 682, 147,
238, 801, 710, 574, 391, 126, 113, 338, 411, 297, 84, 958, 884, 538,
302, 983, 234, 345, 232, 761, 959, 563, 696, 75, 543, 85, 941, 333,
795, 728, 299, 342, 112, 310, 752, 672, 771, 136, 403, 914, 426, 377,
739, 167, 154, 460, 926, 450, 585, 190, 435, 791, 971, 774, 434, 796,
892, 330, 196, 694, 62, 499, 836, 683, 437, 635, 71, 966, 114, 518,
570, 551, 362, 931, 515, 953, 968, 82, 135, 298, 213, 639, 772, 799,
111, 544, 711, 900, 469, 569, 654, 594, 860, 650, 261, 195, 143, 814,
252, 335, 466, 365, 825, 325, 255, 920, 984, 690, 264, 826, 977, 805,
35, 493, 280, 832, 565, 92, 956, 646, 294, 513, 26, 287, 661, 367,
541, 212, 452, 651], dtype=torch.int32), tensor([8, 7, 9, 6, 9, 1, 0, 1, 4, 0, 2, 6, 0, 7, 5, 2, 7, 1, 5, 3, 9, 5, 9, 7,
7, 1, 4, 7, 7, 9, 7, 3, 0, 4, 6, 6, 0, 3, 5, 5, 4, 5, 3, 3, 9, 4, 7, 8,
5, 8, 4, 5, 1, 7, 4, 5, 7, 2, 7, 5, 4, 6, 8, 8, 4, 5, 3, 3, 2, 1, 3, 2,
2, 8, 6, 7, 5, 4, 9, 4, 3, 7, 2, 2, 9, 7, 0, 4, 0, 0, 9, 4, 1, 8, 7, 8,
6, 1, 4, 2, 5, 7, 3, 0, 9, 8, 2, 6, 6, 5, 8, 3, 6, 3, 8, 8, 5, 6, 9, 6,
9, 1, 0, 1, 8, 9, 6, 5, 8, 5, 9, 7, 0, 1, 3, 5, 5, 3, 2, 0, 8, 2, 2, 0,
2, 8, 2, 0, 2, 0, 5, 1, 8, 4, 1, 4, 5, 6, 2, 4, 4, 3, 2, 9, 4, 4, 3, 9,
2, 9, 0, 6, 0, 3, 8, 1, 2, 6, 7, 9, 9, 8, 1, 2, 9, 0, 6, 6, 1, 6, 7, 5,
7, 9, 1, 6, 6, 2, 6, 6])),
names=('seed_nodes', 'labels'),
),
train_set=ItemSet(
items=(tensor([208, 110, 309, 481, 186, 829, 1, 699, 45, 3, 780, 253, 487, 371,
837, 210, 449, 68, 534, 34, 508, 995, 224, 122, 381, 781, 413, 940,
962, 670, 331, 189, 868, 765, 98, 275, 820, 726, 124, 97, 703, 258,
730, 818, 932, 352, 9, 951, 938, 827, 109, 459, 750, 133, 917, 281,
897, 662, 293, 479, 5, 758, 491, 375, 741, 263, 30, 908, 898, 198,
643, 477, 598, 57, 576, 70, 422, 306, 592, 13, 273, 16, 990, 120,
52, 736, 440, 40, 779, 446, 286, 249, 134, 571, 118, 545, 382, 396,
47, 731, 463, 103, 312, 901, 102, 295, 993, 506, 555, 415, 311, 421,
708, 810, 815, 762, 583, 744, 392, 379, 242, 653, 720, 300, 168, 992,
988, 137, 573, 627, 548, 142, 20, 148, 453, 924, 890, 816, 438, 28,
848, 577, 658, 945, 53, 745, 954, 813, 999, 184, 498, 525, 99, 727,
229, 637, 390, 916, 874, 432, 32, 937, 659, 457, 359, 580, 448, 296,
49, 279, 316, 2, 678, 90, 547, 150, 472, 285, 313, 108, 960, 910,
423, 439, 65, 626, 552, 61, 259, 262, 405, 676, 698, 188, 191, 179,
844, 117, 271, 620, 222, 41, 55, 409, 610, 823, 624, 873, 139, 178,
858, 185, 889, 863, 716, 740, 402, 852, 947, 824, 358, 738, 612, 243,
855, 923, 307, 266, 93, 641, 364, 536, 428, 715, 667, 719, 517, 465,
528, 63, 713, 248, 989, 153, 399, 267, 385, 946, 430, 773, 930, 756,
942, 66, 151, 751, 104, 883, 420, 680, 759, 206, 74, 725, 334, 447,
215, 80, 174, 619, 851, 955, 328, 341, 94, 835, 980, 305, 909, 567,
378, 905, 734, 929, 461, 841, 217, 530, 746, 130, 611, 950, 516, 339,
315, 613, 128, 949, 429, 436, 975, 915, 521, 324, 872, 578, 597, 961,
383, 633, 353, 629, 519, 830, 257, 239, 254, 660, 709, 806, 278, 899,
268, 556, 976, 233, 319, 106, 235, 686, 301, 197, 288, 733, 468, 398,
373, 902, 692, 64, 175, 443, 384, 925, 482, 146, 272, 105, 145, 838,
777, 314, 878, 869, 42, 981, 564, 842, 557, 350, 72, 205, 615, 494,
176, 386, 182, 140, 504, 223, 986, 797, 636, 673, 123, 374, 705, 149,
723, 39, 679, 265, 606, 322, 700, 291, 427, 691, 24, 138, 349, 107,
957, 865, 886, 769, 768, 181, 127, 778, 226, 512, 380, 354, 43, 879,
158, 793, 336, 475, 529, 329, 416, 933, 560, 356, 881, 395, 260, 480,
231, 943, 969, 363, 590, 870, 572, 404, 445, 671, 604, 628, 645, 775,
31, 982, 476, 987, 347, 495, 802, 89, 502, 888, 21, 303, 782, 721,
649, 743, 748, 510, 152, 500, 177, 269, 618, 684, 906, 340, 230, 809,
183, 783, 973, 970, 317, 792, 526, 216, 393, 170, 6, 38, 228, 401,
194, 59, 617, 119, 351, 79, 812, 132, 471, 589, 218, 688, 193, 199,
29, 283, 894, 251, 492, 417, 790, 702, 675, 634, 591, 584, 849, 48,
893, 677, 214, 996, 406, 533, 221, 803, 160, 282, 418, 304, 647, 821,
729, 14, 717, 804, 871, 674, 507, 289, 566, 657, 211, 770, 828, 609,
37, 747, 464, 764, 44, 724, 54, 561, 853, 817, 562, 433, 131, 542,
509, 412, 655, 887, 483, 867, 840, 87, 575, 695, 664, 388, 913, 742,
753, 921, 737, 963, 607, 490, 387, 219, 978, 172, 169, 707, 125, 456,
811, 640, 164, 549, 666, 857, 292, 78, 15, 568, 484, 240, 918, 203,
308, 187, 614, 936, 693, 776, 431, 605, 454, 712, 856, 408],
dtype=torch.int32), tensor([6, 7, 0, 5, 4, 0, 7, 5, 6, 6, 6, 2, 7, 1, 2, 2, 4, 8, 4, 4, 5, 0, 2, 4,
9, 0, 2, 8, 5, 3, 1, 6, 9, 4, 9, 7, 2, 0, 2, 6, 5, 2, 7, 4, 7, 5, 8, 5,
5, 3, 3, 3, 6, 9, 9, 1, 3, 4, 6, 3, 8, 2, 7, 4, 8, 9, 8, 9, 3, 9, 9, 8,
9, 3, 0, 5, 6, 1, 4, 9, 5, 1, 3, 9, 3, 5, 8, 3, 4, 0, 0, 9, 8, 0, 4, 9,
7, 0, 2, 3, 1, 3, 6, 3, 3, 4, 0, 5, 2, 7, 3, 0, 0, 5, 2, 9, 7, 4, 6, 1,
1, 6, 7, 8, 9, 1, 0, 4, 5, 6, 5, 2, 6, 5, 7, 8, 2, 6, 6, 5, 0, 4, 5, 4,
1, 0, 2, 8, 9, 6, 1, 3, 5, 6, 4, 7, 5, 2, 9, 4, 7, 2, 8, 3, 0, 0, 1, 3,
4, 4, 1, 7, 4, 5, 8, 1, 9, 4, 6, 9, 2, 9, 3, 9, 1, 7, 1, 5, 8, 5, 6, 4,
1, 8, 2, 1, 5, 6, 7, 9, 5, 8, 9, 1, 0, 6, 6, 9, 6, 1, 9, 3, 6, 4, 8, 5,
7, 2, 1, 0, 8, 0, 1, 5, 8, 6, 8, 8, 9, 9, 5, 4, 8, 7, 7, 6, 0, 5, 1, 0,
0, 6, 4, 0, 5, 0, 5, 9, 9, 9, 1, 3, 2, 5, 1, 1, 9, 5, 2, 0, 0, 1, 6, 7,
1, 9, 5, 9, 0, 0, 8, 5, 5, 3, 1, 3, 3, 0, 8, 6, 8, 3, 4, 4, 7, 4, 0, 9,
9, 8, 2, 8, 1, 7, 6, 4, 7, 2, 8, 9, 6, 2, 4, 1, 8, 2, 0, 7, 5, 8, 7, 2,
7, 8, 0, 7, 4, 6, 5, 4, 9, 0, 3, 9, 4, 7, 9, 5, 7, 4, 7, 5, 1, 5, 3, 1,
8, 4, 4, 2, 1, 6, 1, 1, 5, 5, 3, 2, 2, 9, 6, 2, 9, 5, 6, 3, 9, 7, 3, 1,
2, 1, 8, 8, 8, 7, 5, 7, 6, 4, 7, 8, 1, 2, 8, 0, 8, 0, 8, 7, 3, 3, 9, 6,
0, 5, 0, 9, 5, 3, 5, 6, 0, 4, 5, 8, 0, 4, 8, 0, 9, 8, 3, 7, 8, 9, 6, 2,
4, 7, 4, 4, 2, 8, 2, 9, 9, 9, 2, 8, 0, 0, 5, 1, 9, 8, 4, 9, 8, 3, 1, 2,
7, 4, 3, 8, 6, 2, 2, 0, 7, 3, 5, 6, 6, 4, 4, 2, 3, 0, 0, 5, 6, 4, 7, 6,
1, 7, 4, 8, 8, 0, 2, 3, 9, 7, 4, 1, 9, 0, 4, 9, 1, 1, 3, 5, 0, 2, 9, 5,
4, 8, 0, 4, 4, 6, 5, 2, 0, 7, 2, 3, 5, 5, 1, 8, 5, 3, 3, 5, 7, 8, 4, 8,
6, 1, 1, 9, 7, 8, 8, 5, 8, 7, 1, 7, 7, 5, 8, 6, 1, 9, 1, 5, 1, 0, 6, 5,
2, 4, 9, 1, 8, 9, 1, 1, 8, 7, 3, 2, 2, 5, 2, 8, 1, 4, 7, 3, 0, 0, 7, 9,
1, 3, 7, 2, 0, 9, 6, 2, 3, 8, 7, 9, 7, 8, 1, 6, 2, 6, 4, 0, 4, 5, 3, 4,
6, 7, 1, 2, 9, 3, 6, 8, 4, 9, 0, 6, 1, 4, 1, 7, 7, 1, 4, 8, 4, 0, 6, 1])),
names=('seed_nodes', 'labels'),
),
test_set=ItemSet(
items=(tensor([ 19, 697, 644, 845, 83, 616, 4, 558, 927, 200, 201, 600, 794, 531,
767, 632, 785, 451, 875, 896, 357, 540, 156, 250, 394, 33, 532, 754,
320, 559, 854, 86, 523, 593, 91, 0, 760, 944, 88, 602, 554, 885,
603, 462, 974, 474, 171, 407, 129, 17, 972, 470, 209, 343, 369, 11,
524, 69, 441, 877, 800, 327, 904, 121, 247, 227, 73, 241, 321, 732,
689, 115, 370, 67, 225, 922, 911, 505, 967, 808, 843, 912, 862, 337,
12, 952, 101, 444, 663, 891, 161, 648, 535, 882, 935, 397, 18, 784,
496, 601, 846, 831, 819, 582, 539, 787, 789, 284, 685, 58, 859, 486,
162, 497, 786, 165, 116, 588, 7, 51, 478, 714, 220, 323, 344, 687,
159, 361, 326, 485, 599, 389, 520, 537, 60, 274, 46, 270, 155, 798,
579, 735, 192, 511, 76, 100, 701, 928, 236, 77, 788, 425, 822, 669,
998, 489, 934, 864, 204, 749, 276, 442, 622, 366, 180, 173, 631, 596,
355, 157, 376, 290, 965, 903, 455, 586, 997, 27, 665, 623, 991, 473,
861, 907, 876, 755, 706, 202, 144, 581, 22, 948, 850, 766, 514, 668,
414, 368, 994, 638], dtype=torch.int32), tensor([9, 3, 9, 7, 6, 8, 7, 9, 5, 0, 6, 2, 2, 3, 8, 5, 0, 3, 1, 3, 1, 4, 1, 8,
3, 6, 5, 8, 1, 5, 5, 4, 5, 6, 6, 3, 7, 6, 3, 9, 9, 7, 2, 3, 3, 6, 1, 4,
1, 6, 4, 7, 4, 5, 0, 2, 7, 6, 9, 0, 1, 9, 7, 9, 1, 2, 1, 2, 6, 6, 6, 8,
2, 5, 6, 2, 8, 3, 1, 0, 1, 3, 9, 2, 7, 3, 9, 3, 8, 0, 7, 3, 3, 4, 1, 6,
0, 6, 8, 7, 9, 1, 6, 2, 4, 9, 6, 2, 9, 7, 1, 4, 2, 0, 6, 4, 7, 0, 4, 0,
7, 6, 9, 3, 7, 5, 5, 5, 0, 8, 7, 9, 2, 8, 2, 3, 7, 9, 4, 6, 9, 0, 8, 3,
6, 4, 5, 4, 6, 7, 9, 2, 7, 2, 2, 9, 9, 8, 3, 0, 1, 4, 6, 9, 9, 3, 5, 2,
0, 4, 6, 9, 1, 4, 3, 7, 8, 1, 8, 4, 0, 4, 5, 5, 4, 8, 7, 2, 5, 8, 2, 1,
5, 2, 6, 0, 3, 3, 4, 7])),
names=('seed_nodes', 'labels'),
),
metadata={'name': 'node_classification', 'num_classes': 10},)
Loaded link prediction task: OnDiskTask(validation_set=ItemSet(
items=(tensor([[638, 687],
[ 28, 478],
[804, 472],
...,
[411, 33],
[474, 864],
[789, 986]], dtype=torch.int32), tensor([[ 97, 503, 353, ..., 118, 125, 754],
[794, 32, 332, ..., 846, 974, 919],
[793, 447, 801, ..., 477, 843, 742],
...,
[291, 366, 981, ..., 693, 288, 259],
[812, 17, 13, ..., 163, 613, 428],
[ 36, 550, 427, ..., 450, 544, 945]], dtype=torch.int32)),
names=('node_pairs', 'negative_dsts'),
),
train_set=ItemSet(
items=(tensor([[634, 822],
[285, 606],
[268, 361],
...,
[553, 376],
[184, 33],
[650, 589]], dtype=torch.int32),),
names=('node_pairs',),
),
test_set=ItemSet(
items=(tensor([906, 728], dtype=torch.int32), tensor([[363, 917, 374, ..., 505, 65, 83],
[637, 138, 71, ..., 994, 714, 417],
[222, 599, 817, ..., 298, 208, 664],
...,
[756, 275, 846, ..., 959, 50, 12],
[654, 458, 223, ..., 985, 132, 151],
[912, 579, 766, ..., 879, 719, 998]], dtype=torch.int32)),
names=('node_pairs', 'negative_dsts'),
),
metadata={'name': 'link_prediction', 'num_classes': 10},)
/home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/2.1.x/lib/python3.8/site-packages/dgl-2.1.0-py3.8-linux-x86_64.egg/dgl/graphbolt/impl/ondisk_dataset.py:464: DGLWarning: Edge feature is stored, but edge IDs are not saved.
dgl_warning("Edge feature is stored, but edge IDs are not saved.")