torchdeq.solver
The torchdeq.solver module provides a set of solvers for finding fixed points in Deep Equilibrium Models (DEQs). These solvers are used to iteratively refine the predictions of a DEQ model until they reach a stable state, or “equilibrium”.
This module includes implementations of several popular fixed-point solvers, including Anderson acceleration (anderson_solver), Broyden’s method (broyden_solver), and fixed-point iteration (fixed_point_iter). It also provides a faster version of fixed-point iteration (simple_fixed_point_iter) that omits convergence monitoring for speed improvements.
The get_solver function allows users to retrieve a specific solver by its key, and the register_solver function allows users to add their own custom solvers to the module.
Example
To retrieve a solver, call this get_solver function:
>>> solver = get_solver('anderson')
To register a user-developed solver, call this register_solver function:
>>> register_solver('newton', newton_solver)
Solver Function
- torchdeq.solver.get_solver(key)
Retrieves a fixed point solver from the registered solvers by its key.
Supported solvers:
'anderson'
,'broyden'
,'fixed_point_iter'
,'simple_fixed_point_iter'
.- Parameters:
key (str) – The key of the solver to retrieve. This should match one of the keys used to register a solver.
- Returns:
The solver function associated with the provided key.
- Return type:
callable
- Raises:
AssertionError – If the key does not match any of the registered solvers.
Example
>>> solver = get_solver('anderson')
- torchdeq.solver.register_solver(solver_type, solver)
Registers a user-defined fixed point solver. This solver can be designated using args.f_solver and args.b_solver.
This method adds a new entry to the solver dict with the key as the specified
solver_type
and the value as thesolver
.- Parameters:
solver_type (str) – The type of solver to register. This will be used as the key in the solver dict.
solver_class (callable) – The solver function. This will be used as the value in the solver dict.
Example
>>> register_solver('newton', newton_solver)
Solver
- torchdeq.solver.fp_iter.fixed_point_iter(func, x0, max_iter=50, tol=0.001, stop_mode='abs', indexing=None, tau=1.0, return_final=False, **kwargs)
Implements the fixed-point iteration solver for solving a system of nonlinear equations.
- Parameters:
func (callable) – The function for which we seek a fixed point.
x0 (torch.Tensor) – The initial guess for the root.
max_iter (int, optional) – The maximum number of iterations. Default: 50.
tol (float, optional) – The convergence criterion. Default: 1e-3.
stop_mode (str, optional) – The stopping criterion. Can be either ‘abs’ or ‘rel’. Default: ‘abs’.
indexing (list, optional) – List of iteration indices at which to store the solution. Default: None.
tau (float, optional) – Damping factor. It is used to control the step size in the direction of the solution. Default: 1.0.
return_final (bool, optional) – If True, run all steps and returns the final solution instead of the one with smallest residual. Default: False.
kwargs (dict, optional) – Extra arguments are ignored.
- Returns:
- a tuple containing the following.
torch.Tensor: Fixed point solution.
list[torch.Tensor]: List of the solutions at the specified iteration indices.
- dict[str, torch.Tensor]:
A dict containing solver statistics in a batch. Please see
torchdeq.solver.stat.SolverStat
for more details.
- Return type:
tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]
Examples
>>> f = lambda z: torch.cos(z) # Function for which we seek a fixed point >>> z0 = torch.tensor(0.0) # Initial estimate >>> z_star, _, _ = fixed_point_iter(f, z0) # Run Fixed Point iterations. >>> print((z_star - f(z_star)).norm(p=1)) # Print the numerical error
- torchdeq.solver.fp_iter.simple_fixed_point_iter(func, x0, max_iter=50, tau=1.0, indexing=None, **kwargs)
Implements a simplified fixed-point solver for solving a system of nonlinear equations.
Speeds up by removing statistics monitoring.
- Parameters:
func (callable) – The function for which the fixed point is to be computed.
x0 (torch.Tensor) – The initial guess for the fixed point.
max_iter (int, optional) – The maximum number of iterations. Default: 50.
tau (float, optional) – Damping factor to control the step size in the solution direction. Default: 1.0.
indexing (list, optional) – List of iteration indices at which to store the solution. Default: None.
kwargs (dict, optional) – Extra arguments are ignored.
- Returns:
- a tuple containing the following.
torch.Tensor: The approximate solution.
list[torch.Tensor]: List of the solutions at the specified iteration indices.
- dict[str, torch.Tensor]:
A dummy dict for solver statistics. All values are initialized as -1 of tensor shape (1, 1).
- Return type:
tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]
Examples
>>> f = lambda z: torch.cos(z) # Function for which we seek a fixed point >>> z0 = torch.tensor(0.0) # Initial estimate >>> z_star, _, _ = simple_fixed_point_iter(f, z0) # Run fixed point iterations >>> print((z_star - f(z_star)).norm(p=1)) # Print the numerical error
- torchdeq.solver.anderson.anderson_solver(func, x0, max_iter=50, tol=0.001, stop_mode='abs', indexing=None, m=6, lam=0.0001, tau=1.0, return_final=False, **kwargs)
Implements the Anderson acceleration for fixed-point iteration.
Anderson acceleration is a method that can accelerate the convergence of fixed-point iterations. It improves the rate of convergence by generating a sequence that converges to the fixed point faster than the original sequence.
- Parameters:
func (callable) – The function for which we seek a fixed point.
x0 (torch.Tensor) – Initial estimate for the fixed point.
max_iter (int, optional) – Maximum number of iterations. Default: 50.
tol (float, optional) – Tolerance for stopping criteria. Default: 1e-3.
stop_mode (str, optional) – Stopping criterion. Can be ‘abs’ for absolute or ‘rel’ for relative. Default: ‘abs’.
indexing (None or list, optional) – Indices for which to store and return solutions. If None, solutions are not stored. Default: None.
m (int, optional) – Maximum number of stored residuals in Anderson mixing. Default: 6.
lam (float, optional) – Regularization parameter in Anderson mixing. Default: 1e-4.
tau (float, optional) – Damping factor. It is used to control the step size in the direction of the solution. Default: 1.0.
return_final (bool, optional) – If True, returns the final solution instead of the one with smallest residual. Default: False.
kwargs (dict, optional) – Extra arguments are ignored.
- Returns:
- a tuple containing the following.
torch.Tensor: Fixed point solution.
list[torch.Tensor]: List of the solutions at the specified iteration indices.
- dict[str, torch.Tensor]:
A dict containing solver statistics in a batch. Please see
torchdeq.solver.stat.SolverStat
for more details.
- Return type:
tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]
Examples
>>> f = lambda z: 0.5 * (z + 2 / z) # Function for which we seek a fixed point >>> z0 = torch.tensor(1.0) # Initial estimate >>> z_star, _, _ = anderson_solver(f, z0) # Run Anderson Acceleration >>> print((z_star - f(z_star)).norm(p=1)) # Print the numerical error
- torchdeq.solver.broyden.broyden_solver(func, x0, max_iter=50, tol=0.001, stop_mode='abs', indexing=None, LBFGS_thres=None, ls=False, return_final=False, **kwargs)
Implements the Broyden’s method for solving a system of nonlinear equations.
- Parameters:
func (callable) – The function for which we seek a fixed point.
x0 (torch.Tensor) – The initial guess for the root.
max_iter (int, optional) – The maximum number of iterations. Default: 50.
tol (float, optional) – The convergence criterion. Default: 1e-3.
stop_mode (str, optional) – The stopping criterion. Can be either ‘abs’ or ‘rel’. Default: ‘abs’.
indexing (list, optional) – List of iteration indices at which to store the solution. Default: None.
LBFGS_thres (int, optional) – The max_iter for the limited memory BFGS method. None for storing all. Default: None.
ls (bool, optional) – If True, perform a line search at each step. Default: False.
return_final (bool, optional) – If True, returns the final solution instead of the one with smallest residual. Default: False.
kwargs (dict, optional) – Extra arguments are ignored.
- Returns:
- a tuple containing the following.
torch.Tensor: Fixed point solution.
list[torch.Tensor]: List of the solutions at the specified iteration indices.
- dict[str, torch.Tensor]:
A dict containing solver statistics in a batch. Please see
torchdeq.solver.stat.SolverStat
for more details.
- Return type:
tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]
Examples
>>> f = lambda z: 0.5 * (z + 2 / z) # Function for which we seek a fixed point >>> z0 = torch.tensor(1.0) # Initial estimate >>> z_star, _, _ = broyden_solver(f, z0) # Run the Broyden's method >>> print((z_star - f(z_star)).norm(p=1)) # Print the numerical error
Solver Stat
- class torchdeq.solver.stat.SolverStat(*args, **kwargs)
A class for storing solver statistics.
This class is a subclass of dict, which allows users to query the solver statistics as dictionary keys.
- Valid Keys:
'abs_lowest'
:The lowest absolute fixed point errors achieved, i.e. \(\|z - f(z)\|\). torch.Tensor of shape \((B,)\).
'rel_lowest'
:The lowest relative fixed point errors achieved, i.e., \(\|z - f(z)\| / \|f(z)\|\). torch.Tensor of shape \((B,)\).
'abs_trace'
:The absolute fixed point errors achieved along the solver steps. torch.Tensor of shape \((B, N)\), where \(N\) is the solver step consumed.
'rel_trace'
:The relative fixed point errors achieved along the solver steps. torch.Tensor of shape \((B, N)\), where \(N\) is the solver step consumed.
'nstep'
:The number of step when the fixed point errors were achieved. torch.Tensor of shape \((B,)\).
'sradius'
:Optional. The largest (abs.) eigenvalue estimated by power method. Available in the eval mode when
sradius_mode
set toTrue
. torch.Tensor of shape \((B,)\).