Source code for pcax.utils._misc
from typing import Any, Callable, Type, Tuple
from jaxtyping import PyTree
import contextlib
import enum
from ..core._tree import tree_apply
from ..predictive_coding._energy_module import EnergyModule
########################################################################################################################
#
# MISC
#
########################################################################################################################
[docs]
@contextlib.contextmanager
def step(
module: EnergyModule | PyTree,
status: str | None | Tuple = None,
*,
clear_params: Callable[[Any], bool] | Type | Tuple = None
):
"""Applies common operations to a model before and after a step (normally a weight and/or state update).
It is useful as settings the model's status and clearing the parameters cache allows to control the model's
behavior. Example of usage:
```python
with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
# compute energy and gradients
# apply gradients
```
Args:
module (EnergyModule | PyTree): the target module.
status (str | None | Tuple, optional): the status to apply to the module and its submodules. By default, the
status is set to None both before and after the step. If a tuple is provided, the first element is the
status before the step and the second element is the status after the step. If a single element is
provided, it is used BEFORE the step and None is used after the step.
clear_params (Callable[[Any], bool] | Type | Tuple, optional): Target parameters to clear. The value is
directly passed to the 'EnergyModule.clear_params', so refer to that method for more information.
If a tuple is provided, the first element is used to call '.clear_params' before the step and the second
element is used after. If a single element is provided, it is used AFTER the step and no clearing happens
before that.
"""
# Enforce status to be a tuple.
status = (status, None) if not isinstance(status, list | tuple) else status
clear_params = (None, clear_params) if not isinstance(clear_params, list | tuple) else clear_params
if clear_params[0] is not None:
module.clear_params(clear_params[0])
tree_apply(lambda m: m._status.set(status[0]), lambda x: isinstance(x, EnergyModule), tree=module)
yield
tree_apply(lambda m: m._status.set(status[1]), lambda x: isinstance(x, EnergyModule), tree=module)
if clear_params[1] is not None:
module.clear_params(clear_params[1])