torchdeq.core

The DEQ models are a class of implicit models that solve for fixed points to make predictions. This module provides the core classes and functions for implementing Deep Equilibrium (DEQ) models in PyTorch.

The main classes in this module are DEQBase, DEQIndexing, and DEQSliced. DEQBase is the base class for DEQ models, and DEQIndexing and DEQSliced are two specific implementations of DEQ models that use different strategies for applying gradients during training.

The module also provides utility functions for creating and manipulating DEQ models, such as get_deq for creating a DEQ model based on command line arguments, register_deq for registering a new DEQ model class, and reset_deq for resetting the normalization and dropout layers of a DEQ model.

Example

To create a DEQ model, you can use the get_deq function:

>>> deq = get_deq(args)

To reset the normalization and dropout layers of a DEQ model, you can use the reset_deq function:

>>> deq_layer = DEQLayer(args)          # A Pytorch Module used in the f of z* = f(z*, x).
>>> reset_deq(deq_layer)

Core Function

torchdeq.core.get_deq(args=None, **kwargs)

Factory function to generate an instance of a DEQ model based on the command line arguments.

This function returns an instance of a DEQ model class based on the DEQ computational core specified in the command line arguments args.core. For example, --core indexing for DEQIndexing, --core sliced for DEQSliced, etc.

DEQIndexing and DEQSliced build different computational graphs in training but keep the same for test.

For DEQIndexing, it defines a computational graph with tracked gradients by indexing the internal solver states and applying the gradient function to the sampled states. This is equivalent to attaching the gradient function aside the full solver computational graph. The maximum number of DEQ function calls is defined by args.f_max_iter.

For DEQSliced, it slices the full solver steps into several smaller graphs (w/o grad). The gradient function will be applied to the returned state of each subgraph. Then a new fixed point solver will resume from the output of the gradient function. This is equivalent to inserting the gradient function into the full solver computational graph. The maximum number of DEQ function calls is defined by, for example, args.f_max_iter + args.n_states * args.grad.

Parameters:
  • args (Union[argparse.Namespace, dict, DEQConfig, Any]) – Configuration specifying the config of the DEQ model. Default None. This can be an instance of argparse.Namespace, a dictionary, or an instance of DEQConfig. Unknown config will be processed using get_attr function.

  • **kwargs – Additional keyword arguments to update the config.

Returns:

DEQ module that defines the computational graph from the specified config.

Return type:

DEQBase (torch.nn.Module)

Example

To instantiate a DEQ module, you can directly pass keyword arguments to this function:

>>> deq = get_deq(core='sliced')

Alternatively, if you’re using a config system like argparse, you can pass the parsed config as a single object:

>>> args = argparse.Namespace(core='sliced')
>>> deq = get_deq(args)
torchdeq.core.reset_deq(model)

Resets the normalization and dropout layers of the given DEQ model (usually before each training iteration).

Parameters:

model (torch.nn.Module) – The DEQ model to reset.

Example

>>> deq_layer = DEQLayer(args)          # A Pytorch Module used in the f of z* = f(z*, x).
>>> reset_deq(deq_layer)
torchdeq.core.register_deq(deq_type, core)

Registers a user-defined DEQ class for the get_deq function.

This method adds a new entry to the DEQ class dict with the key as the specified DEQ type and the value as the DEQ class.

Parameters:
  • deq_type (str) – The type of DEQ model to register. This will be used as the key in the DEQ class dict.

  • core (type) – The class defining the DEQ model. This will be used as the value in the DEQ class dict.

Example

>>> register_deq('custom', CustomDEQ)

DEQ Class

class torchdeq.core.DEQBase(args=None, f_solver='fixed_point_iter', b_solver='fixed_point_iter', no_stat=None, f_max_iter=40, b_max_iter=40, f_tol=0.001, b_tol=1e-06, f_stop_mode='abs', b_stop_mode='abs', eval_factor=1.0, eval_f_max_iter=0, **kwargs)

Base class for Deep Equilibrium (DEQ) model.

This class is not intended to be directly instantiated as the actual DEQ module. Instead, you should create an instance of a subclass of this class.

If you are looking to implement a new computational graph for DEQ models, you can inherit from this class. This allows you to leverage other components in the library in your implementation.

Parameters:
  • args (Union[argparse.Namespace, dict, DEQConfig, Any], optional) – Configuration for the DEQ model. This can be an instance of argparse.Namespace, a dictionary, or an instance of DEQConfig. Unknown config will be processed using get_attr function. Priority: args > norm_kwargs. Default None.

  • f_solver (str, optional) – The forward solver function. Default solver is 'fixed_point_iter'.

  • b_solver (str, optional) – The backward solver function. Default solver is 'fixed_point_iter'.

  • no_stat (bool, optional) – Skips the solver stats computation if True. Default None.

  • f_max_iter (int, optional) – Maximum number of iterations (NFE) for the forward solver. Default 40.

  • b_max_iter (int, optional) – Maximum number of iterations (NFE) for the backward solver. Default 40.

  • f_tol (float, optional) – The forward pass solver stopping criterion. Default 1e-3.

  • b_tol (float, optional) – The backward pass solver stopping criterion. Default 1e-6.

  • f_stop_mode (str, optional) – The forward pass fixed-point convergence stop mode. Default 'abs'.

  • b_stop_mode (str, optional) – The backward pass fixed-point convergence stop mode. Default 'abs'.

  • eval_factor (int, optional) – The max iteration for the forward pass at test time, calculated as f_max_iter * eval_factor. Default 1.0.

  • eval_f_max_iter (int, optional) – The max iteration for the forward pass at test time. Overwrite eval_factor by an exact number.

  • **kwargs – Additional keyword arguments to update the configuration.

forward(func, z_init, solver_kwargs=None, sradius_mode=False, backward_writer=None, **kwargs)

Defines the computation graph and gradients of DEQ. Must be overridden in subclasses.

Parameters:
  • func (callable) – The DEQ function.

  • z_init (torch.Tensor) – Initial tensor for fixed point solver.

  • solver_kwargs (dict, optional) – Additional arguments for the solver used in this forward pass. These arguments will overwrite the default solver arguments. Refer to the documentation of the specific solver for the list of accepted arguments. Default None.

  • sradius_mode (bool, optional) – If True, computes the spectral radius in validation and adds ‘sradius’ to the info dictionary. Default False.

  • backward_writer (callable, optional) – Callable function to monitor the backward pass. It should accept the solver statistics dictionary as input. Default None.

Raises:

NotImplementedError – If the method is not overridden.

class torchdeq.core.DEQIndexing(args=None, ift=False, hook_ift=False, grad=1, tau=1.0, sup_gap=-1, sup_loc=None, n_states=1, indexing=None, **kwargs)

DEQ computational graph that samples fixed point states at specific indices.

For DEQIndexing, it defines a computational graph with tracked gradients by indexing the internal solver states and applying the gradient function to the sampled states. This is equivalent to attaching the gradient function aside the full solver computational graph. The maximum number of DEQ function calls is defined by args.f_max_iter.

Parameters:
  • args (Union[argparse.Namespace, dict, DEQConfig, Any], optional) – Configuration for the DEQ model. This can be an instance of argparse.Namespace, a dictionary, or an instance of DEQConfig. Unknown config will be processed using get_attr function. Priority: args > norm_kwargs. Default None.

  • f_solver (str, optional) – The forward solver function. Default 'fixed_point_iter'.

  • b_solver (str, optional) – The backward solver function. Default 'fixed_point_iter'.

  • no_stat (bool, optional) – Skips the solver stats computation if True. Default None.

  • f_max_iter (int, optional) – Maximum number of iterations (NFE) for the forward solver. Default 40.

  • b_max_iter (int, optional) – Maximum number of iterations (NFE) for the backward solver. Default 40.

  • f_tol (float, optional) – The forward pass solver stopping criterion. Default 1e-3.

  • b_tol (float, optional) – The backward pass solver stopping criterion. Default 1e-6.

  • f_stop_mode (str, optional) – The forward pass fixed-point convergence stop mode. Default 'abs'.

  • b_stop_mode (str, optional) – The backward pass fixed-point convergence stop mode. Default 'abs'.

  • eval_factor (int, optional) – The max iteration for the forward pass at test time, calculated as f_max_iter * eval_factor. Default 1.0.

  • eval_f_max_iter (int, optional) – The max iteration for the forward pass at test time. Overwrite eval_factor by an exact number.

  • ift (bool, optional) – If true, enable Implicit Differentiation. IFT=Implicit Function Theorem. Default False.

  • hook_ift (bool, optional) – If true, enable a Pytorch backward hook implementation of IFT. Furthure reduces memory usage but may affect stability. Default False.

  • grad (Union[int, list[int], tuple[int]], optional) – Specifies the steps of PhantomGrad. It allows for using multiple values to represent different gradient steps in the sampled trajectory states. Default 1.

  • tau (float, optional) – Damping factor for PhantomGrad. Default 1.0.

  • sup_gap (int, optional) – The gap for uniformly sampling trajectories from PhantomGrad. Sample every sup_gap states if sup_gap > 0. Default -1.

  • sup_loc (list[int], optional) – Specifies trajectory steps or locations in PhantomGrad from which to sample. Default None.

  • n_states (int, optional) – Uniformly samples trajectory states from the solver. The backward passes of sampled states will be automactically tracked. IFT will be applied to the best fixed-point estimation when ift=True, while internal states are tracked by PhantomGrad. Default 1. By default, only the best fixed point estimation will be returned.

  • indexing (int, optional) – Samples specific trajectory states at the given steps in indexing from the solver. Similar to n_states but more flexible. Default None.

  • **kwargs – Additional keyword arguments to update the configuration.

arg_indexing

Define gradient functions through the backward factory.

forward(func, z_init, solver_kwargs=None, sradius_mode=False, backward_writer=None, **kwargs)

Defines the computation graph and gradients of DEQ.

This method carries out the forward pass computation for the DEQ model, by solving for the fixed point. During training, it also keeps track of the trajectory of the solution. In inference mode, it returns the final fixed point.

Parameters:
  • func (callable) – The DEQ function.

  • z_init (torch.Tensor) – Initial tensor for fixed point solver.

  • solver_kwargs (dict, optional) – Additional arguments for the solver used in this forward pass. These arguments will overwrite the default solver arguments. Refer to the documentation of the specific solver for the list of accepted arguments. Default None.

  • sradius_mode (bool, optional) – If True, computes the spectral radius in validation and adds 'sradius' to the info dictionary. Default False.

  • backward_writer (callable, optional) – Callable function to monitor the backward pass. It should accept the solver statistics dictionary as input. Default None.

Returns:

a tuple containing the following.

  • list[torch.Tensor]:
    During training, returns the sampled fixed point trajectory (tracked gradients) according to n_states or indexing.
    During inference, returns a list containing the fixed point solution only.
  • dict[str, torch.Tensor]:

    A dict containing solver statistics in a batch. Please see torchdeq.solver.stat.SolverStat for more details.

Return type:

tuple[list[torch.Tensor], dict[str, torch.Tensor]]

class torchdeq.core.DEQSliced(args=None, ift=False, hook_ift=False, grad=1, tau=1.0, sup_gap=-1, sup_loc=None, n_states=1, indexing=None, **kwargs)

DEQ computational graph that slices the full solver trajectory to apply gradients.

For DEQSliced, it slices the full solver steps into several smaller graphs (w/o grad). The gradient function will be applied to the returned state of each subgraph. Then a new fixed point solver will resume from the output of the gradient function. This is equivalent to inserting the gradient function into the full solver computational graph. The maximum number of DEQ function calls is defined by, for example, args.f_max_iter + args.n_states * args.grad.

Parameters:
  • args (Union[argparse.Namespace, dict, DEQConfig, Any], optional) – Configuration for the DEQ model. This can be an instance of argparse.Namespace, a dictionary, or an instance of DEQConfig. Unknown config will be processed using get_attr function. Priority: args > norm_kwargs. Default None.

  • f_solver (str, optional) – The forward solver function. Default 'fixed_point_iter'.

  • b_solver (str, optional) – The backward solver function. Default 'fixed_point_iter'.

  • no_stat (bool, optional) – Skips the solver stats computation if True. Default None.

  • f_max_iter (int, optional) – Maximum number of iterations (NFE) for the forward solver. Default 40.

  • b_max_iter (int, optional) – Maximum number of iterations (NFE) for the backward solver. Default 40.

  • f_tol (float, optional) – The forward pass solver stopping criterion. Default 1e-3.

  • b_tol (float, optional) – The backward pass solver stopping criterion. Default 1e-6.

  • f_stop_mode (str, optional) – The forward pass fixed-point convergence stop mode. Default 'abs'.

  • b_stop_mode (str, optional) – The backward pass fixed-point convergence stop mode. Default 'abs'.

  • eval_factor (int, optional) – The max iteration for the forward pass at test time, calculated as f_max_iter * eval_factor. Default 1.0.

  • eval_f_max_iter (int, optional) – The max iteration for the forward pass at test time. Overwrite eval_factor by an exact number.

  • ift (bool, optional) – If true, enable Implicit Differentiation. IFT=Implicit Function Theorem. Default False.

  • hook_ift (bool, optional) – If true, enable a Pytorch backward hook implementation of IFT. Furthure reduces memory usage but may affect stability. Default False.

  • grad (Union[int, list[int], tuple[int]], optional) – Specifies the steps of PhantomGrad. It allows for using multiple values to represent different gradient steps in the sampled trajectory states. Default 1.

  • tau (float, optional) – Damping factor for PhantomGrad. Default 1.0.

  • sup_gap (int, optional) – The gap for uniformly sampling trajectories from PhantomGrad. Sample every sup_gap states if sup_gap > 0. Default -1.

  • sup_loc (list[int], optional) – Specifies trajectory steps or locations in PhantomGrad from which to sample. Default None.

  • n_states (int, optional) – Uniformly samples trajectory states from the solver. The backward passes of sampled states will be automactically tracked. IFT will be applied to the best fixed-point estimation when ift=True, while internal states are tracked by PhantomGrad. Default 1. By default, only the best fixed point estimation will be returned.

  • indexing (int, optional) – Samples specific trajectory states at the given steps in indexing from the solver. Similar to n_states but more flexible. Default None.

  • **kwargs – Additional keyword arguments to update the configuration.

arg_indexing

Define gradient functions through the backward factory.

forward(func, z_star, solver_kwargs=None, sradius_mode=False, backward_writer=None, **kwargs)

Defines the computation graph and gradients of DEQ.

Parameters:
  • func (callable) – The DEQ function.

  • z_init (torch.Tensor) – Initial tensor for fixed point solver.

  • solver_kwargs (dict, optional) – Additional arguments for the solver used in this forward pass. These arguments will overwrite the default solver arguments. Refer to the documentation of the specific solver for the list of accepted arguments. Default None.

  • sradius_mode (bool, optional) – If True, computes the spectral radius in validation and adds 'sradius' to the info dictionary. Default False.

  • backward_writer (callable, optional) – Callable function to monitor the backward pass. It should accept the solver statistics dictionary as input. Default None.

Returns:

a tuple containing the following.

  • list[torch.Tensor]:
    During training, returns the sampled fixed point trajectory (tracked gradients) according to n_states or indexing.
    During inference, returns a list containing the fixed point solution only.
  • dict[str, torch.Tensor]:

    A dict containing solver statistics in a batch. Please see torchdeq.solver.stat.SolverStat for more details.

Return type:

tuple[list[torch.Tensor], dict[str, torch.Tensor]]