WeightBasisΒΆ

class dgl.nn.pytorch.utils.WeightBasis(shape, num_bases, num_outputs)[source]ΒΆ

Bases: torch.nn.modules.module.Module

Basis decomposition from Modeling Relational Data with Graph Convolutional Networks

It can be described as below:

\[W_o = \sum_{b=1}^B a_{ob} V_b\]

Each weight output \(W_o\) is essentially a linear combination of basis transformations \(V_b\) with coefficients \(a_{ob}\).

If is useful as a form of regularization on a large parameter matrix. Thus, the number of weight outputs is usually larger than the number of bases.

Parameters
  • shape (tuple[int]) – Shape of the basis parameter.

  • num_bases (int) – Number of bases.

  • num_outputs (int) – Number of outputs.

forward()[source]ΒΆ

Forward computation

Returns

weight – Composed weight tensor of shape (num_outputs,) + shape

Return type

torch.Tensor