Source code for dgl.graphbolt.itemset

"""GraphBolt Itemset."""

import textwrap
from typing import Dict, Iterable, Iterator, Tuple, Union

import torch

__all__ = ["ItemSet", "ItemSetDict"]


def is_scalar(x):
    """Checks if the input is a scalar."""
    return (
        len(x.shape) == 0 if isinstance(x, torch.Tensor) else isinstance(x, int)
    )


[docs]class ItemSet: r"""A wrapper of iterable data or tuple of iterable data. All itemsets that represent an iterable of items should subclass it. Such form of itemset is particularly useful when items come from a stream. This class requires each input itemset to be iterable. Parameters ---------- items: Union[int, Iterable, Tuple[Iterable]] The items to be iterated over. If it is a single integer, a `range()` object will be created and iterated over. If it's multi-dimensional iterable such as `torch.Tensor`, it will be iterated over the first dimension. If it is a tuple, each item in the tuple is an iterable of items. names: Union[str, Tuple[str]], optional The names of the items. If it is a tuple, each name corresponds to an item in the tuple. The naming is arbitrary, but in general practice, the names should be chosen from ['labels', 'seeds', 'indexes'] to align with the attributes of class `dgl.graphbolt.MiniBatch`. Examples -------- >>> import torch >>> from dgl import graphbolt as gb 1. Integer: number of nodes. >>> num = 10 >>> item_set = gb.ItemSet(num, names="seeds") >>> list(item_set) [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5), tensor(6), tensor(7), tensor(8), tensor(9)] >>> item_set[:] tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) >>> item_set.names ('seeds',) 2. Torch scalar: number of nodes. Customizable dtype compared to Integer. >>> num = torch.tensor(10, dtype=torch.int32) >>> item_set = gb.ItemSet(num, names="seeds") >>> list(item_set) [tensor(0, dtype=torch.int32), tensor(1, dtype=torch.int32), tensor(2, dtype=torch.int32), tensor(3, dtype=torch.int32), tensor(4, dtype=torch.int32), tensor(5, dtype=torch.int32), tensor(6, dtype=torch.int32), tensor(7, dtype=torch.int32), tensor(8, dtype=torch.int32), tensor(9, dtype=torch.int32)] >>> item_set[:] tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32) >>> item_set.names ('seeds',) 3. Single iterable: seed nodes. >>> node_ids = torch.arange(0, 5) >>> item_set = gb.ItemSet(node_ids, names="seeds") >>> list(item_set) [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)] >>> item_set[:] tensor([0, 1, 2, 3, 4]) >>> item_set.names ('seeds',) 4. Tuple of iterables with same shape: seed nodes and labels. >>> node_ids = torch.arange(0, 5) >>> labels = torch.arange(5, 10) >>> item_set = gb.ItemSet( ... (node_ids, labels), names=("seeds", "labels")) >>> list(item_set) [(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)), (tensor(3), tensor(8)), (tensor(4), tensor(9))] >>> item_set[:] (tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])) >>> item_set.names ('seeds', 'labels') 5. Tuple of iterables with different shape: seeds and labels. >>> seeds = torch.arange(0, 10).reshape(-1, 2) >>> labels = torch.tensor([1, 1, 0, 0, 0]) >>> item_set = gb.ItemSet( ... (seeds, labels), names=("seeds", "lables")) >>> list(item_set) [(tensor([0, 1]), tensor([1])), (tensor([2, 3]), tensor([1])), (tensor([4, 5]), tensor([0])), (tensor([6, 7]), tensor([0])), (tensor([8, 9]), tensor([0]))] >>> item_set[:] (tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]), tensor([1, 1, 0, 0, 0])) >>> item_set.names ('seeds', 'labels') 6. Tuple of iterables with different shape: hyperlink and labels. >>> seeds = torch.arange(0, 10).reshape(-1, 5) >>> labels = torch.tensor([1, 0]) >>> item_set = gb.ItemSet( ... (seeds, labels), names=("seeds", "lables")) >>> list(item_set) [(tensor([0, 1, 2, 3, 4]), tensor([1])), (tensor([5, 6, 7, 8, 9]), tensor([0]))] >>> item_set[:] (tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), tensor([1, 0])) >>> item_set.names ('seeds', 'labels') """ def __init__( self, items: Union[int, torch.Tensor, Iterable, Tuple[Iterable]], names: Union[str, Tuple[str]] = None, ) -> None: if is_scalar(items): self._length = int(items) self._items = items self._num_items = 1 elif isinstance(items, tuple): try: self._length = len(items[0]) except TypeError: self._length = None if self._length is not None: if any(self._length != len(item) for item in items): raise ValueError("Size mismatch between items.") self._items = items self._num_items = len(items) else: try: self._length = len(items) except TypeError: self._length = None self._items = (items,) self._num_items = 1 if names is not None: if isinstance(names, tuple): self._names = names else: self._names = (names,) assert self._num_items == len(self._names), ( f"Number of items ({self._num_items}) and " f"names ({len(self._names)}) must match." ) else: self._names = None def __iter__(self) -> Iterator: if is_scalar(self._items): dtype = getattr(self._items, "dtype", torch.int64) yield from torch.arange(self._items, dtype=dtype) return if self._num_items == 1: yield from self._items[0] return if self._length is not None: # Use for-loop to iterate over the items. It can avoid a long # waiting time when the items are torch tensors. Since torch # tensors need to call self.unbind(0) to slice themselves. # While for-loops are slower than zip, they prevent excessive # wait times during the loading phase, and the impact on overall # performance during the training/testing stage is minimal. # For more details, see https://github.com/dmlc/dgl/pull/6293. for i in range(self._length): yield tuple(item[i] for item in self._items) else: # If the items are not Sized, we use zip to iterate over them. zip_items = zip(*self._items) for item in zip_items: yield tuple(item) def __getitem__(self, idx: Union[int, slice, Iterable]) -> Tuple: if self._length is None: raise TypeError( f"{type(self).__name__} instance doesn't support indexing." ) if is_scalar(self._items): if isinstance(idx, slice): start, stop, step = idx.indices(self._length) dtype = getattr(self._items, "dtype", torch.int64) return torch.arange(start, stop, step, dtype=dtype) if isinstance(idx, int): if idx < 0: idx += self._length if idx < 0 or idx >= self._length: raise IndexError( f"{type(self).__name__} index out of range." ) return ( torch.tensor(idx, dtype=self._items.dtype) if isinstance(self._items, torch.Tensor) else idx ) raise TypeError( f"{type(self).__name__} indices must be integer or slice." ) if self._num_items == 1: return self._items[0][idx] return tuple(item[idx] for item in self._items) @property def names(self) -> Tuple[str]: """Return the names of the items.""" return self._names @property def num_items(self) -> int: """Return the number of the items.""" return self._num_items def __len__(self): if self._length is None: raise TypeError( f"{type(self).__name__} instance doesn't have valid length." ) return self._length def __repr__(self) -> str: ret = ( f"{self.__class__.__name__}(\n" f" items={self._items},\n" f" names={self._names},\n" f")" ) return ret
[docs]class ItemSetDict: r"""Dictionary wrapper of **ItemSet**. Each item is retrieved by iterating over each itemset and returned with corresponding key as a dict. Parameters ---------- itemsets: Dict[str, ItemSet] Examples -------- >>> import torch >>> from dgl import graphbolt as gb 1. Single iterable: seed nodes. >>> node_ids_user = torch.arange(0, 5) >>> node_ids_item = torch.arange(5, 10) >>> item_set = gb.ItemSetDict({ ... "user": gb.ItemSet(node_ids_user, names="seeds"), ... "item": gb.ItemSet(node_ids_item, names="seeds")}) >>> list(item_set) [{"user": tensor(0)}, {"user": tensor(1)}, {"user": tensor(2)}, {"user": tensor(3)}, {"user": tensor(4)}, {"item": tensor(5)}, {"item": tensor(6)}, {"item": tensor(7)}, {"item": tensor(8)}, {"item": tensor(9)}}] >>> item_set[:] {"user": tensor([0, 1, 2, 3, 4]), "item": tensor([5, 6, 7, 8, 9])} >>> item_set.names ('seeds',) 2. Tuple of iterables with same shape: seed nodes and labels. >>> node_ids_user = torch.arange(0, 2) >>> labels_user = torch.arange(0, 2) >>> node_ids_item = torch.arange(2, 5) >>> labels_item = torch.arange(2, 5) >>> item_set = gb.ItemSetDict({ ... "user": gb.ItemSet( ... (node_ids_user, labels_user), ... names=("seeds", "labels")), ... "item": gb.ItemSet( ... (node_ids_item, labels_item), ... names=("seeds", "labels"))}) >>> list(item_set) [{"user": (tensor(0), tensor(0))}, {"user": (tensor(1), tensor(1))}, {"item": (tensor(2), tensor(2))}, {"item": (tensor(3), tensor(3))}, {"item": (tensor(4), tensor(4))}}] >>> item_set[:] {"user": (tensor([0, 1]), tensor([0, 1])), "item": (tensor([2, 3, 4]), tensor([2, 3, 4]))} >>> item_set.names ('seeds', 'labels') 3. Tuple of iterables with different shape: seeds and labels. >>> seeds_like = torch.arange(0, 4).reshape(-1, 2) >>> labels_like = torch.tensor([1, 0]) >>> seeds_follow = torch.arange(0, 6).reshape(-1, 2) >>> labels_follow = torch.tensor([1, 1, 0]) >>> item_set = gb.ItemSetDict({ ... "user:like:item": gb.ItemSet( ... (seeds_like, labels_like), ... names=("seeds", "labels")), ... "user:follow:user": gb.ItemSet( ... (seeds_follow, labels_follow), ... names=("seeds", "labels"))}) >>> list(item_set) [{'user:like:item': (tensor([0, 1]), tensor(1))}, {'user:like:item': (tensor([2, 3]), tensor(0))}, {'user:follow:user': (tensor([0, 1]), tensor(1))}, {'user:follow:user': (tensor([2, 3]), tensor(1))}, {'user:follow:user': (tensor([4, 5]), tensor(0))}] >>> item_set[:] {'user:like:item': (tensor([[0, 1], [2, 3]]), tensor([1, 0])), 'user:follow:user': (tensor([[0, 1], [2, 3], [4, 5]]), tensor([1, 1, 0]))} >>> item_set.names ('seeds', 'labels') 4. Tuple of iterables with different shape: hyperlink and labels. >>> first_seeds = torch.arange(0, 6).reshape(-1, 3) >>> first_labels = torch.tensor([1, 0]) >>> second_seeds = torch.arange(0, 2).reshape(-1, 1) >>> second_labels = torch.tensor([1, 0]) >>> item_set = gb.ItemSetDict({ ... "query:user:item": gb.ItemSet( ... (first_seeds, first_labels), ... names=("seeds", "labels")), ... "user": gb.ItemSet( ... (second_seeds, second_labels), ... names=("seeds", "labels"))}) >>> list(item_set) [{'query:user:item': (tensor([0, 1, 2]), tensor(1))}, {'query:user:item': (tensor([3, 4, 5]), tensor(0))}, {'user': (tensor([0]), tensor(1))}, {'user': (tensor([1]), tensor(0))}] >>> item_set[:] {'query:user:item': (tensor([[0, 1, 2], [3, 4, 5]]), tensor([1, 0])), 'user': (tensor([[0], [1]]),tensor([1, 0]))} >>> item_set.names ('seeds', 'labels') """ def __init__(self, itemsets: Dict[str, ItemSet]) -> None: self._itemsets = itemsets self._names = itemsets[list(itemsets.keys())[0]].names assert all( self._names == itemset.names for itemset in itemsets.values() ), "All itemsets must have the same names." try: # For indexable itemsets, we compute the offsets for each itemset # in advance to speed up indexing. offsets = [0] + [ len(itemset) for itemset in self._itemsets.values() ] self._offsets = torch.tensor(offsets).cumsum(0) except TypeError: self._offsets = None def __iter__(self) -> Iterator: for key, itemset in self._itemsets.items(): for item in itemset: yield {key: item} def __len__(self) -> int: return sum(len(itemset) for itemset in self._itemsets.values()) def __getitem__(self, idx: Union[int, slice]) -> Dict[str, Tuple]: if self._offsets is None: raise TypeError( f"{type(self).__name__} instance doesn't support indexing." ) total_num = self._offsets[-1] if isinstance(idx, int): if idx < 0: idx += total_num if idx < 0 or idx >= total_num: raise IndexError(f"{type(self).__name__} index out of range.") offset_idx = torch.searchsorted(self._offsets, idx, right=True) offset_idx -= 1 idx -= self._offsets[offset_idx] key = list(self._itemsets.keys())[offset_idx] return {key: self._itemsets[key][idx]} elif isinstance(idx, slice): start, stop, step = idx.indices(total_num) assert step == 1, "Step must be 1." assert start < stop, "Start must be smaller than stop." data = {} offset_idx_start = max( 1, torch.searchsorted(self._offsets, start, right=False) ) keys = list(self._itemsets.keys()) for offset_idx in range(offset_idx_start, len(self._offsets)): key = keys[offset_idx - 1] data[key] = self._itemsets[key][ max(0, start - self._offsets[offset_idx - 1]) : stop - self._offsets[offset_idx - 1] ] if stop <= self._offsets[offset_idx]: break return data else: raise TypeError( f"{type(self).__name__} indices must be int or slice." ) @property def names(self) -> Tuple[str]: """Return the names of the items.""" return self._names def __repr__(self) -> str: ret = ( "{Classname}(\n" " itemsets={itemsets},\n" " names={names},\n" ")" ) itemsets_str = textwrap.indent( repr(self._itemsets), " " * len(" itemsets=") ).strip() return ret.format( Classname=self.__class__.__name__, itemsets=itemsets_str, names=self._names, )