Source code for dgl.graphbolt.negative_sampler

"""Negative samplers."""

from _collections_abc import Mapping

from torch.utils.data import functional_datapipe

from .minibatch_transformer import MiniBatchTransformer

__all__ = [
    "NegativeSampler",
]


[docs]@functional_datapipe("sample_negative") class NegativeSampler(MiniBatchTransformer): """ A negative sampler used to generate negative samples and return a mix of positive and negative samples. Functional name: :obj:`sample_negative`. Parameters ---------- datapipe : DataPipe The datapipe. negative_ratio : int The proportion of negative samples to positive samples. """ def __init__( self, datapipe, negative_ratio, ): super().__init__(datapipe, self._sample) assert negative_ratio > 0, "Negative_ratio should be positive Integer." self.negative_ratio = negative_ratio def _sample(self, minibatch): """ Generate a mix of positive and negative samples. If `seeds` in minibatch is not None, `labels` and `indexes` will be constructed after negative sampling, based on corresponding seeds. Parameters ---------- minibatch : MiniBatch An instance of 'MiniBatch' class requires the 'node_pairs' field. This function is responsible for generating negative edges corresponding to the positive edges defined by the 'node_pairs'. In cases where negative edges already exist, this function will overwrite them. Returns ------- MiniBatch An instance of 'MiniBatch' encompasses both positive and negative samples. """ if minibatch.seeds is None: node_pairs = minibatch.node_pairs assert node_pairs is not None if isinstance(node_pairs, Mapping): minibatch.negative_srcs, minibatch.negative_dsts = {}, {} for etype, pos_pairs in node_pairs.items(): self._collate( minibatch, self._sample_with_etype(pos_pairs, etype), etype, ) else: self._collate(minibatch, self._sample_with_etype(node_pairs)) else: seeds = minibatch.seeds if isinstance(seeds, Mapping): if minibatch.indexes is None: minibatch.indexes = {} if minibatch.labels is None: minibatch.labels = {} for etype, pos_pairs in seeds.items(): ( minibatch.seeds[etype], minibatch.labels[etype], minibatch.indexes[etype], ) = self._sample_with_etype( pos_pairs, etype, use_seeds=True ) else: ( minibatch.seeds, minibatch.labels, minibatch.indexes, ) = self._sample_with_etype(seeds, use_seeds=True) return minibatch def _sample_with_etype(self, node_pairs, etype=None, use_seeds=False): """Generate negative pairs for a given etype form positive pairs for a given etype. If `node_pairs` is a 2D tensor, which represents `seeds` is used in minibatch, corresponding labels and indexes will be constructed. Parameters ---------- node_pairs : Tuple[Tensor, Tensor] A tuple of tensors that represent source-destination node pairs of positive edges, where positive means the edge must exist in the graph. etype : str Canonical edge type. Returns ------- Tuple[Tensor, Tensor] or Tensor A collection of negative node pairs. Tensor or None Corresponding labels. If label is True, corresponding edge is positive. If label is False, corresponding edge is negative. Tensor or None Corresponding indexes, indicates to which query an edge belongs. """ raise NotImplementedError def _collate(self, minibatch, neg_pairs, etype=None): """Collates positive and negative samples into minibatch. Parameters ---------- minibatch : MiniBatch The input minibatch, which contains positive node pairs, will be filled with negative information in this function. neg_pairs : Tuple[Tensor, Tensor] A tuple of tensors represents source-destination node pairs of negative edges, where negative means the edge may not exist in the graph. etype : str Canonical edge type. """ neg_src, neg_dst = neg_pairs if neg_src is not None: neg_src = neg_src.view(-1, self.negative_ratio) if neg_dst is not None: neg_dst = neg_dst.view(-1, self.negative_ratio) if etype is not None: minibatch.negative_srcs[etype] = neg_src minibatch.negative_dsts[etype] = neg_dst else: minibatch.negative_srcs = neg_src minibatch.negative_dsts = neg_dst