torchdeq.solver

The torchdeq.solver module provides a set of solvers for finding fixed points in Deep Equilibrium Models (DEQs). These solvers are used to iteratively refine the predictions of a DEQ model until they reach a stable state, or “equilibrium”.

This module includes implementations of several popular fixed-point solvers, including Anderson acceleration (anderson_solver), Broyden’s method (broyden_solver), and fixed-point iteration (fixed_point_iter). It also provides a faster version of fixed-point iteration (simple_fixed_point_iter) that omits convergence monitoring for speed improvements.

The get_solver function allows users to retrieve a specific solver by its key, and the register_solver function allows users to add their own custom solvers to the module.

Example

To retrieve a solver, call this get_solver function:

>>> solver = get_solver('anderson')

To register a user-developed solver, call this register_solver function:

>>> register_solver('newton', newton_solver)

Solver Function

torchdeq.solver.get_solver(key)

Retrieves a fixed point solver from the registered solvers by its key.

Supported solvers: 'anderson', 'broyden', 'fixed_point_iter', 'simple_fixed_point_iter'.

Parameters:

key (str) – The key of the solver to retrieve. This should match one of the keys used to register a solver.

Returns:

The solver function associated with the provided key.

Return type:

callable

Raises:

AssertionError – If the key does not match any of the registered solvers.

Example

>>> solver = get_solver('anderson')
torchdeq.solver.register_solver(solver_type, solver)

Registers a user-defined fixed point solver. This solver can be designated using args.f_solver and args.b_solver.

This method adds a new entry to the solver dict with the key as the specified solver_type and the value as the solver.

Parameters:
  • solver_type (str) – The type of solver to register. This will be used as the key in the solver dict.

  • solver_class (callable) – The solver function. This will be used as the value in the solver dict.

Example

>>> register_solver('newton', newton_solver)

Solver

torchdeq.solver.fp_iter.fixed_point_iter(func, x0, max_iter=50, tol=0.001, stop_mode='abs', indexing=None, tau=1.0, return_final=False, **kwargs)

Implements the fixed-point iteration solver for solving a system of nonlinear equations.

Parameters:
  • func (callable) – The function for which we seek a fixed point.

  • x0 (torch.Tensor) – The initial guess for the root.

  • max_iter (int, optional) – The maximum number of iterations. Default: 50.

  • tol (float, optional) – The convergence criterion. Default: 1e-3.

  • stop_mode (str, optional) – The stopping criterion. Can be either ‘abs’ or ‘rel’. Default: ‘abs’.

  • indexing (list, optional) – List of iteration indices at which to store the solution. Default: None.

  • tau (float, optional) – Damping factor. It is used to control the step size in the direction of the solution. Default: 1.0.

  • return_final (bool, optional) – If True, run all steps and returns the final solution instead of the one with smallest residual. Default: False.

  • kwargs (dict, optional) – Extra arguments are ignored.

Returns:

a tuple containing the following.
  • torch.Tensor: Fixed point solution.

  • list[torch.Tensor]: List of the solutions at the specified iteration indices.

  • dict[str, torch.Tensor]:

    A dict containing solver statistics in a batch. Please see torchdeq.solver.stat.SolverStat for more details.

Return type:

tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]

Examples

>>> f = lambda z: torch.cos(z)                  # Function for which we seek a fixed point
>>> z0 = torch.tensor(0.0)                      # Initial estimate
>>> z_star, _, _ = fixed_point_iter(f, z0)      # Run Fixed Point iterations.
>>> print((z_star - f(z_star)).norm(p=1))       # Print the numerical error
torchdeq.solver.fp_iter.simple_fixed_point_iter(func, x0, max_iter=50, tau=1.0, indexing=None, **kwargs)

Implements a simplified fixed-point solver for solving a system of nonlinear equations.

Speeds up by removing statistics monitoring.

Parameters:
  • func (callable) – The function for which the fixed point is to be computed.

  • x0 (torch.Tensor) – The initial guess for the fixed point.

  • max_iter (int, optional) – The maximum number of iterations. Default: 50.

  • tau (float, optional) – Damping factor to control the step size in the solution direction. Default: 1.0.

  • indexing (list, optional) – List of iteration indices at which to store the solution. Default: None.

  • kwargs (dict, optional) – Extra arguments are ignored.

Returns:

a tuple containing the following.
  • torch.Tensor: The approximate solution.

  • list[torch.Tensor]: List of the solutions at the specified iteration indices.

  • dict[str, torch.Tensor]:

    A dummy dict for solver statistics. All values are initialized as -1 of tensor shape (1, 1).

Return type:

tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]

Examples

>>> f = lambda z: torch.cos(z)                      # Function for which we seek a fixed point
>>> z0 = torch.tensor(0.0)                          # Initial estimate
>>> z_star, _, _ = simple_fixed_point_iter(f, z0)   # Run fixed point iterations
>>> print((z_star - f(z_star)).norm(p=1))           # Print the numerical error
torchdeq.solver.anderson.anderson_solver(func, x0, max_iter=50, tol=0.001, stop_mode='abs', indexing=None, m=6, lam=0.0001, tau=1.0, return_final=False, **kwargs)

Implements the Anderson acceleration for fixed-point iteration.

Anderson acceleration is a method that can accelerate the convergence of fixed-point iterations. It improves the rate of convergence by generating a sequence that converges to the fixed point faster than the original sequence.

Parameters:
  • func (callable) – The function for which we seek a fixed point.

  • x0 (torch.Tensor) – Initial estimate for the fixed point.

  • max_iter (int, optional) – Maximum number of iterations. Default: 50.

  • tol (float, optional) – Tolerance for stopping criteria. Default: 1e-3.

  • stop_mode (str, optional) – Stopping criterion. Can be ‘abs’ for absolute or ‘rel’ for relative. Default: ‘abs’.

  • indexing (None or list, optional) – Indices for which to store and return solutions. If None, solutions are not stored. Default: None.

  • m (int, optional) – Maximum number of stored residuals in Anderson mixing. Default: 6.

  • lam (float, optional) – Regularization parameter in Anderson mixing. Default: 1e-4.

  • tau (float, optional) – Damping factor. It is used to control the step size in the direction of the solution. Default: 1.0.

  • return_final (bool, optional) – If True, returns the final solution instead of the one with smallest residual. Default: False.

  • kwargs (dict, optional) – Extra arguments are ignored.

Returns:

a tuple containing the following.
  • torch.Tensor: Fixed point solution.

  • list[torch.Tensor]: List of the solutions at the specified iteration indices.

  • dict[str, torch.Tensor]:

    A dict containing solver statistics in a batch. Please see torchdeq.solver.stat.SolverStat for more details.

Return type:

tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]

Examples

>>> f = lambda z: 0.5 * (z + 2 / z)                 # Function for which we seek a fixed point
>>> z0 = torch.tensor(1.0)                          # Initial estimate
>>> z_star, _, _ = anderson_solver(f, z0)           # Run Anderson Acceleration
>>> print((z_star - f(z_star)).norm(p=1))           # Print the numerical error
torchdeq.solver.broyden.broyden_solver(func, x0, max_iter=50, tol=0.001, stop_mode='abs', indexing=None, LBFGS_thres=None, ls=False, return_final=False, **kwargs)

Implements the Broyden’s method for solving a system of nonlinear equations.

Parameters:
  • func (callable) – The function for which we seek a fixed point.

  • x0 (torch.Tensor) – The initial guess for the root.

  • max_iter (int, optional) – The maximum number of iterations. Default: 50.

  • tol (float, optional) – The convergence criterion. Default: 1e-3.

  • stop_mode (str, optional) – The stopping criterion. Can be either ‘abs’ or ‘rel’. Default: ‘abs’.

  • indexing (list, optional) – List of iteration indices at which to store the solution. Default: None.

  • LBFGS_thres (int, optional) – The max_iter for the limited memory BFGS method. None for storing all. Default: None.

  • ls (bool, optional) – If True, perform a line search at each step. Default: False.

  • return_final (bool, optional) – If True, returns the final solution instead of the one with smallest residual. Default: False.

  • kwargs (dict, optional) – Extra arguments are ignored.

Returns:

a tuple containing the following.
  • torch.Tensor: Fixed point solution.

  • list[torch.Tensor]: List of the solutions at the specified iteration indices.

  • dict[str, torch.Tensor]:

    A dict containing solver statistics in a batch. Please see torchdeq.solver.stat.SolverStat for more details.

Return type:

tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]

Examples

>>> f = lambda z: 0.5 * (z + 2 / z)                 # Function for which we seek a fixed point
>>> z0 = torch.tensor(1.0)                          # Initial estimate
>>> z_star, _, _ = broyden_solver(f, z0)            # Run the Broyden's method
>>> print((z_star - f(z_star)).norm(p=1))           # Print the numerical error

Solver Stat

class torchdeq.solver.stat.SolverStat(*args, **kwargs)

A class for storing solver statistics.

This class is a subclass of dict, which allows users to query the solver statistics as dictionary keys.

Valid Keys:
  • 'abs_lowest':

    The lowest absolute fixed point errors achieved, i.e. \(\|z - f(z)\|\). torch.Tensor of shape \((B,)\).

  • 'rel_lowest':

    The lowest relative fixed point errors achieved, i.e., \(\|z - f(z)\| / \|f(z)\|\). torch.Tensor of shape \((B,)\).

  • 'abs_trace':

    The absolute fixed point errors achieved along the solver steps. torch.Tensor of shape \((B, N)\), where \(N\) is the solver step consumed.

  • 'rel_trace':

    The relative fixed point errors achieved along the solver steps. torch.Tensor of shape \((B, N)\), where \(N\) is the solver step consumed.

  • 'nstep':

    The number of step when the fixed point errors were achieved. torch.Tensor of shape \((B,)\).

  • 'sradius':

    Optional. The largest (abs.) eigenvalue estimated by power method. Available in the eval mode when sradius_mode set to True. torch.Tensor of shape \((B,)\).