RedditDatasetο
- class dgl.data.RedditDataset(self_loop=False, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]ο
Bases:
DGLBuiltinDataset
Reddit dataset for community detection (node classification)
This is a graph dataset from Reddit posts made in the month of September, 2014. The node label in this case is the community, or βsubredditβ, that a post belongs to. The authors sampled 50 large communities and built a post-to-post graph, connecting posts if the same user comments on both. In total this dataset contains 232,965 posts with an average degree of 492. We use the first 20 days for training and the remaining days for testing (with 30% used for validation).
Reference: http://snap.stanford.edu/graphsage/
Statistics
Nodes: 232,965
Edges: 114,615,892
Node feature size: 602
Number of training samples: 153,431
Number of validation samples: 23,831
Number of test samples: 55,703
- Parameters:
self_loop (bool) β Whether load dataset with self loop connections. Default: False
raw_dir (str) β Raw file directory to download/contains the input data directory. Default: ~/.dgl/
force_reload (bool) β Whether to reload the dataset. Default: False
verbose (bool) β Whether to print out progress information. Default: True.
transform (callable, optional) β A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.
Examples
>>> data = RedditDataset() >>> g = data[0] >>> num_classes = data.num_classes >>> >>> # get node feature >>> feat = g.ndata['feat'] >>> >>> # get data split >>> train_mask = g.ndata['train_mask'] >>> val_mask = g.ndata['val_mask'] >>> test_mask = g.ndata['test_mask'] >>> >>> # get labels >>> label = g.ndata['label'] >>> >>> # Train, Validation and Test
- __getitem__(idx)[source]ο
Get graph by index
- Parameters:
idx (int) β Item index
- Returns:
graph structure, node labels, node features and splitting masks:
ndata['label']
: node labelndata['feat']
: node featurendata['train_mask']
οΌ mask for training node setndata['val_mask']
: mask for validation node setndata['test_mask']:
mask for test node set
- Return type: