RomanEmpireDatasetο
- class dgl.data.RomanEmpireDataset(raw_dir=None, force_reload=False, verbose=True, transform=None)[source]ο
Bases:
HeterophilousGraphDataset
Roman-empire dataset from the βA Critical Look at the Evaluation of GNNs under Heterophily: Are We Really Making Progress? <https://arxiv.org/abs/2302.11640>β__ paper.
This dataset is based on the Roman Empire article from English Wikipedia, which was selected since it is one of the longest articles on Wikipedia. Each node in the graph corresponds to one (non-unique) word in the text. Thus, the number of nodes in the graph is equal to the articleβs length. Two words are connected with an edge if at least one of the following two conditions holds: either these words follow each other in the text, or these words are connected in the dependency tree of the sentence (one word depends on the other). Thus, the graph is a chain graph with additional shortcut edges corresponding to syntactic dependencies between words. The class of a node is its syntactic role (17 most frequent roles were selected as unique classes and all the other roles were grouped into the 18th class). Node features are word embeddings.
Statistics:
Nodes: 22662
Edges: 65854
Classes: 18
Node features: 300
10 train/val/test splits
- Parameters:
raw_dir (str, optional) β Raw file directory to store the processed data. Default: ~/.dgl/
force_reload (bool, optional) β Whether to re-download the data source. Default: False
verbose (bool, optional) β Whether to print progress information. Default: True
transform (callable, optional) β A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access. Default: None
Examples
>>> from dgl.data import RomanEmpireDataset >>> dataset = RomanEmpireDataset() >>> g = dataset[0] >>> num_classes = dataset.num_classes
>>> # get node features >>> feat = g.ndata["feat"]
>>> # get the first data split >>> train_mask = g.ndata["train_mask"][:, 0] >>> val_mask = g.ndata["val_mask"][:, 0] >>> test_mask = g.ndata["test_mask"][:, 0]
>>> # get labels >>> label = g.ndata['label']
- __getitem__(idx)ο
Gets the data object at index.
- __len__()ο
The number of examples in the dataset.