"""Built-in message function."""
from __future__ import absolute_import
import sys
from itertools import product
from .base import BuiltinFunction, TargetCode
__all__ = ["copy_u", "copy_e", "BinaryMessageFunction", "CopyMessageFunction"]
class MessageFunction(BuiltinFunction):
"""Base builtin message function class."""
@property
def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError
class BinaryMessageFunction(MessageFunction):
"""Class for the lhs_op_rhs builtin message function.
See Also
--------
u_mul_e
"""
def __init__(self, binary_op, lhs, rhs, lhs_field, rhs_field, out_field):
self.binary_op = binary_op
self.lhs = lhs
self.rhs = rhs
self.lhs_field = lhs_field
self.rhs_field = rhs_field
self.out_field = out_field
@property
def name(self):
lhs = TargetCode.CODE2STR[self.lhs]
rhs = TargetCode.CODE2STR[self.rhs]
return "{}_{}_{}".format(lhs, self.binary_op, rhs)
class CopyMessageFunction(MessageFunction):
"""Class for the copy builtin message function.
See Also
--------
copy_u
"""
def __init__(self, target, in_field, out_field):
self.target = target
self.in_field = in_field
self.out_field = out_field
@property
def name(self):
return "copy_{}".format(TargetCode.CODE2STR[self.target])
[docs]def copy_u(u, out):
"""Builtin message function that computes message using source node
feature.
Parameters
----------
u : str
The source feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.copy_u('h', 'm')
The above example is equivalent to the following user defined function:
>>> def message_func(edges):
>>> return {'m': edges.src['h']}
"""
return CopyMessageFunction(TargetCode.SRC, u, out)
[docs]def copy_e(e, out):
"""Builtin message function that computes message using edge feature.
Parameters
----------
e : str
The edge feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.copy_e('h', 'm')
The above example is equivalent to the following user defined function:
>>> def message_func(edges):
>>> return {'m': edges.data['h']}
"""
return CopyMessageFunction(TargetCode.EDGE, e, out)
###############################################################################
# Generate all following builtin message functions:
# element-wise message functions:
# u_add_v, u_sub_v, u_mul_v, u_div_v
# u_add_e, u_sub_e, u_mul_e, u_div_e
# v_add_u, v_sub_u, v_mul_u, v_div_u
# v_add_e, v_sub_e, v_mul_e, v_div_e
# e_add_u, e_sub_u, e_mul_u, e_div_u
# e_add_v, e_sub_v, e_mul_v, e_div_v
#
# dot message functions:
# u_dot_v, u_dot_e, v_dot_e
# v_dot_u, e_dot_u, e_dot_v
_TARGET_MAP = {
"u": TargetCode.SRC,
"v": TargetCode.DST,
"e": TargetCode.EDGE,
}
def _gen_message_builtin(lhs, rhs, binary_op):
name = "{}_{}_{}".format(lhs, binary_op, rhs)
docstring = """Builtin message function that computes a message on an edge
by performing element-wise {} between features of {} and {}
if the features have the same shape; otherwise, it first broadcasts the features
to a new shape and performs the element-wise operation.
Broadcasting follows NumPy semantics. Please see
https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
for more details about the NumPy broadcasting semantics.
Parameters
----------
lhs_field : str
The feature field of {}.
rhs_field : str
The feature field of {}.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.{}('h', 'h', 'm')
""".format(
binary_op,
TargetCode.CODE2STR[_TARGET_MAP[lhs]],
TargetCode.CODE2STR[_TARGET_MAP[rhs]],
TargetCode.CODE2STR[_TARGET_MAP[lhs]],
TargetCode.CODE2STR[_TARGET_MAP[rhs]],
name,
)
def func(lhs_field, rhs_field, out):
return BinaryMessageFunction(
binary_op,
_TARGET_MAP[lhs],
_TARGET_MAP[rhs],
lhs_field,
rhs_field,
out,
)
func.__name__ = name
func.__doc__ = docstring
return func
def _register_builtin_message_func():
"""Register builtin message functions"""
target = ["u", "v", "e"]
for lhs, rhs in product(target, target):
if lhs != rhs:
for binary_op in ["add", "sub", "mul", "div", "dot"]:
func = _gen_message_builtin(lhs, rhs, binary_op)
setattr(sys.modules[__name__], func.__name__, func)
__all__.append(func.__name__)
_register_builtin_message_func()