"""GraphBolt Itemset."""
import textwrap
from typing import Dict, Iterable, Iterator, Sized, Tuple, Union
import torch
__all__ = ["ItemSet", "ItemSetDict"]
def is_scalar(x):
"""Checks if the input is a scalar."""
return (
len(x.shape) == 0 if isinstance(x, torch.Tensor) else isinstance(x, int)
)
[docs]class ItemSet:
r"""A wrapper of iterable data or tuple of iterable data.
All itemsets that represent an iterable of items should subclass it. Such
form of itemset is particularly useful when items come from a stream. This
class requires each input itemset to be iterable.
Parameters
----------
items: Union[int, Iterable, Tuple[Iterable]]
The items to be iterated over. If it is a single integer, a `range()`
object will be created and iterated over. If it's multi-dimensional
iterable such as `torch.Tensor`, it will be iterated over the first
dimension. If it is a tuple, each item in the tuple is an iterable of
items.
names: Union[str, Tuple[str]], optional
The names of the items. If it is a tuple, each name corresponds to an
item in the tuple. The naming is arbitrary, but in general practice,
the names should be chosen from ['seed_nodes', 'node_pairs', 'labels',
'seeds', 'negative_srcs', 'negative_dsts'] to align with the attributes
of class `dgl.graphbolt.MiniBatch`.
Examples
--------
>>> import torch
>>> from dgl import graphbolt as gb
1. Integer: number of nodes.
>>> num = 10
>>> item_set = gb.ItemSet(num, names="seed_nodes")
>>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5),
tensor(6), tensor(7), tensor(8), tensor(9)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> item_set.names
('seed_nodes',)
2. Torch scalar: number of nodes. Customizable dtype compared to Integer.
>>> num = torch.tensor(10, dtype=torch.int32)
>>> item_set = gb.ItemSet(num, names="seed_nodes")
>>> list(item_set)
[tensor(0, dtype=torch.int32), tensor(1, dtype=torch.int32),
tensor(2, dtype=torch.int32), tensor(3, dtype=torch.int32),
tensor(4, dtype=torch.int32), tensor(5, dtype=torch.int32),
tensor(6, dtype=torch.int32), tensor(7, dtype=torch.int32),
tensor(8, dtype=torch.int32), tensor(9, dtype=torch.int32)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)
>>> item_set.names
('seed_nodes',)
3. Single iterable: seed nodes.
>>> node_ids = torch.arange(0, 5)
>>> item_set = gb.ItemSet(node_ids, names="seed_nodes")
>>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4])
>>> item_set.names
('seed_nodes',)
4. Tuple of iterables with same shape: seed nodes and labels.
>>> node_ids = torch.arange(0, 5)
>>> labels = torch.arange(5, 10)
>>> item_set = gb.ItemSet(
... (node_ids, labels), names=("seed_nodes", "labels"))
>>> list(item_set)
[(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)),
(tensor(3), tensor(8)), (tensor(4), tensor(9))]
>>> item_set[:]
(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))
>>> item_set.names
('seed_nodes', 'labels')
5. Tuple of iterables with different shape: node pairs and negative dsts.
>>> node_pairs = torch.arange(0, 10).reshape(-1, 2)
>>> neg_dsts = torch.arange(10, 25).reshape(-1, 3)
>>> item_set = gb.ItemSet(
... (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts"))
>>> list(item_set)
[(tensor([0, 1]), tensor([10, 11, 12])),
(tensor([2, 3]), tensor([13, 14, 15])),
(tensor([4, 5]), tensor([16, 17, 18])),
(tensor([6, 7]), tensor([19, 20, 21])),
(tensor([8, 9]), tensor([22, 23, 24]))]
>>> item_set[:]
(tensor([[0, 1], [2, 3], [4, 5], [6, 7],[8, 9]]),
tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18], [19, 20, 21],
[22, 23, 24]]))
>>> item_set.names
('node_pairs', 'negative_dsts')
"""
def __init__(
self,
items: Union[int, torch.Tensor, Iterable, Tuple[Iterable]],
names: Union[str, Tuple[str]] = None,
) -> None:
if isinstance(items, tuple) or is_scalar(items):
self._items = items
else:
self._items = (items,)
if names is not None:
num_items = (
len(self._items) if isinstance(self._items, tuple) else 1
)
if isinstance(names, tuple):
self._names = names
else:
self._names = (names,)
assert num_items == len(self._names), (
f"Number of items ({num_items}) and "
f"names ({len(self._names)}) must match."
)
else:
self._names = None
def __iter__(self) -> Iterator:
if is_scalar(self._items):
dtype = getattr(self._items, "dtype", torch.int64)
yield from torch.arange(self._items, dtype=dtype)
return
if len(self._items) == 1:
yield from self._items[0]
return
if isinstance(self._items[0], Sized):
items_len = len(self._items[0])
# Use for-loop to iterate over the items. It can avoid a long
# waiting time when the items are torch tensors. Since torch
# tensors need to call self.unbind(0) to slice themselves.
# While for-loops are slower than zip, they prevent excessive
# wait times during the loading phase, and the impact on overall
# performance during the training/testing stage is minimal.
# For more details, see https://github.com/dmlc/dgl/pull/6293.
for i in range(items_len):
yield tuple(item[i] for item in self._items)
else:
# If the items are not Sized, we use zip to iterate over them.
zip_items = zip(*self._items)
for item in zip_items:
yield tuple(item)
def __len__(self) -> int:
if is_scalar(self._items):
return int(self._items)
if isinstance(self._items[0], Sized):
return len(self._items[0])
raise TypeError(
f"{type(self).__name__} instance doesn't have valid length."
)
def __getitem__(self, idx: Union[int, slice, Iterable]) -> Tuple:
try:
len(self)
except TypeError:
raise TypeError(
f"{type(self).__name__} instance doesn't support indexing."
)
if is_scalar(self._items):
if isinstance(idx, slice):
start, stop, step = idx.indices(int(self._items))
dtype = getattr(self._items, "dtype", torch.int64)
return torch.arange(start, stop, step, dtype=dtype)
if isinstance(idx, int):
if idx < 0:
idx += self._items
if idx < 0 or idx >= self._items:
raise IndexError(
f"{type(self).__name__} index out of range."
)
return (
torch.tensor(idx, dtype=self._items.dtype)
if isinstance(self._items, torch.Tensor)
else idx
)
raise TypeError(
f"{type(self).__name__} indices must be integer or slice."
)
if len(self._items) == 1:
return self._items[0][idx]
return tuple(item[idx] for item in self._items)
@property
def names(self) -> Tuple[str]:
"""Return the names of the items."""
return self._names
def __repr__(self) -> str:
ret = (
f"{self.__class__.__name__}(\n"
f" items={self._items},\n"
f" names={self._names},\n"
f")"
)
return ret
[docs]class ItemSetDict:
r"""Dictionary wrapper of **ItemSet**.
Each item is retrieved by iterating over each itemset and returned with
corresponding key as a dict.
Parameters
----------
itemsets: Dict[str, ItemSet]
Examples
--------
>>> import torch
>>> from dgl import graphbolt as gb
1. Single iterable: seed nodes.
>>> node_ids_user = torch.arange(0, 5)
>>> node_ids_item = torch.arange(5, 10)
>>> item_set = gb.ItemSetDict({
... "user": gb.ItemSet(node_ids_user, names="seed_nodes"),
... "item": gb.ItemSet(node_ids_item, names="seed_nodes")})
>>> list(item_set)
[{"user": tensor(0)}, {"user": tensor(1)}, {"user": tensor(2)},
{"user": tensor(3)}, {"user": tensor(4)}, {"item": tensor(5)},
{"item": tensor(6)}, {"item": tensor(7)}, {"item": tensor(8)},
{"item": tensor(9)}}]
>>> item_set[:]
{"user": tensor([0, 1, 2, 3, 4]), "item": tensor([5, 6, 7, 8, 9])}
>>> item_set.names
('seed_nodes',)
2. Tuple of iterables with same shape: seed nodes and labels.
>>> node_ids_user = torch.arange(0, 2)
>>> labels_user = torch.arange(0, 2)
>>> node_ids_item = torch.arange(2, 5)
>>> labels_item = torch.arange(2, 5)
>>> item_set = gb.ItemSetDict({
... "user": gb.ItemSet(
... (node_ids_user, labels_user),
... names=("seed_nodes", "labels")),
... "item": gb.ItemSet(
... (node_ids_item, labels_item),
... names=("seed_nodes", "labels"))})
>>> list(item_set)
[{"user": (tensor(0), tensor(0))}, {"user": (tensor(1), tensor(1))},
{"item": (tensor(2), tensor(2))}, {"item": (tensor(3), tensor(3))},
{"item": (tensor(4), tensor(4))}}]
>>> item_set[:]
{"user": (tensor([0, 1]), tensor([0, 1])),
"item": (tensor([2, 3, 4]), tensor([2, 3, 4]))}
>>> item_set.names
('seed_nodes', 'labels')
3. Tuple of iterables with different shape: node pairs and negative dsts.
>>> node_pairs_like = torch.arange(0, 4).reshape(-1, 2)
>>> neg_dsts_like = torch.arange(4, 10).reshape(-1, 3)
>>> node_pairs_follow = torch.arange(0, 6).reshape(-1, 2)
>>> neg_dsts_follow = torch.arange(6, 15).reshape(-1, 3)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(
... (node_pairs_like, neg_dsts_like),
... names=("node_pairs", "negative_dsts")),
... "user:follow:user": gb.ItemSet(
... (node_pairs_follow, neg_dsts_follow),
... names=("node_pairs", "negative_dsts"))})
>>> list(item_set)
[{"user:like:item": (tensor([0, 1]), tensor([4, 5, 6]))},
{"user:like:item": (tensor([2, 3]), tensor([7, 8, 9]))},
{"user:follow:user": (tensor([0, 1]), tensor([ 6, 7, 8, 9, 10, 11]))},
{"user:follow:user": (tensor([2, 3]), tensor([12, 13, 14, 15, 16, 17]))},
{"user:follow:user": (tensor([4, 5]), tensor([18, 19, 20, 21, 22, 23]))}]
>>> item_set[:]
{"user:like:item": (tensor([[0, 1], [2, 3]]),
tensor([[4, 5, 6], [7, 8, 9]])),
"user:follow:user": (tensor([[0, 1], [2, 3], [4, 5]]),
tensor([[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]]))}
>>> item_set.names
('node_pairs', 'negative_dsts')
"""
def __init__(self, itemsets: Dict[str, ItemSet]) -> None:
self._itemsets = itemsets
self._names = itemsets[list(itemsets.keys())[0]].names
assert all(
self._names == itemset.names for itemset in itemsets.values()
), "All itemsets must have the same names."
try:
# For indexable itemsets, we compute the offsets for each itemset
# in advance to speed up indexing.
offsets = [0] + [
len(itemset) for itemset in self._itemsets.values()
]
self._offsets = torch.tensor(offsets).cumsum(0)
except TypeError:
self._offsets = None
def __iter__(self) -> Iterator:
for key, itemset in self._itemsets.items():
for item in itemset:
yield {key: item}
def __len__(self) -> int:
return sum(len(itemset) for itemset in self._itemsets.values())
def __getitem__(self, idx: Union[int, slice]) -> Dict[str, Tuple]:
if self._offsets is None:
raise TypeError(
f"{type(self).__name__} instance doesn't support indexing."
)
total_num = self._offsets[-1]
if isinstance(idx, int):
if idx < 0:
idx += total_num
if idx < 0 or idx >= total_num:
raise IndexError(f"{type(self).__name__} index out of range.")
offset_idx = torch.searchsorted(self._offsets, idx, right=True)
offset_idx -= 1
idx -= self._offsets[offset_idx]
key = list(self._itemsets.keys())[offset_idx]
return {key: self._itemsets[key][idx]}
elif isinstance(idx, slice):
start, stop, step = idx.indices(total_num)
assert step == 1, "Step must be 1."
assert start < stop, "Start must be smaller than stop."
data = {}
offset_idx_start = max(
1, torch.searchsorted(self._offsets, start, right=False)
)
keys = list(self._itemsets.keys())
for offset_idx in range(offset_idx_start, len(self._offsets)):
key = keys[offset_idx - 1]
data[key] = self._itemsets[key][
max(0, start - self._offsets[offset_idx - 1]) : stop
- self._offsets[offset_idx - 1]
]
if stop <= self._offsets[offset_idx]:
break
return data
raise TypeError(f"{type(self).__name__} indices must be int or slice.")
@property
def names(self) -> Tuple[str]:
"""Return the names of the items."""
return self._names
def __repr__(self) -> str:
ret = (
"{Classname}(\n"
" itemsets={itemsets},\n"
" names={names},\n"
")"
)
itemsets_str = textwrap.indent(
repr(self._itemsets), " " * len(" itemsets=")
).strip()
return ret.format(
Classname=self.__class__.__name__,
itemsets=itemsets_str,
names=self._names,
)