pcax.utils package#
Module contents#
- class pcax.utils.Mask(x: Type | m | Callable, map_to: Tuple[Any, Any] | None = None)[source]#
Bases:
objectMask can either be used to remove unwanted tensors (by setting them to None, similarly to equinox.filter) or to map them to a different value based on the whether they pass the filter or not (by using map_to). It is important to note that, in any case, the filter acts on the Params of a pytree, seeing them as leaves, while the modified value is the Param’s value (and thus it is impossible to differentiate between different children of the same Param, in the case its value is a pytree itself).
- __call__(pydag: Any, is_pytree: bool = False) Any[source]#
Applies the mask to the given pydag.
- Parameters:
pydag (Any) – target pydag.
is_pytree (bool, optional) – to guarantee that each parameter is masked only once (and thus the mask will not have duplicates), we have to ensure we are working with a pytree. Defaults to False.
- Returns:
the masked pydag (enforced to be a pytree).
- Return type:
Any
- __init__(x: Type | m | Callable, map_to: Tuple[Any, Any] | None = None)[source]#
Mask constructor.
- Parameters:
x (Type | m | Callable) – The object on which recursively apply the mask. If it is a type, it will be used to filter the parameters by type. If it is a callable, it must return a boolean and it will be used to filter the parameters by the result of the call. If it is a mask (i.e., m), it will be recursively applied to the parameters.
map_to (Tuple[Any, Any] | None, optional) – if not None, it will be used to set the parameters to either value based on the mask boolean result. If None, the parameters will be set to None if mask is False, and will remain unchanged otherwise. Defaults to None.
- class pcax.utils.Optim(optax_opt: GradientTransformation, parameters: PyTree | None = None)[source]#
Bases:
BaseModuleOptim inherits from core.BaseModule and thus it is a pytree. It is a thin wrapper around the optax library.
- __init__(optax_opt: GradientTransformation, parameters: PyTree | None = None)[source]#
Optim constructor.
- Parameters:
optax_opt (optax.GradientTransformation) – the optax constructor function.
parameters (PyTree | None, optional) – target parameters. The init method can be called separately by passing None.
- apply_updates(module: PyTree, updates: PyTree) None[source]#
Applies the updates to the module parameters.
- Parameters:
module (PyTree) – the module storing the target parameters.
updates (PyTree) – the updates to apply. Provided updates must match the same structure of the module used to initialise the optimizer.
- step(module: PyTree, grads: PyTree, scale_by_batch_size: bool = False, apply_updates: bool = True, mul: float = None) None[source]#
Performs a gradient update step similarly to Pytorch’s ‘optimizer.step()’ by calling first ‘optax_opt.update’ and then ‘eqx.apply_updates’.
- Parameters:
module (PyTree) – the module storing the target parameters.
grads (PyTree) – the computed gradients to apply. Provided gradients must match the same structure of the module used to initialise the optimizer.
- pcax.utils.load_params(model: ~jaxtyping.PyTree, path: str, filter: ~typing.Callable[[~typing.Any], bool] | ~typing.Type[~pcax.core._parameter.BaseParam] = <class 'pcax.nn._parameter.LayerParam'>) None[source]#
Function to load the parameters of a model from a file. The ‘.npz’ extension is automatically added to the file name. The model must have the same structure as the one used to save the parameters and must already be initialized:
`python model = Model() load_params(model, "model.npz") `- Parameters:
model (PyTree) – target model.
path (str) – the path to the file containing the model parameters to load.
filter (Callable[[Any], bool] | Type[BaseParam], optional) – filter function or type identifying the parameters to save. The default value ‘LayerParam’ selects all the weights of the layers in the model.
- Raises:
KeyError – the file does not contain all the parameters required by the model.
- class pcax.utils.m(x: Type | m | Callable | None = None)[source]#
Bases:
objectm is a utility mask that can be used to combine different filters using logical operators (and more). It is necessary as python supports only the | operator for types. Available operators are: - | for logical or (note that we can also use the | operator directly on the type themselves):
float | int is equivalent to m(int) | m(float). In this case, the parameter will be masked if it is of either type. The other operators behave similarly.
& for logical and
~ for logical not
has to filter based on the presence of an attribute with a specific value
has_not to filter based on the absence of an attribute with a specific value
For example:
`python (m(A | B)) & ~m(C).has(attr1=1) `selects paraameters of class A or B that are not of class C and have an attribute attr1 equal to 1.
- pcax.utils.save_params(model: ~jaxtyping.PyTree, path: str, filter: ~typing.Callable[[~typing.Any], bool] | ~typing.Type[~pcax.core._parameter.BaseParam] = <class 'pcax.nn._parameter.LayerParam'>) None[source]#
Function to save the parameters of a model to a file. The ‘.npz’ extension is automatically added to the file name.
- Parameters:
model (PyTree) – the model to dump to disk.
path (str) – the path to the file where to save the model. If the file already exists, it will be overwritten.
filter (Callable[[Any], bool] | Type[BaseParam], optional) – filter function or type identifying the parameters to save. The default value ‘LayerParam’ selects all the weights of the layers in the model.
- pcax.utils.step(module: EnergyModule | PyTree, status: str | None | Tuple = None, *, clear_params: Callable[[Any], bool] | Type | Tuple = None)[source]#
Applies common operations to a model before and after a step (normally a weight and/or state update). It is useful as settings the model’s status and clearing the parameters cache allows to control the model’s behavior. Example of usage:
```python with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
# compute energy and gradients # apply gradients
- Parameters:
module (EnergyModule | PyTree) – the target module.
status (str | None | Tuple, optional) – the status to apply to the module and its submodules. By default, the status is set to None both before and after the step. If a tuple is provided, the first element is the status before the step and the second element is the status after the step. If a single element is provided, it is used BEFORE the step and None is used after the step.
clear_params (Callable[[Any], bool] | Type | Tuple, optional) – Target parameters to clear. The value is directly passed to the ‘EnergyModule.clear_params’, so refer to that method for more information. If a tuple is provided, the first element is used to call ‘.clear_params’ before the step and the second element is used after. If a single element is provided, it is used AFTER the step and no clearing happens before that.