torchdeq.loss

Correction

torchdeq.loss.fp_correction(crit, args, weight_func='exp', return_loss_values=False, **kwargs)

Computes fixed-point correction for stabilizing Deep Equilibrium (DEQ) models.

Fixed point correction applies the loss function to a sequence of tensors that converge to the fixed point. The loss value of each tensor tuple is weighted by the weight function. This function automatically aligns the input arguments to be of the same length.

The currently supported weight functions include 'const' (constant), 'linear', and 'exp' (exponential).

Parameters:
  • crit (callable) – Loss function. Can be the instance of torch.nn.Module or functor.

  • args (list or tuple) – List of arguments to pass to the criterion.

  • weight_func (str, optional) – Name of the weight function to use. Default ‘exp’.

  • return_loss_values (bool, optional) – Whether to return the loss values. Default False.

  • **kwargs – Additional keyword arguments for the weight function.

Returns:

The computed loss. list[float]: List of individual loss values. Returned only if return_loss_values is set to True.

Return type:

torch.Tensor

Examples

>>> x = [torch.randn(16, 32, 32) for _ in range(3)]
>>> y = torch.randn(16, 32, 32)
>>> mask = torch.rand(16, 32, 32)
>>> crit = lambda x, y, mask: ((x - y) * mask).abs().mean()
>>> loss = fp_correction(crit, (x, y, mask))
torchdeq.loss.register_weight_func(name, func)

Registers a new weight function for fixed point correction.

The weight function should map a pair of integers (n, k) to a float, serving as the weight of loss, where ‘n’ is the total length of the sequence that converges to the fixed point, and ‘k’ is the order of the current state in the sequence.

Parameters:
  • name (str) – Identifier to associate with the new weight function.

  • func (callable) – The weight function to register, mapping (n, k) to a float value.

Raises:

AssertionError – If func is not callable.

Jacobian

torchdeq.loss.jac_reg(f0, z0, vecs=1, create_graph=True)

Estimates tr(J^TJ)=tr(JJ^T) via Hutchinson estimator.

Parameters:
  • f0 (torch.Tensor) – Output of the function f (whose J is to be analyzed)

  • z0 (torch.Tensor) – Input to the function f

  • vecs (int, optional) – Number of random Gaussian vectors to use. Defaults to 2.

  • create_graph (bool, optional) – Whether to create backward graph (e.g., to train on this loss). Defaults to True.

Returns:

A 1x1 torch tensor that encodes the (shape-normalized) jacobian loss

Return type:

torch.Tensor

torchdeq.loss.power_method(f0, z0, n_iters=100)

Estimates the spectral radius of J using power method.

Parameters:
  • f0 (torch.Tensor) – Output of the function f (whose J is to be analyzed)

  • z0 (torch.Tensor) – Input to the function f

  • n_iters (int, optional) – Number of power method iterations. Default is 100.

Returns:

(largest eigenvector, largest (abs.) eigenvalue)

Return type:

tuple