torchdeq.dropout

A module containing several implementations of variational dropout.

Variational dropout is a type of dropout where a single dropout mask is generated once per sample and applied consistently across all solver steps in the sample. This is particularly effective when used with implicit models, as it counters overfitting while preserving the dynamics.

This module provides variational dropout for 1d, 2d, and 3d inputs, with both channel-wise and token-wise options.

Dropout Function

torchdeq.dropout.reset_dropout(model)

Resets the dropout mask for all variational dropout layers in the model at the beginning of a training iteration.

Parameters:

model (torch.nn.Module) – A DEQ layer in which the dropout masks should be reset.

Dropout

class torchdeq.dropout.VariationalDropout(dropout=0.5)

Applies Variational Dropout to the input tensor.

During training, randomly zeros some of the elements of the input tensor with probability ‘dropout’ using a mask tensor sampled from a Bernoulli distribution.

The same mask is used for each input in a training iteration. (for fixed point convergence) This random mask is reset at the beginning of the next training iteration using reset_dropout.

Parameters:

dropout (float, optional) – The probability of an element to be zeroed. Default: 0.5.

Shape:
  • Input: Tensor of any shape.

  • Output: Tensor of the same shape as input.

Examples

>>> m = VariationalDropout(dropout=0.5)
>>> input = torch.randn(20, 16)
>>> output = m(input)
reset_mask(x)

Resets the dropout mask. Subclasses should implement this method according to the dimensionality of the input tensor.

class torchdeq.dropout.VariationalDropout1d(dropout=0.5, token_first=True)

Applies Variational Dropout to the input tensor.

During training, randomly zero out the entire channel/feature dimension of the input 1d tensor with probability ‘dropout’ using a mask tensor sample from a Bernoulli distribution.

The channel/feature dimension of 1d tensor is the \(*\) slice of \((B, L, *)\) for token_first=True, or \((B, *, L)\) for token_first=False.

The same mask is used for each input in a training iteration. (for fixed point convergence) This random mask is reset at the beginning of the next training iteration using reset_dropout.

Parameters:
  • dropout (float, optional) – The probability of an element to be zeroed. Default: 0.5

  • token_first (bool, optional) – If True, expects input tensor in shape \((B, L, D)\), otherwise expects \((B, D, L)\). Here, B is batch size, L is sequence length, and D is feature dimension. Default: False.

Shape:
  • Input: \((B, L, D)\) or \((B, D, L)\).

  • Output: \((B, L, D)\) or \((B, D, L)\) (same shape as input).

reset_mask(x)

Resets the dropout mask. Subclasses should implement this method according to the dimensionality of the input tensor.

class torchdeq.dropout.VariationalDropout2d(dropout=0.5, token_first=True)

Applies Variational Dropout to the input tensor.

During training, randomly zero out the entire channel/feature dimension of the input 2d tensor with probability ‘dropout’ using a mask tensor sample from a Bernoulli distribution.

The channel/feature dimension of 2d tensor is the \(*\) of \((B, H, W, *)\) for token_first=True, or \((B, *, H, W)\) for token_first=False.

During the fixed point solving, a fixed mask will be applied until convergence. Reset this random mask at the beginning of the next training iteration using reset_dropout.

Parameters:
  • dropout (float, optional) – The probability of an element to be zeroed. Default: 0.5

  • token_first (bool, optional) – If True, expect input tensor in shape \((B, H, W, D)\), otherwise expect \((B, D, H, W)\). Here, B is batch size, and D is feature dimension. Default: False

Shape:
  • Input: \((B, H, W, D)\) or \((B, D, H, W)\).

  • Output: \((B, H, W, D)\) or \((B, D, H, W)\) (same shape as input).

reset_mask(x)

Resets the dropout mask. Subclasses should implement this method according to the dimensionality of the input tensor.

class torchdeq.dropout.VariationalDropout3d(dropout=0.5, token_first=True)

Applies Variational Dropout to the input tensor.

During training, randomly zero out the entire channel/feature dimension of the input 3d tensor with probability ‘dropout’ using a mask tensor sample from a Bernoulli distribution.

The channel/feature dimension of 3d tensor is the \(*\) slice of \((B, T, H, W, *)\) for token_first=True, or \((B, *, T, H, W)\) for token_first=False.

During the fixed point solving, a fixed mask will be applied until convergence. Reset this random mask at the beginning of the next training iteration using reset_dropout.

Parameters:
  • dropout (float, optional) – The probability of an element to be zeroed. Default: 0.5

  • token_first (bool, optional) – If True, expect input tensor in shape \((B, T, H, W, D)\), otherwise expect \((B, D, T, H, W)\). Here, B is batch size, and D is feature dimension. Default: False

Shape:
  • Input: \((B, T, H, W, D)\) or \((B, D, T, H, W)\).

  • Output: \((B, T, H, W, D)\) or \((B, D, T, H, W)\) (same shape as input).

reset_mask(x)

Resets the dropout mask. Subclasses should implement this method according to the dimensionality of the input tensor.

class torchdeq.dropout.VariationalDropToken1d(dropout=0.5, token_first=True)

Applies Variational Dropout to the input tensor.

During training, randomly zero out the entire token/sequence dimension of the input 1d tensor with probability ‘dropout’ using a mask tensor sample from a Bernoulli distribution.

The token/sequence dimension of 1d tensor is the \(*\) slice of \((B, *, L)\) for token_first=True, or \((B, D, *)\) for token_first=False.

During the fixed point solving, a fixed mask will be applied until convergence. Reset this random mask at the beginning of the next training iteration using reset_dropout.

Parameters:
  • dropout (float, optional) – The probability of an element to be zeroed. Default: 0.5

  • token_first (bool, optional) – If True, expect input tensor in shape \((B, L, D)\), otherwise expect \((B, D, L)\). Here, B is batch size, and D is feature dimension. Default: False

Shape:
  • Input: \((B, L, D)\) or \((B, D, L)\).

  • Output: \((B, L, D)\) or \((B, D, L)\) (same shape as input).

reset_mask(x)

Resets the dropout mask. Subclasses should implement this method according to the dimensionality of the input tensor.

class torchdeq.dropout.VariationalDropToken2d(dropout=0.5, token_first=True)

Applies Variational Dropout to the input tensor.

During training, randomly zero out the entire token/sequence dimension of the input 2d tensor with probability ‘dropout’ using a mask tensor sample from a Bernoulli distribution.

The token/sequence dimension of 2d tensor is the \(*\) slice of \((B, H, W, *)\) for token_first=True, or \((B, *, H, W)\) for token_first=False.

During the fixed point solving, a fixed mask will be applied until convergence. Reset this random mask at the beginning of the next training iteration using reset_dropout.

Parameters:
  • dropout (float, optional) – The probability of an element to be zeroed. Default: 0.5

  • token_first (bool, optional) – If True, expect input tensor in shape \((B, H, W, D)\), otherwise expect \((B, D, H, W)\). Here, B is batch size, and D is feature dimension. Default: False

Shape:
  • Input: \((B, H, W, D)\) or \((B, D, H, W)\).

  • Output: \((B, H, W, D)\) or \((B, D, H, W)\) (same shape as input).

reset_mask(x)

Resets the dropout mask. Subclasses should implement this method according to the dimensionality of the input tensor.

class torchdeq.dropout.VariationalDropToken3d(dropout=0.5, token_first=True)

Applies Variational Dropout to the input tensor.

During training, randomly zero out the entire token/sequence dimension of the input 3d tensor with probability ‘dropout’ using a mask tensor sample from a Bernoulli distribution.

The token/sequence dimension of 3d tensor is the \(*\) slice of \((B, T, H, W, *)\) for token_first=True, or \((B, *, T, H, W)\) for token_first=False.

During the fixed point solving, a fixed mask will be applied until convergence. Reset this random mask at the beginning of the next training iteration using reset_dropout.

Parameters:
  • dropout (float, optional) – The probability of an element to be zeroed. Default: 0.5

  • token_first (bool, optional) – If True, expect input tensor in shape \((B, T, H, W, D)\), otherwise expect \((B, D, T, H, W)\). Here, B is batch size, and D is feature dimension. Default: False

Shape:
  • Input: \((B, T, H, W, D)\) or \((B, D, T, H, W)\).

  • Output: \((B, T, H, W, D)\) or \((B, D, T, H, W)\) (same shape as input).

reset_mask(x)

Resets the dropout mask. Subclasses should implement this method according to the dimensionality of the input tensor.