SSTDatasetο
- class dgl.data.SSTDataset(mode='train', glove_embed_file=None, vocab_file=None, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]ο
Bases:
DGLBuiltinDataset
Stanford Sentiment Treebank dataset.
Each sample is the constituency tree of a sentence. The leaf nodes represent words. The word is a int value stored in the
x
feature field. The non-leaf node has a special valuePAD_WORD
in thex
field. Each node also has a sentiment annotation: 5 classes (very negative, negative, neutral, positive and very positive). The sentiment label is a int value stored in they
feature field. Official site: http://nlp.stanford.edu/sentiment/index.htmlStatistics:
Train examples: 8,544
Dev examples: 1,101
Test examples: 2,210
Number of classes for each node: 5
- Parameters:
mode (str, optional) β Should be one of [βtrainβ, βdevβ, βtestβ, βtinyβ] Default: train
glove_embed_file (str, optional) β The path to pretrained glove embedding file. Default: None
vocab_file (str, optional) β Optional vocabulary file. If not given, the default vacabulary file is used. Default: None
raw_dir (str) β Raw file directory to download/contains the input data directory. Default: ~/.dgl/
force_reload (bool) β Whether to reload the dataset. Default: False
verbose (bool) β Whether to print out progress information. Default: True.
transform (callable, optional) β A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.
- vocabο
Vocabulary of the dataset
- Type:
OrderedDict
- pretrained_embο
Pretrained glove embedding with respect the vocabulary.
- Type:
Tensor
Notes
All the samples will be loaded and preprocessed in the memory first.
Examples
>>> # get dataset >>> train_data = SSTDataset() >>> dev_data = SSTDataset(mode='dev') >>> test_data = SSTDataset(mode='test') >>> tiny_data = SSTDataset(mode='tiny') >>> >>> len(train_data) 8544 >>> train_data.num_classes 5 >>> glove_embed = train_data.pretrained_emb >>> train_data.vocab_size 19536 >>> train_data[0] Graph(num_nodes=71, num_edges=70, ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)} edata_schemes={}) >>> for tree in train_data: ... input_ids = tree.ndata['x'] ... labels = tree.ndata['y'] ... mask = tree.ndata['mask'] ... # your code here