"""Module for common feature initializers."""
from __future__ import absolute_import
from . import backend as F
__all__ = ['base_initializer', 'zero_initializer']
[docs]def base_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-argument
"""The function signature for feature initializer.
Any customized feature initializer should follow this signature (see
example below).
Parameters
----------
shape : tuple of int
The shape of the result features. The first dimension
is the batch dimension.
dtype : data type object
The data type of the returned features.
ctx : context object
The device context of the returned features.
id_range : slice
The start id and the end id of the features to be initialized.
The id could be node or edge id depending on the scenario.
Note that the step is always None.
Examples
--------
If PyTorch is used as backend, the following code defines an feature
initializer that initializes tensor value to 1
>>> import torch
>>> import dgl
>>> def initializer(shape, dtype, ctx, id_range):
>>> return torch.ones(shape, dtype=dtype, device=ctx)
>>> g = dgl.DGLGraph()
>>> g.set_n_initializer(initializer)
See Also
--------
dgl.DGLGraph.set_n_initializer
dgl.DGLGraph.set_e_initializer
"""
raise NotImplementedError
[docs]def zero_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-argument
"""Zero feature initializer
Examples
--------
>>> import dgl
>>> g = dgl.DGLGraph()
>>> g.set_n_initializer(dgl.init.zero_initializer)
See Also
--------
dgl.DGLGraph.set_n_initializer
dgl.DGLGraph.set_e_initializer
"""
return F.zeros(shape, dtype, ctx)