Source code for dgl.function.reducer

"""Built-in reducer function."""
# pylint: disable=redefined-builtin
from __future__ import absolute_import

from .. import backend as F
from .base import BuiltinFunction

__all__ = ["sum", "max"]

class ReduceFunction(BuiltinFunction):
    """Base builtin reduce function class."""

    def __call__(self, nodes):
        """Regular computation of this builtin function

        This will be used when optimization is not available and should
        ONLY be called by DGL framework.
        """
        raise NotImplementedError

    @property
    def name(self):
        """Return the name of this builtin function."""
        raise NotImplementedError

    def is_spmv_supported(self):
        """Return whether the SPMV optimization is supported."""
        raise NotImplementedError


class SimpleReduceFunction(ReduceFunction):
    """Builtin reduce function that aggregates a single field into another
    single field."""
    def __init__(self, name, reduce_op, msg_field, out_field):
        self._name = name
        self.reduce_op = reduce_op
        self.msg_field = msg_field
        self.out_field = out_field

    def is_spmv_supported(self):
        """Return whether the SPMV optimization is supported."""
        # NOTE: only sum is supported right now.
        return self._name == "sum"

    def __call__(self, nodes):
        return {self.out_field : self.reduce_op(nodes.mailbox[self.msg_field], 1)}

    @property
    def name(self):
        return self._name

[docs]def sum(msg, out): """Builtin reduce function that aggregates messages by sum. Parameters ---------- msg : str The message field. out : str The output node feature field. Examples -------- >>> import dgl >>> reduce_func = dgl.function.sum(msg='m', out='h') The above example is equivalent to the following user defined function (if using PyTorch): >>> import torch >>> def reduce_func(nodes): >>> return {'h': torch.sum(nodes.mailbox['m'], dim=1)} """ return SimpleReduceFunction("sum", F.sum, msg, out)
[docs]def max(msg, out): """Builtin reduce function that aggregates messages by max. Parameters ---------- msg : str The message field. out : str The output node feature field. Examples -------- >>> import dgl >>> reduce_func = dgl.function.max(msg='m', out='h') The above example is equivalent to the following user defined function (if using PyTorch): >>> import torch >>> def reduce_func(nodes): >>> return {'h': torch.max(nodes.mailbox['m'], dim=1)} """ return SimpleReduceFunction("max", F.max, msg, out)