torchdeq.norm

The torchdeq.norm module provides a set of tools for managing normalization in Deep Equilibrium Models (DEQs). It includes factory functions for applying, resetting, and removing normalization, as well as for registering new normalization types and modules.

The module also provides classes for specific types of normalization, such as WeightNorm and SpectralNorm.

Example

To apply normalization to a model, call this apply_norm function:

>>> apply_norm(model, 'weight_norm', filter_out=['embedding'])

To reset the all normalization within a DEQ model, call this reset_norm function:

>>> reset_norm(model)

To remove the normalization of a DEQ model, call remove_norm function:

>>> remove_norm(model)

To register a user-defined normalization type, call register_norm function:

>>> register_norm('custom_norm', CustomNorm)

To register a new module for a user-define normalization, call register_norm_module function:

>>> register_norm_module(Conv2d, 'custom_norm', 'weight', 0)

Norm Function

torchdeq.norm.apply_norm(model, norm_type='weight_norm', prefix_filter_out=None, filter_out=None, args=None, **norm_kwargs)

Auto applies normalization to all weights of a given layer based on the norm_type.

The currently supported normalizations include 'weight_norm', 'spectral_norm', and 'none' (No Norm applied). Skip the weights whose name contains any string of filter_out or starts with any of prefix_filter_out.

Parameters:
  • model (torch.nn.Module) – Model to apply normalization.

  • norm_type (str, optional) – Type of normalization to be applied. Default is 'weight_norm'.

  • prefix_filter_out (list or str, optional) – List of module weights prefixes to skip out when applying normalization. Default is None.

  • filter_out (list or str, optional) – List of module weights names to skip out when applying normalization. Default is None.

  • args (Union[argparse.Namespace, dict, DEQConfig, Any]) – Configuration for the DEQ model. This can be an instance of argparse.Namespace, a dictionary, or an instance of DEQConfig. Unknown config will be processed using get_attr function. Priority: args > norm_kwargs. Default is None.

  • norm_kwargs – Keyword arguments for the normalization layer.

Raises:

AssertionError – If the norm_type is not registered.

Example

>>> apply_norm(model, 'weight_norm', filter_out=['embedding'])
torchdeq.norm.reset_norm(model)

Auto resets the normalization of a given DEQ model.

Parameters:

model (torch.nn.Module) – Model to reset normalization.

Example

>>> reset_norm(model)
torchdeq.norm.remove_norm(model)

Removes the normalization of a given DEQ model.

Parameters:

model (torch.nn.Module) – A DEQ model to remove normalization.

Example

>>> remove_norm(model)
torchdeq.norm.register_norm(norm_type, norm_class)

Registers a user-defined normalization class for the apply_norm function.

This function adds a new entry to the Norm class dict with the key as the specified norm_type and the value as the norm_class.

Parameters:
  • norm_type (str) – The type of normalization to register. This will be used as the key in the Norm class dictionary.

  • norm_class (type) – The class defining the normalization. This will be used as the value in the Norm class dictionary.

Example

>>> register_norm('custom_norm', CustomNorm)
torchdeq.norm.register_norm_module(module_class, norm_type, names='weight', dims=0)

Registers a to-be-normed module for the user-defined normalization class in the apply_norm function.

This function adds a new entry to the _target_modules attribute of the specified normalization class in the _norm_class dictionary. The key is the module class and the value is a tuple containing the attribute name and dimension over which to compute the norm.

Parameters:
  • module_class (type) – Module class to be indexed for the user-defined normalization class.

  • norm_type (str) – The type of normalization class that the module class should be registered for.

  • names (str, optional) – Attribute name of module_class for the normalization to be applied. Default 'weight'.

  • dims (int, optional) – Dimension over which to compute the norm. Default 0.

Example

>>> register_norm_module(Conv2d, 'custom_norm', 'weight', 0)

Normalization

class torchdeq.norm.weight_norm.WeightNorm(names, dims, learn_scale: bool = True, target_norm: float = 1.0, clip: bool = False, clip_value: float = 1.0)
classmethod apply(module, deq_args=None, names=None, dims=None, learn_scale=True, target_norm=1.0, clip=False, clip_value=1.0)

Apply weight normalization to a given module.

Parameters:
  • module (torch.nn.Module) – The module to apply weight normalization to.

  • deq_args (Union[argparse.Namespace, dict, DEQConfig, Any]) – Configuration for the DEQ model. This can be an instance of argparse.Namespace, a dictionary, or an instance of DEQConfig. Unknown config will be processed using get_attr function.

  • names (list or str, optional) – The names of the parameters to apply spectral normalization to.

  • dims (list or int, optional) – The dimensions along which to normalize.

  • learn_scale (bool, optional) – If true, learn a scale factor during training. Default True.

  • target_norm (float, optional) – The target norm value. Default 1.

  • clip (bool, optional) – If true, clip the scale factor. Default False.

  • clip_value (float, optional) – The value to clip the scale factor to. Default 1.

Returns:

The WeightNorm instance.

Return type:

WeightNorm

compute_weight(module, name, dim)

Computes the weight with weight normalization.

Parameters:
  • module (torch.nn.Module) – The module which holds the weight tensor.

  • name (str) – The name of the weight parameter.

  • dim (int) – The dimension along which to normalize.

Returns:

The weight tensor after applying weight normalization.

Return type:

Tensor

remove(module)

Removes weight normalization from the module.

Parameters:

module (torch.nn.Module) – The module to remove weight normalization from.

class torchdeq.norm.spectral_norm.SpectralNorm(names, dims, learn_scale: bool = True, target_norm: float = 1.0, clip: bool = False, clip_value: float = 1.0, n_power_iterations: int = 1, eps: float = 1e-12)
classmethod apply(module, deq_args=None, names=None, dims=None, learn_scale=True, target_norm=1.0, clip=False, clip_value=1.0, n_power_iterations=1, eps=1e-12)

Applies spectral normalization to a given module.

Parameters:
  • module (torch.nn.Module) – The module to apply spectral normalization to.

  • deq_args (Union[argparse.Namespace, dict, DEQConfig, Any]) – Configuration for the DEQ model. This can be an instance of argparse.Namespace, a dictionary, or an instance of DEQConfig. Unknown config will be processed using get_attr function.

  • names (list or str, optional) – The names of the parameters to apply spectral normalization to.

  • dims (list or int, optional) – The dimensions along which to normalize.

  • learn_scale (bool, optional) – If true, learn a scale factor during training. Default True.

  • target_norm (float, optional) – The target norm value. Default 1.

  • clip (bool, optional) – If true, clip the scale factor. Default False.

  • clip_value (float, optional) – The value to clip the scale factor to. Default 1.

  • n_power_iterations (int, optional) – The number of power iterations to perform. Default 1.

  • eps (float, optional) – A small constant for numerical stability. Default 1e-12.

Returns:

The SpectralNorm instance.

Return type:

SpectralNorm

compute_weight(module, do_power_iteration, name, dim)

Computes the weight with spectral normalization.

Parameters:
  • module (torch.nn.Module) – The module which holds the weight tensor.

  • do_power_iteration (bool) – If true, do power iteration for approximating singular vectors.

  • name (str) – The name of the weight parameter.

  • dim (int) – The dimension along which to normalize.

Returns:

The computed weight tensor.

Return type:

torch.Tensor

remove(module)

Removes spectral normalization from the module.

Parameters:

module (torch.nn.Module) – The module to remove spectral normalization from.