RomanEmpireDataset

class dgl.data.RomanEmpireDataset(raw_dir=None, force_reload=False, verbose=True, transform=None)[source]

Bases: HeterophilousGraphDataset

Roman-empire dataset from the β€˜A Critical Look at the Evaluation of GNNs under Heterophily: Are We Really Making Progress? <https://arxiv.org/abs/2302.11640>’__ paper.

This dataset is based on the Roman Empire article from English Wikipedia, which was selected since it is one of the longest articles on Wikipedia. Each node in the graph corresponds to one (non-unique) word in the text. Thus, the number of nodes in the graph is equal to the article’s length. Two words are connected with an edge if at least one of the following two conditions holds: either these words follow each other in the text, or these words are connected in the dependency tree of the sentence (one word depends on the other). Thus, the graph is a chain graph with additional shortcut edges corresponding to syntactic dependencies between words. The class of a node is its syntactic role (17 most frequent roles were selected as unique classes and all the other roles were grouped into the 18th class). Node features are word embeddings.

Statistics:

  • Nodes: 22662

  • Edges: 65854

  • Classes: 18

  • Node features: 300

  • 10 train/val/test splits

Parameters:
  • raw_dir (str, optional) – Raw file directory to store the processed data. Default: ~/.dgl/

  • force_reload (bool, optional) – Whether to re-download the data source. Default: False

  • verbose (bool, optional) – Whether to print progress information. Default: True

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access. Default: None

num_classes

Number of node classes

Type:

int

Examples

>>> from dgl.data import RomanEmpireDataset
>>> dataset = RomanEmpireDataset()
>>> g = dataset[0]
>>> num_classes = dataset.num_classes
>>> # get node features
>>> feat = g.ndata["feat"]
>>> # get the first data split
>>> train_mask = g.ndata["train_mask"][:, 0]
>>> val_mask = g.ndata["val_mask"][:, 0]
>>> test_mask = g.ndata["test_mask"][:, 0]
>>> # get labels
>>> label = g.ndata['label']
__getitem__(idx)

Gets the data object at index.

__len__()

The number of examples in the dataset.