Shortcuts

torch.nn.init

Warning

All the functions in this module are intended to be used to initialize neural network parameters, so they all run in torch.no_grad() mode and will not be taken into account by autograd.

torch.nn.init.calculate_gain(nonlinearity, param=None)[source]

Return the recommended gain value for the given nonlinearity function.

The values are as follows:

nonlinearity

gain

Linear / Identity

\(1\)

Conv{1,2,3}D

\(1\)

Sigmoid

\(1\)

Tanh

\(\frac{5}{3}\)

ReLU

\(\sqrt{2}\)

Leaky Relu

\(\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}\)

SELU

\(\frac{3}{4}\)

Warning

In order to implement Self-Normalizing Neural Networks , you should use nonlinearity='linear' instead of nonlinearity='selu'. This gives the initial weights a variance of 1 / N, which is necessary to induce a stable fixed point in the forward pass. In contrast, the default gain for SELU sacrifices the normalization effect for more stable gradient flow in rectangular layers.

Parameters:
  • nonlinearity – the non-linear function (nn.functional name)

  • param – optional parameter for the non-linear function

Examples

>>> gain = nn.init.calculate_gain('leaky_relu', 0.2)  # leaky_relu with negative_slope=0.2
torch.nn.init.uniform_(tensor, a=0.0, b=1.0, generator=None)[source]

Fill the input Tensor with values drawn from the uniform distribution.

\(\mathcal{U}(a, b)\).

Parameters:
  • tensor (Tensor) – an n-dimensional torch.Tensor

  • a (float) – the lower bound of the uniform distribution

  • b (float) – the upper bound of the uniform distribution

  • generator (Optional[Generator]) – the torch Generator to sample from (default: None)

Return type:

Tensor

Examples

>>> w = torch.empty(3, 5)
>>> nn.init.uniform_(w)
torch.nn.init.normal_(tensor, mean=0.0, std=1.0, generator=None)[source]

Fill the input Tensor with values drawn from the normal distribution.

\(\mathcal{N}(\text{mean}, \text{std}^2)\).

Parameters:
  • tensor (Tensor) – an n-dimensional torch.Tensor

  • mean (float) – the mean of the normal distribution

  • std (float) – the standard deviation of the normal distribution

  • generator (Optional[Generator]) – the torch Generator to sample from (default: None)

Return type:

Tensor

Examples

>>> w = torch.empty(3, 5)
>>> nn.init.normal_(w)
torch.nn.init.constant_(tensor, val)[source]

Fill the input Tensor with the value \(\text{val}\).

Parameters:
  • tensor (Tensor) – an n-dimensional torch.Tensor

  • val (float) – the value to fill the tensor with

Return type:

Tensor

Examples

>>> w = torch.empty(3, 5)
>>> nn.init.constant_(w, 0.3)
torch.nn.init.ones_(tensor)[source]

Fill the input Tensor with the scalar value 1.

Parameters:

tensor (Tensor) – an n-dimensional torch.Tensor

Return type:

Tensor

Examples

>>> w = torch.empty(3, 5)
>>> nn.init.ones_(w)
torch.nn.init.zeros_(tensor)[source]

Fill the input Tensor with the scalar value 0.

Parameters:

tensor (Tensor) – an n-dimensional torch.Tensor

Return type:

Tensor

Examples

>>> w = torch.empty(3, 5)
>>> nn.init.zeros_(w)
torch.nn.init.eye_(tensor)[source]

Fill the 2-dimensional input Tensor with the identity matrix.

Preserves the identity of the inputs in Linear layers, where as many inputs are preserved as possible.

Parameters:

tensor – a 2-dimensional torch.Tensor

Examples

>>> w = torch.empty(3, 5)
>>> nn.init.eye_(w)
torch.nn.init.dirac_(tensor, groups=1)[source]

Fill the {3, 4, 5}-dimensional input Tensor with the Dirac delta function.

Preserves the identity of the inputs in Convolutional layers, where as many input channels are preserved as possible. In case of groups>1, each group of channels preserves identity

Parameters:
  • tensor – a {3, 4, 5}-dimensional torch.Tensor

  • groups (int, optional) – number of groups in the conv layer (default: 1)

Examples

>>> w = torch.empty(3, 16, 5, 5)
>>> nn.init.dirac_(w)
>>> w = torch.empty(3, 24, 5, 5)
>>> nn.init.dirac_(w, 3)
torch.nn.init.xavier_uniform_(tensor, gain=1.0, generator=None)[source]

Fill the input Tensor with values using a Xavier uniform distribution.

The method is described in Understanding the difficulty of training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010). The resulting tensor will have values sampled from \(\mathcal{U}(-a, a)\) where

\[a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} \]

Also known as Glorot initialization.

Parameters:
  • tensor (Tensor) – an n-dimensional torch.Tensor

  • gain (float) – an optional scaling factor

  • generator (Optional[Generator]) – the torch Generator to sample from (default: None)

Return type:

Tensor

Examples

>>> w = torch.empty(3, 5)
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
torch.nn.init.xavier_normal_(tensor, gain=1.0, generator=None)[source]

Fill the input Tensor with values using a Xavier normal distribution.

The method is described in Understanding the difficulty of training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010). The resulting tensor will have values sampled from \(\mathcal{N}(0, \text{std}^2)\) where

\[\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} \]

Also known as Glorot initialization.

Parameters:
  • tensor (Tensor) – an n-dimensional torch.Tensor

  • gain (float) – an optional scaling factor

  • generator (Optional[Generator]) – the torch Generator to sample from (default: None)

Return type:

Tensor

Examples

>>> w = torch.empty(3, 5)
>>> nn.init.xavier_normal_(w)
torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None)[source]

Fill the input Tensor with values using a Kaiming uniform distribution.

The method is described in Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. (2015). The resulting tensor will have values sampled from \(\mathcal{U}(-\text{bound}, \text{bound})\) where

\[\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} \]

Also known as He initialization.

Parameters:
  • tensor (Tensor) – an n-dimensional torch.Tensor

  • a (float) – the negative slope of the rectifier used after this layer (only used with 'leaky_relu')

  • mode (str) – either 'fan_in' (default) or 'fan_out'. Choosing 'fan_in' preserves the magnitude of the variance of the weights in the forward pass. Choosing 'fan_out' preserves the magnitudes in the backwards pass.

  • nonlinearity (str) – the non-linear function (nn.functional name), recommended to use only with 'relu' or 'leaky_relu' (default).

  • generator (Optional[Generator]) – the torch Generator to sample from (default: None)

Examples

>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None)[source]

Fill the input Tensor with values using a Kaiming normal distribution.

The method is described in Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. (2015). The resulting tensor will have values sampled from \(\mathcal{N}(0, \text{std}^2)\) where

\[\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} \]

Also known as He initialization.

Parameters:
  • tensor (Tensor) – an n-dimensional torch.Tensor

  • a (float) – the negative slope of the rectifier used after this layer (only used with 'leaky_relu')

  • mode (str) – either 'fan_in' (default) or 'fan_out'. Choosing 'fan_in' preserves the magnitude of the variance of the weights in the forward pass. Choosing 'fan_out' preserves the magnitudes in the backwards pass.

  • nonlinearity (str) – the non-linear function (nn.functional name), recommended to use only with 'relu' or 'leaky_relu' (default).

  • generator (Optional[Generator]) – the torch Generator to sample from (default: None)

Examples

>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
torch.nn.init.trunc_normal_(tensor, mean=0.0, std=1.0, a=- 2.0, b=2.0, generator=None)[source]

Fill the input Tensor with values drawn from a truncated normal distribution.

The values are effectively drawn from the normal distribution \(\mathcal{N}(\text{mean}, \text{std}^2)\) with values outside \([a, b]\) redrawn until they are within the bounds. The method used for generating the random values works best when \(a \leq \text{mean} \leq b\).

Parameters:
  • tensor (Tensor) – an n-dimensional torch.Tensor

  • mean (float) – the mean of the normal distribution

  • std (float) – the standard deviation of the normal distribution

  • a (float) – the minimum cutoff value

  • b (float) – the maximum cutoff value

  • generator (Optional[Generator]) – the torch Generator to sample from (default: None)

Return type:

Tensor

Examples

>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
torch.nn.init.orthogonal_(tensor, gain=1, generator=None)[source]

Fill the input Tensor with a (semi) orthogonal matrix.

Described in Exact solutions to the nonlinear dynamics of learning in deep linear neural networks - Saxe, A. et al. (2013). The input tensor must have at least 2 dimensions, and for tensors with more than 2 dimensions the trailing dimensions are flattened.

Parameters:
  • tensor – an n-dimensional torch.Tensor, where \(n \geq 2\)

  • gain – optional scaling factor

  • generator (Optional[Generator]) – the torch Generator to sample from (default: None)

Examples

>>> w = torch.empty(3, 5)
>>> nn.init.orthogonal_(w)
torch.nn.init.sparse_(tensor, sparsity, std=0.01, generator=None)[source]

Fill the 2D input Tensor as a sparse matrix.

The non-zero elements will be drawn from the normal distribution \(\mathcal{N}(0, 0.01)\), as described in Deep learning via Hessian-free optimization - Martens, J. (2010).

Parameters:
  • tensor – an n-dimensional torch.Tensor

  • sparsity – The fraction of elements in each column to be set to zero

  • std – the standard deviation of the normal distribution used to generate the non-zero values

  • generator (Optional[Generator]) – the torch Generator to sample from (default: None)

Examples

>>> w = torch.empty(3, 5)
>>> nn.init.sparse_(w, sparsity=0.1)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources