torchdeq.utils

Config

torchdeq.utils.add_deq_args(parser)

Decorate the commonly used argument parser with arguments used in TorchDEQ.

Parameters:

parser (argparse.Namespace) – Command line arguments.

Memory

torchdeq.utils.mem_gc(func, in_args=None)

Performs the forward and backward pass of a PyTorch Module using gradient checkpointing.

This function is designed for use with iterative computational graphs and the PyTorch DDP training protocol. In the forward pass, it does not store any activations. During the backward pass, it first recomputes the activations and then applies the vector-Jacobian product (vjp) to calculate gradients with respect to the inputs.

The function automatically tracks gradients for the parameters and input tensors that require gradients. It is particularly useful for creating computational graphs with constant memory complexity, i.e., \(\mathcal{O}(1)\) memory.

Parameters:
  • func (torch.nn.Module) – Pytorch Module for which gradients will be computed.

  • in_args (tuple, optional) – Input arguments for the function. Default None.

Returns:

The output of the func Module.

Return type:

tuple

Init

torchdeq.utils.mixed_init(z_shape, device=None)

Initializes a tensor with a shape of z_shape with half Gaussian random values and hald zeros.

Proposed in the paper, Path Independent Equilibrium Models Can Better Exploit Test-Time Computation, for better path independence.

Parameters:
  • z_shape (tuple) – Shape of the tensor to be initialized.

  • device (torch.device, optional) – The desired device of returned tensor. Default None.

Returns:

A tensor of shape z_shape with values randomly initialized and zero masked.

Return type:

torch.Tensor