Source code for dgl.data.tree

"""Tree-structured data.
Including:
    - Stanford Sentiment Treebank
"""
from __future__ import absolute_import

from collections import OrderedDict
import networkx as nx

import numpy as np
import os

from .dgl_dataset import DGLBuiltinDataset
from .. import backend as F
from .utils import _get_dgl_url, save_graphs, save_info, load_graphs, \
    load_info, deprecate_property
from ..convert import from_networkx

__all__ = ['SST', 'SSTDataset']


[docs]class SSTDataset(DGLBuiltinDataset): r"""Stanford Sentiment Treebank dataset. Each sample is the constituency tree of a sentence. The leaf nodes represent words. The word is a int value stored in the ``x`` feature field. The non-leaf node has a special value ``PAD_WORD`` in the ``x`` field. Each node also has a sentiment annotation: 5 classes (very negative, negative, neutral, positive and very positive). The sentiment label is a int value stored in the ``y`` feature field. Official site: `<http://nlp.stanford.edu/sentiment/index.html>`_ Statistics: - Train examples: 8,544 - Dev examples: 1,101 - Test examples: 2,210 - Number of classes for each node: 5 Parameters ---------- mode : str, optional Should be one of ['train', 'dev', 'test', 'tiny'] Default: train glove_embed_file : str, optional The path to pretrained glove embedding file. Default: None vocab_file : str, optional Optional vocabulary file. If not given, the default vacabulary file is used. Default: None raw_dir : str Raw file directory to download/contains the input data directory. Default: ~/.dgl/ force_reload : bool Whether to reload the dataset. Default: False verbose : bool Whether to print out progress information. Default: True. transform : callable, optional A transform that takes in a :class:`~dgl.DGLGraph` object and returns a transformed version. The :class:`~dgl.DGLGraph` object will be transformed before every access. Attributes ---------- vocab : OrderedDict Vocabulary of the dataset num_classes : int Number of classes for each node pretrained_emb: Tensor Pretrained glove embedding with respect the vocabulary. vocab_size : int The size of the vocabulary Notes ----- All the samples will be loaded and preprocessed in the memory first. Examples -------- >>> # get dataset >>> train_data = SSTDataset() >>> dev_data = SSTDataset(mode='dev') >>> test_data = SSTDataset(mode='test') >>> tiny_data = SSTDataset(mode='tiny') >>> >>> len(train_data) 8544 >>> train_data.num_classes 5 >>> glove_embed = train_data.pretrained_emb >>> train_data.vocab_size 19536 >>> train_data[0] Graph(num_nodes=71, num_edges=70, ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)} edata_schemes={}) >>> for tree in train_data: ... input_ids = tree.ndata['x'] ... labels = tree.ndata['y'] ... mask = tree.ndata['mask'] ... # your code here """ PAD_WORD = -1 # special pad word id UNK_WORD = -1 # out-of-vocabulary word id def __init__(self, mode='train', glove_embed_file=None, vocab_file=None, raw_dir=None, force_reload=False, verbose=False, transform=None): assert mode in ['train', 'dev', 'test', 'tiny'] _url = _get_dgl_url('dataset/sst.zip') self._glove_embed_file = glove_embed_file if mode == 'train' else None self.mode = mode self._vocab_file = vocab_file super(SSTDataset, self).__init__(name='sst', url=_url, raw_dir=raw_dir, force_reload=force_reload, verbose=verbose, transform=transform) def process(self): from nltk.corpus.reader import BracketParseCorpusReader # load vocab file self._vocab = OrderedDict() vocab_file = self._vocab_file if self._vocab_file is not None else os.path.join(self.raw_path, 'vocab.txt') with open(vocab_file, encoding='utf-8') as vf: for line in vf.readlines(): line = line.strip() self._vocab[line] = len(self._vocab) # filter glove if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file): glove_emb = {} with open(self._glove_embed_file, 'r', encoding='utf-8') as pf: for line in pf.readlines(): sp = line.split(' ') if sp[0].lower() in self._vocab: glove_emb[sp[0].lower()] = np.asarray([float(x) for x in sp[1:]]) files = ['{}.txt'.format(self.mode)] corpus = BracketParseCorpusReader(self.raw_path, files) sents = corpus.parsed_sents(files[0]) # initialize with glove pretrained_emb = [] fail_cnt = 0 for line in self._vocab.keys(): if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file): if not line.lower() in glove_emb: fail_cnt += 1 pretrained_emb.append(glove_emb.get(line.lower(), np.random.uniform(-0.05, 0.05, 300))) self._pretrained_emb = None if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file): self._pretrained_emb = F.tensor(np.stack(pretrained_emb, 0)) print('Miss word in GloVe {0:.4f}'.format(1.0 * fail_cnt / len(self._pretrained_emb))) # build trees self._trees = [] for sent in sents: self._trees.append(self._build_tree(sent)) def _build_tree(self, root): g = nx.DiGraph() def _rec_build(nid, node): for child in node: cid = g.number_of_nodes() if isinstance(child[0], str) or isinstance(child[0], bytes): # leaf node word = self.vocab.get(child[0].lower(), self.UNK_WORD) g.add_node(cid, x=word, y=int(child.label()), mask=1) else: g.add_node(cid, x=SSTDataset.PAD_WORD, y=int(child.label()), mask=0) _rec_build(cid, child) g.add_edge(cid, nid) # add root g.add_node(0, x=SSTDataset.PAD_WORD, y=int(root.label()), mask=0) _rec_build(0, root) ret = from_networkx(g, node_attrs=['x', 'y', 'mask']) return ret def has_cache(self): graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') vocab_path = os.path.join(self.save_path, 'vocab.pkl') return os.path.exists(graph_path) and os.path.exists(vocab_path) def save(self): graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') save_graphs(graph_path, self._trees) vocab_path = os.path.join(self.save_path, 'vocab.pkl') save_info(vocab_path, {'vocab': self.vocab}) if self.pretrained_emb: emb_path = os.path.join(self.save_path, 'emb.pkl') save_info(emb_path, {'embed': self.pretrained_emb}) def load(self): graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') vocab_path = os.path.join(self.save_path, 'vocab.pkl') emb_path = os.path.join(self.save_path, 'emb.pkl') self._trees = load_graphs(graph_path)[0] self._vocab = load_info(vocab_path)['vocab'] self._pretrained_emb = None if os.path.exists(emb_path): self._pretrained_emb = load_info(emb_path)['embed'] @property def vocab(self): r""" Vocabulary Returns ------- OrderedDict """ return self._vocab @property def pretrained_emb(self): r"""Pre-trained word embedding, if given.""" return self._pretrained_emb
[docs] def __getitem__(self, idx): r""" Get graph by index Parameters ---------- idx : int Returns ------- :class:`dgl.DGLGraph` graph structure, word id for each node, node labels and masks. - ``ndata['x']``: word id of the node - ``ndata['y']:`` label of the node - ``ndata['mask']``: 1 if the node is a leaf, otherwise 0 """ if self._transform is None: return self._trees[idx] else: return self._transform(self._trees[idx])
[docs] def __len__(self): r"""Number of graphs in the dataset.""" return len(self._trees)
@property def vocab_size(self): r"""Vocabulary size.""" return len(self._vocab) @property def num_classes(self): r"""Number of classes for each node.""" return 5
SST = SSTDataset