torchdeq.grad
The torchdeq.grad module offers a factory function, backward_factory, which is designed to facilitate the customization of various differentiation methods during the backward pass.
This function is integral to the construction of the backward computational graph in the DEQ class, as it is invoked multiple times to generate gradient functors.
While the backward_factory function is a powerful tool, it is generally not recommended for direct use outside of the library. Instead, users should primarily interact with the DEQ class via the torch.core entry point for most DEQ computations. This approach ensures the appropriate and efficient use of the library’s features.
- torchdeq.grad.backward_factory(grad_type='ift', hook_ift=False, b_solver=None, b_solver_kwargs={}, sup_gap=-1, sup_loc=None, tau=1.0, **grad_factory_kwargs)
Factory for the backward pass of implicit deep learning, e.g., DEQ (implicit models), Hamburger (optimization layers), etc. This function implements various gradients like Implicit Differentiation (IFT), 1-step Grad and Phantom Grad.
- Implicit Differentiation:
[2018-ICML] Reviving and Improving Recurrent Back-Propagation
[2019-NeurIPS] Deep Equilibrium Models
[2019-NeurIPS] Meta-Learning with Implicit Gradients
…
- 1-step Grad & Higher-order Grad:
[2021-ICLR] Is Attention Better Than Matrix Decomposition?
[2022-AAAI] JFB: Jacobian-Free Backpropagation for Implicit Networks
[2021-NeurIPS] On Training Implicit Models
…
- Parameters:
grad_type (str, int, optional) – Gradient type to use. grad_type should be
'ift'
for IFT or an int for PhantomGrad. Default'ift'
. Set to'ift'
to enable the implicit differentiation (IFT) mode. When passing a numberk
to this function, it runs UPG with stepsk
and damping factortau
.hook_ift (bool, optional) –
Set to
True
to enable an \(\Omega(1)\) memory (w.r.t. activations) implementation using the Pytorch hook for IFT.Set to
False
to enable the \(\Omega(2)\) memory implementation usingtorch.autograd.Function
to avoid the (potential) segment fault in older PyTorch versions.Note that the
torch.autograd.Function
implementation is more stable than this hook in numerics and execution, even though they should be conceptually the same. For PyTorch version < 1.7.1 on some machines, this \(\Omega(1)\) hook seems to trigger a segment fault after some training steps. This issue is not caused by TorchDEQ but rather due to the hook.remove() call and some interactions between Python and PyTorch. Thetorch.autograd.Function
implementation also introduces slightly better numerical stability when the forward solver introduces some fixed point errors.Default
False
.b_solver (str, optional) – Solver for the IFT backward pass. Default None. Supported solvers:
'anderson'
,'broyden'
,'fixed_point_iter'
,'simple_fixed_point_iter'
.b_solver_kwargs (dict, optional) – Collection of backward solver kwargs, e.g., max_iter (int, optional), max steps for the backward solver, stop_mode (str, optional), criterion for convergence, etc. See torchdeq.solver for all kwargs.
sup_gap (int, optional) – The gap for uniformly sampling trajectories from PhantomGrad. Sample every
sup_gap
states ifsup_gap > 0
. Default -1.sup_loc (list[int], optional) – Specifies trajectory steps or locations in PhantomGrad from which to sample. Default None.
tau (float, optional) – Damping factor for PhantomGrad. Default 1.0. 0.5-0.7 is recommended for MDEQ. 1.0 for DEQ flow. For DEQ flow, the gating function in GRU naturally produces adaptive tau values.
grad_factory_kwargs – Extra arguments are ignored.
- Returns:
A gradient functor for implicit deep learning. The function takes trainer, func and z_pred as arguments and returns a list of tensors with the gradient information.
- Args:
- trainer (torch.nn.Module):
the module that employs implicit deep learning.
- func (type):
function that defines the f in z = f(z).
- z_pred (torch.Tensor):
latent state to run the backward pass.
- writer (callable, optional):
Callable function to monitor the backward pass. It should accept the solver statistics dictionary as input. Default None.
- Returns:
- list[torch.Tensor]:
a list of tensors that tracks the gradient info. These tensors can be directly applied to downstream networks, while all the gradient info will be automatically tracked in the backward pass.
- Return type:
callable