# Source code for dgl.nn.pytorch.linear

"""Various commonly used linear modules"""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
import math

import torch
import torch.nn as nn

from ...ops import gather_mm, segment_mm

__all__ = ["TypedLinear"]

[docs]class TypedLinear(nn.Module):
r"""Linear transformation according to types.

For each sample of the input batch :math:x \in X, apply linear transformation
:math:xW_t, where :math:t is the type of :math:x.

The module supports two regularization methods (basis-decomposition and
block-diagonal-decomposition) proposed by "Modeling Relational Data
with Graph Convolutional Networks <https://arxiv.org/abs/1703.06103>__"

The basis regularization decomposes :math:W_t by:

.. math::

W_t^{(l)} = \sum_{b=1}^B a_{tb}^{(l)}V_b^{(l)}

where :math:B is the number of bases, :math:V_b^{(l)} are linearly combined
with coefficients :math:a_{tb}^{(l)}.

The block-diagonal-decomposition regularization decomposes :math:W_t into :math:B
block-diagonal matrices. We refer to :math:B as the number of bases:

.. math::

W_t^{(l)} = \oplus_{b=1}^B Q_{tb}^{(l)}

where :math:B is the number of bases, :math:Q_{tb}^{(l)} are block
bases with shape :math:R^{(d^{(l+1)}/B)\times(d^{l}/B)}.

Parameters
----------
in_size : int
Input feature size.
out_size : int
Output feature size.
num_types : int
Total number of types.
regularizer : str, optional
Which weight regularizer to use "basis" or "bdd":

- "basis" is short for basis-decomposition.
- "bdd" is short for block-diagonal-decomposition.

Default applies no regularization.
num_bases : int, optional
Number of bases. Needed when regularizer is specified. Typically smaller
than num_types.
Default: None.

Examples
--------

No regularization.

>>> from dgl.nn import TypedLinear
>>> import torch
>>>
>>> x = torch.randn(100, 32)
>>> x_type = torch.randint(0, 5, (100,))
>>> m = TypedLinear(32, 64, 5)
>>> y = m(x, x_type)
>>> print(y.shape)
torch.Size([100, 64])

With basis regularization

>>> x = torch.randn(100, 32)
>>> x_type = torch.randint(0, 5, (100,))
>>> m = TypedLinear(32, 64, 5, regularizer='basis', num_bases=4)
>>> y = m(x, x_type)
>>> print(y.shape)
torch.Size([100, 64])
"""

def __init__(
self, in_size, out_size, num_types, regularizer=None, num_bases=None
):
super().__init__()
self.in_size = in_size
self.out_size = out_size
self.num_types = num_types
if regularizer is None:
self.W = nn.Parameter(torch.Tensor(num_types, in_size, out_size))
elif regularizer == "basis":
if num_bases is None:
raise ValueError(
'Missing "num_bases" for basis regularization.'
)
self.W = nn.Parameter(torch.Tensor(num_bases, in_size, out_size))
self.coeff = nn.Parameter(torch.Tensor(num_types, num_bases))
self.num_bases = num_bases
elif regularizer == "bdd":
if num_bases is None:
raise ValueError('Missing "num_bases" for bdd regularization.')
if in_size % num_bases != 0 or out_size % num_bases != 0:
raise ValueError(
"Input and output sizes must be divisible by num_bases."
)
self.submat_in = in_size // num_bases
self.submat_out = out_size // num_bases
self.W = nn.Parameter(
torch.Tensor(
num_types, num_bases * self.submat_in * self.submat_out
)
)
self.num_bases = num_bases
else:
raise ValueError(
f'Supported regularizer options: "basis", "bdd", but got {regularizer}'
)
self.regularizer = regularizer
self.reset_parameters()

[docs]    def reset_parameters(self):
"""Reset parameters"""
# Follow torch.nn.Linear 's initialization to use kaiming_uniform_ on in_size
if self.regularizer is None:
nn.init.uniform_(
self.W,
-1 / math.sqrt(self.in_size),
1 / math.sqrt(self.in_size),
)
elif self.regularizer == "basis":
nn.init.uniform_(
self.W,
-1 / math.sqrt(self.in_size),
1 / math.sqrt(self.in_size),
)
nn.init.xavier_uniform_(
self.coeff, gain=nn.init.calculate_gain("relu")
)
elif self.regularizer == "bdd":
nn.init.uniform_(
self.W,
-1 / math.sqrt(self.submat_in),
1 / math.sqrt(self.submat_in),
)
else:
raise ValueError(
f'Supported regularizer options: "basis", "bdd", but got {regularizer}'
)

def get_weight(self):
"""Get type-wise weight"""
if self.regularizer is None:
return self.W
elif self.regularizer == "basis":
W = self.W.view(self.num_bases, self.in_size * self.out_size)
return (self.coeff @ W).view(
self.num_types, self.in_size, self.out_size
)
elif self.regularizer == "bdd":
return self.W
else:
raise ValueError(
f'Supported regularizer options: "basis", "bdd", but got {regularizer}'
)

[docs]    def forward(self, x, x_type, sorted_by_type=False):
"""Forward computation.

Parameters
----------
x : torch.Tensor
A 2D input tensor. Shape: (N, D1)
x_type : torch.Tensor
A 1D integer tensor storing the type of the elements in x with one-to-one
correspondenc. Shape: (N,)
sorted_by_type : bool, optional
Whether the inputs have been sorted by the types. Forward on pre-sorted inputs may
be faster.

Returns
-------
y : torch.Tensor
The transformed output tensor. Shape: (N, D2)
"""
w = self.get_weight()
if self.regularizer == "bdd":
w = w.index_select(0, x_type).view(
-1, self.submat_in, self.submat_out
)
x = x.view(-1, 1, self.submat_in)
elif sorted_by_type:
pos_l = torch.searchsorted(
x_type, torch.arange(self.num_types, device=x.device)
)
pos_r = torch.cat(
[pos_l[1:], torch.tensor([len(x_type)], device=x.device)]
)
seglen = (
pos_r - pos_l
).cpu()  # XXX(minjie): cause device synchronize
return segment_mm(x, w, seglen_a=seglen)
else:
return gather_mm(x, w, idx_b=x_type)

def __repr__(self):
if self.regularizer is None:
return (
f"TypedLinear(in_size={self.in_size}, out_size={self.out_size}, "
f"num_types={self.num_types})"
)
else:
return (
f"TypedLinear(in_size={self.in_size}, out_size={self.out_size}, "
f"num_types={self.num_types}, regularizer={self.regularizer}, "
f"num_bases={self.num_bases})"
)