RedditDataset

class dgl.data.RedditDataset(self_loop=False, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]

Bases: DGLBuiltinDataset

Reddit dataset for community detection (node classification)

This is a graph dataset from Reddit posts made in the month of September, 2014. The node label in this case is the community, or β€œsubreddit”, that a post belongs to. The authors sampled 50 large communities and built a post-to-post graph, connecting posts if the same user comments on both. In total this dataset contains 232,965 posts with an average degree of 492. We use the first 20 days for training and the remaining days for testing (with 30% used for validation).

Reference: http://snap.stanford.edu/graphsage/

Statistics

  • Nodes: 232,965

  • Edges: 114,615,892

  • Node feature size: 602

  • Number of training samples: 153,431

  • Number of validation samples: 23,831

  • Number of test samples: 55,703

Parameters:
  • self_loop (bool) – Whether load dataset with self loop connections. Default: False

  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – Whether to print out progress information. Default: True.

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

num_classes

Number of classes for each node

Type:

int

Examples

>>> data = RedditDataset()
>>> g = data[0]
>>> num_classes = data.num_classes
>>>
>>> # get node feature
>>> feat = g.ndata['feat']
>>>
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
>>>
>>> # get labels
>>> label = g.ndata['label']
>>>
>>> # Train, Validation and Test
__getitem__(idx)[source]

Get graph by index

Parameters:

idx (int) – Item index

Returns:

graph structure, node labels, node features and splitting masks:

  • ndata['label']: node label

  • ndata['feat']: node feature

  • ndata['train_mask']: mask for training node set

  • ndata['val_mask']: mask for validation node set

  • ndata['test_mask']: mask for test node set

Return type:

dgl.DGLGraph

__len__()[source]

Number of graphs in the dataset