Source code for dgl.function.message

"""Built-in message function."""
from __future__ import absolute_import

import operator

from .base import BuiltinFunction
from .. import backend as F

__all__ = ["src_mul_edge", "copy_src", "copy_edge"]


class MessageFunction(BuiltinFunction):
    """Base builtin message function class."""

    def __call__(self, edges):
        """Regular computation of this builtin function

        This will be used when optimization is not available and should
        ONLY be called by DGL framework.
        """
        raise NotImplementedError

    @property
    def name(self):
        """Return the name of this builtin function."""
        raise NotImplementedError

    def is_spmv_supported(self, g):
        """Return whether the SPMV optimization is supported."""
        raise NotImplementedError

    @property
    def use_edge_feature(self):
        """Return true if the message function uses edge feature data."""
        raise NotImplementedError

def _is_spmv_supported_edge_feat(g, field):
    """Return whether the edge feature shape supports SPMV optimization.

    Only scalar feature is supported currently.
    """
    feat = g.get_e_repr()[field]
    shape = F.shape(feat)
    return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1)


class SrcMulEdgeMessageFunction(MessageFunction):
    """Class for the src_mul_edge builtin message function.

    See Also
    --------
    src_mul_edge
    """
    def __init__(self, mul_op, src_field, edge_field, out_field):
        self.mul_op = mul_op
        self.src_field = src_field
        self.edge_field = edge_field
        self.out_field = out_field

    def is_spmv_supported(self, g):
        """Return true if this supports SPMV optimization.

        Parameters
        ----------
        g : DGLGraph
            The graph.

        Returns
        -------
        bool
            True if this supports SPMV optimization.
        """
        return _is_spmv_supported_edge_feat(g, self.edge_field)

    def __call__(self, edges):
        """Regular computation of this builtin function

        This will be used when optimization is not available and should
        ONLY be called by DGL framework.
        """
        sdata = edges.src[self.src_field]
        edata = edges.data[self.edge_field]
        # Due to the different broadcasting semantics of different backends,
        #   we need to broadcast the sdata and edata to be of the same rank.
        rank = max(F.ndim(sdata), F.ndim(edata))
        sshape = F.shape(sdata)
        eshape = F.shape(edata)
        sdata = F.reshape(sdata, sshape + (1,) * (rank - F.ndim(sdata)))
        edata = F.reshape(edata, eshape + (1,) * (rank - F.ndim(edata)))
        ret = self.mul_op(sdata, edata)
        return {self.out_field : ret}

    @property
    def name(self):
        return "src_mul_edge"

    @property
    def use_edge_feature(self):
        """Return true if the message function uses edge feature data."""
        return True

class CopySrcMessageFunction(MessageFunction):
    """Class for the copy_src builtin message function.

    See Also
    --------
    copy_src
    """
    def __init__(self, src_field, out_field):
        self.src_field = src_field
        self.out_field = out_field

    def is_spmv_supported(self, g):
        """Return true if this supports SPMV optimization.

        Parameters
        ----------
        g : DGLGraph
            The graph.

        Returns
        -------
        bool
            True if this supports SPMV optimization.
        """
        return True

    def __call__(self, edges):
        """Regular computation of this builtin function

        This will be used when optimization is not available and should
        ONLY be called by DGL framework.
        """
        return {self.out_field : edges.src[self.src_field]}

    @property
    def name(self):
        return "copy_src"

    @property
    def use_edge_feature(self):
        """Return true if the message function uses edge feature data."""
        return False

class CopyEdgeMessageFunction(MessageFunction):
    """Class for the copy_edge builtin message function.

    See Also
    --------
    copy_edge
    """
    def __init__(self, edge_field=None, out_field=None):
        self.edge_field = edge_field
        self.out_field = out_field

    def is_spmv_supported(self, g):
        """Return true if this supports SPMV optimization.

        Parameters
        ----------
        g : DGLGraph
            The graph.

        Returns
        -------
        bool
            True if this supports SPMV optimization.
        """
        # TODO: support this with e2v spmv
        return False
        # return _is_spmv_supported_edge_feat(g, self.edge_field)

    def __call__(self, edges):
        """Regular computation of this builtin function

        This will be used when optimization is not available and should
        ONLY be called by DGL framework.
        """
        return {self.out_field : edges.data[self.edge_field]}

    @property
    def name(self):
        return "copy_edge"

    @property
    def use_edge_feature(self):
        """Return true if the message function uses edge feature data."""
        return True

[docs]def src_mul_edge(src, edge, out): """Builtin message function that computes message by multiplying source node features with edge features. Parameters ---------- src : str The source feature field. edge : str The edge feature field. out : str The output message field. Examples -------- >>> import dgl >>> message_func = dgl.function.src_mul_edge(src='h', edge='w', out='m') The above example is equivalent to the following user defined function: >>> def message_func(edges): >>> return {'m': edges.src['h'] * edges.data['w']} """ return SrcMulEdgeMessageFunction(operator.mul, src, edge, out)
[docs]def copy_src(src, out): """Builtin message function that computes message using source node feature. Parameters ---------- src : str The source feature field. out : str The output message field. Examples -------- >>> import dgl >>> message_func = dgl.function.copy_src(src='h', out='m') The above example is equivalent to the following user defined function: >>> def message_func(edges): >>> return {'m': edges.src['h']} """ return CopySrcMessageFunction(src, out)
[docs]def copy_edge(edge, out): """Builtin message function that computes message using edge feature. Parameters ---------- edge : str The edge feature field. out : str The output message field. Examples -------- >>> import dgl >>> message_func = dgl.function.copy_edge(edge='h', out='m') The above example is equivalent to the following user defined function: >>> def message_func(edges): >>> return {'m': edges.data['h']} """ return CopyEdgeMessageFunction(edge, out)