TypedLinearο

class dgl.nn.pytorch.TypedLinear(in_size, out_size, num_types, regularizer=None, num_bases=None)[source]ο

Bases: Module

Linear transformation according to types.

For each sample of the input batch $$x \in X$$, apply linear transformation $$xW_t$$, where $$t$$ is the type of $$x$$.

The module supports two regularization methods (basis-decomposition and block-diagonal-decomposition) proposed by βModeling Relational Data with Graph Convolutional Networksβ

The basis regularization decomposes $$W_t$$ by:

$W_t^{(l)} = \sum_{b=1}^B a_{tb}^{(l)}V_b^{(l)}$

where $$B$$ is the number of bases, $$V_b^{(l)}$$ are linearly combined with coefficients $$a_{tb}^{(l)}$$.

The block-diagonal-decomposition regularization decomposes $$W_t$$ into $$B$$ block-diagonal matrices. We refer to $$B$$ as the number of bases:

$W_t^{(l)} = \oplus_{b=1}^B Q_{tb}^{(l)}$

where $$B$$ is the number of bases, $$Q_{tb}^{(l)}$$ are block bases with shape $$R^{(d^{(l+1)}/B)\times(d^{l}/B)}$$.

Parameters:
• in_size (int) β Input feature size.

• out_size (int) β Output feature size.

• num_types (int) β Total number of types.

• regularizer (str, optional) β

Which weight regularizer to use βbasisβ or βbddβ:

• βbasisβ is short for basis-decomposition.

• βbddβ is short for block-diagonal-decomposition.

Default applies no regularization.

• num_bases (int, optional) β Number of bases. Needed when regularizer is specified. Typically smaller than num_types. Default: None.

Examples

No regularization.

>>> from dgl.nn import TypedLinear
>>> import torch
>>>
>>> x = torch.randn(100, 32)
>>> x_type = torch.randint(0, 5, (100,))
>>> m = TypedLinear(32, 64, 5)
>>> y = m(x, x_type)
>>> print(y.shape)
torch.Size([100, 64])


With basis regularization

>>> x = torch.randn(100, 32)
>>> x_type = torch.randint(0, 5, (100,))
>>> m = TypedLinear(32, 64, 5, regularizer='basis', num_bases=4)
>>> y = m(x, x_type)
>>> print(y.shape)
torch.Size([100, 64])

forward(x, x_type, sorted_by_type=False)[source]ο

Forward computation.

Parameters:
• x (torch.Tensor) β A 2D input tensor. Shape: (N, D1)

• x_type (torch.Tensor) β A 1D integer tensor storing the type of the elements in x with one-to-one correspondenc. Shape: (N,)

• sorted_by_type (bool, optional) β Whether the inputs have been sorted by the types. Forward on pre-sorted inputs may be faster.

Returns:

y β The transformed output tensor. Shape: (N, D2)

Return type:

torch.Tensor

reset_parameters()[source]ο

Reset parameters