# Source code for dgl.ops.edge_softmax

"""dgl edge_softmax operator module."""
from ..backend import astype
from ..backend import edge_softmax as edge_softmax_internal
from ..backend import edge_softmax_hetero as edge_softmax_hetero_internal
from ..base import ALL, is_all

__all__ = ["edge_softmax"]

[docs]def edge_softmax(graph, logits, eids=ALL, norm_by="dst"):
r"""Compute softmax over weights of incoming edges for every node.

For a node :math:i, edge softmax is an operation that computes

.. math::
a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}

where :math:z_{ij} is a signal of edge :math:j\rightarrow i, also
called logits in the context of softmax. :math:\mathcal{N}(i) is
the set of nodes that have an edge to :math:i.

By default edge softmax is normalized by destination nodes(i.e. :math:ij
are incoming edges of i in the formula above). We also support edge
softmax normalized by source nodes(i.e. :math:ij are outgoing edges of
i in the formula). The former case corresponds to softmax in GAT and
Transformer, and the latter case corresponds to softmax in Capsule network.
An example of using edge softmax is in
Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>__ where
the attention weights are computed with this operation.
Other non-GNN examples using this are
Transformer <https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf>__,
Capsule <https://arxiv.org/pdf/1710.09829.pdf>__, etc.

Parameters
----------
graph : DGLGraph
The graph over which edge softmax will be performed.
logits : torch.Tensor or dict of torch.Tensor
The input edge feature. Heterogeneous graphs can have dict of tensors where
each tensor stores the edge features of the corresponding relation type.
eids : torch.Tensor or ALL, optional
The IDs of the edges to apply edge softmax. If ALL, it will apply edge
softmax to all edges in the graph. Default: ALL.
norm_by : str, could be src or dst
Normalized by source nodes or destination nodes. Default: dst.

Returns
-------
Tensor or tuple of tensors
Softmax value.

Notes
-----
* Input shape: :math:(E, *, 1) where * means any number of
additional dimensions, :math:E equals the length of eids.
If the eids is ALL, :math:E equals the number of edges in
the graph.
* Return shape: :math:(E, *, 1)

Examples on a homogeneous graph
-------------------------------
The following example uses PyTorch backend.

>>> from dgl.nn.functional import edge_softmax
>>> import dgl
>>> import torch as th

Create a :code:DGLGraph object and initialize its edge features.

>>> g = dgl.graph((th.tensor([0, 0, 0, 1, 1, 2]), th.tensor([0, 1, 2, 1, 2, 2])))
>>> edata = th.ones(6, 1).float()
>>> edata
tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.]])

Apply edge softmax over g:

>>> edge_softmax(g, edata)
tensor([[1.0000],
[0.5000],
[0.3333],
[0.5000],
[0.3333],
[0.3333]])

Apply edge softmax over g normalized by source nodes:

>>> edge_softmax(g, edata, norm_by='src')
tensor([[0.3333],
[0.3333],
[0.3333],
[0.5000],
[0.5000],
[1.0000]])

Apply edge softmax to first 4 edges of g:

>>> edge_softmax(g, edata[:4], th.Tensor([0,1,2,3]))
tensor([[1.0000],
[0.5000],
[1.0000],
[0.5000]])

Examples on a heterogeneous graph
---------------------------------

Create a heterogeneous graph and initialize its edge features.

>>> hg = dgl.heterograph({
...     ('user', 'follows', 'user'): ([0, 0, 1], [0, 1, 2]),
...     ('developer', 'develops', 'game'): ([0, 1], [0, 1])
...     })
>>> edata_follows = th.ones(3, 1).float()
>>> edata_develops = th.ones(2, 1).float()
>>> edata_dict = {('user', 'follows', 'user'): edata_follows,
... ('developer','develops', 'game'): edata_develops}

Apply edge softmax over hg normalized by source nodes:

>>> edge_softmax(hg, edata_dict, norm_by='src')
{('developer', 'develops', 'game'): tensor([[1.],
[1.]]), ('user', 'follows', 'user'): tensor([[0.5000],
[0.5000],
[1.0000]])}
"""
if not is_all(eids):
eids = astype(eids, graph.idtype)
if graph._graph.number_of_etypes() == 1:
return edge_softmax_internal(
graph._graph, logits, eids=eids, norm_by=norm_by
)
else:
logits_list = [None] * graph._graph.number_of_etypes()
logits = {graph.to_canonical_etype(k): v for k, v in logits.items()}
for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel)
logits_list[etid] = logits[rel]
logits_tuple = tuple(logits_list)
score_tuple = edge_softmax_hetero_internal(
graph._graph, eids, norm_by, *logits_tuple
)
score = {}
for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel)
score[rel] = score_tuple[etid]
return score