WN18Dataset

class dgl.data.WN18Dataset(reverse=True, raw_dir=None, force_reload=False, verbose=True, transform=None)[source]

Bases: dgl.data.knowledge_graph.KnowledgeGraphDataset

WN18 link prediction dataset.

The WN18 dataset was introduced in Translating Embeddings for Modeling Multi-relational Data. It included the full 18 relations scraped from WordNet for roughly 41,000 synsets. When creating the dataset, a reverse edge with reversed relation types are created for each edge by default.

WN18 dataset statistics:

  • Nodes: 40943

  • Number of relation types: 18

  • Number of reversed relation types: 18

  • Label Split:

    • Train: 141442

    • Valid: 5000

    • Test: 5000

Parameters
  • reverse (bool) – Whether to add reverse edge. Default True.

  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – Whether to print out progress information. Default: True.

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

num_nodes

Number of nodes

Type

int

num_rels

Number of relation types

Type

int

Examples

>>> dataset = WN18Dataset()
>>> g = dataset.graph
>>> e_type = g.edata['e_type']
>>>
>>> # get data split
>>> train_mask = g.edata['train_mask']
>>> val_mask = g.edata['val_mask']
>>>
>>> train_set = th.arange(g.number_of_edges())[train_mask]
>>> val_set = th.arange(g.number_of_edges())[val_mask]
>>>
>>> # build train_g
>>> train_edges = train_set
>>> train_g = g.edge_subgraph(train_edges,
                              relabel_nodes=False)
>>> train_g.edata['e_type'] = e_type[train_edges];
>>>
>>> # build val_g
>>> val_edges = th.cat([train_edges, val_edges])
>>> val_g = g.edge_subgraph(val_edges,
                            relabel_nodes=False)
>>> val_g.edata['e_type'] = e_type[val_edges];
>>>
>>> # Train, Validation and Test
>>>
__getitem__(idx)[source]

Gets the graph object

Parameters

idx (int) – Item index, WN18Dataset has only one graph object

Returns

The graph contains

  • edata['e_type']: edge relation type

  • edata['train_edge_mask']: positive training edge mask

  • edata['val_edge_mask']: positive validation edge mask

  • edata['test_edge_mask']: positive testing edge mask

  • edata['train_mask']: training edge set mask (include reversed training edges)

  • edata['val_mask']: validation edge set mask (include reversed validation edges)

  • edata['test_mask']: testing edge set mask (include reversed testing edges)

  • ndata['ntype']: node type. All 0 in this dataset

Return type

dgl.DGLGraph

__len__()[source]

The number of graphs in the dataset.