"""Built-in reducer function."""
# pylint: disable=redefined-builtin
from __future__ import absolute_import
from .. import backend as F
from .base import BuiltinFunction
__all__ = ["sum", "max"]
class ReduceFunction(BuiltinFunction):
"""Base builtin reduce function class."""
def __call__(self, nodes):
"""Regular computation of this builtin function
This will be used when optimization is not available and should
ONLY be called by DGL framework.
"""
raise NotImplementedError
@property
def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError
def is_spmv_supported(self):
"""Return whether the SPMV optimization is supported."""
raise NotImplementedError
class SimpleReduceFunction(ReduceFunction):
"""Builtin reduce function that aggregates a single field into another
single field."""
def __init__(self, name, reduce_op, msg_field, out_field):
self._name = name
self.reduce_op = reduce_op
self.msg_field = msg_field
self.out_field = out_field
def is_spmv_supported(self):
"""Return whether the SPMV optimization is supported."""
# NOTE: only sum is supported right now.
return self._name == "sum"
def __call__(self, nodes):
return {self.out_field : self.reduce_op(nodes.mailbox[self.msg_field], 1)}
@property
def name(self):
return self._name
[docs]def sum(msg, out):
"""Builtin reduce function that aggregates messages by sum.
Parameters
----------
msg : str
The message field.
out : str
The output node feature field.
Examples
--------
>>> import dgl
>>> reduce_func = dgl.function.sum(msg='m', out='h')
The above example is equivalent to the following user defined function
(if using PyTorch):
>>> import torch
>>> def reduce_func(nodes):
>>> return {'h': torch.sum(nodes.mailbox['m'], dim=1)}
"""
return SimpleReduceFunction("sum", F.sum, msg, out)
[docs]def max(msg, out):
"""Builtin reduce function that aggregates messages by max.
Parameters
----------
msg : str
The message field.
out : str
The output node feature field.
Examples
--------
>>> import dgl
>>> reduce_func = dgl.function.max(msg='m', out='h')
The above example is equivalent to the following user defined function
(if using PyTorch):
>>> import torch
>>> def reduce_func(nodes):
>>> return {'h': torch.max(nodes.mailbox['m'], dim=1)}
"""
return SimpleReduceFunction("max", F.max, msg, out)