Source code for pcax.utils._misc

from typing import Any, Callable, Type, Tuple
from jaxtyping import PyTree
import contextlib
import enum

from ..core._tree import tree_apply
from ..predictive_coding._energy_module import EnergyModule


########################################################################################################################
#
# MISC
#
########################################################################################################################


[docs] @contextlib.contextmanager def step( module: EnergyModule | PyTree, status: str | None | Tuple = None, *, clear_params: Callable[[Any], bool] | Type | Tuple = None ): """Applies common operations to a model before and after a step (normally a weight and/or state update). It is useful as settings the model's status and clearing the parameters cache allows to control the model's behavior. Example of usage: ```python with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache): # compute energy and gradients # apply gradients ``` Args: module (EnergyModule | PyTree): the target module. status (str | None | Tuple, optional): the status to apply to the module and its submodules. By default, the status is set to None both before and after the step. If a tuple is provided, the first element is the status before the step and the second element is the status after the step. If a single element is provided, it is used BEFORE the step and None is used after the step. clear_params (Callable[[Any], bool] | Type | Tuple, optional): Target parameters to clear. The value is directly passed to the 'EnergyModule.clear_params', so refer to that method for more information. If a tuple is provided, the first element is used to call '.clear_params' before the step and the second element is used after. If a single element is provided, it is used AFTER the step and no clearing happens before that. """ # Enforce status to be a tuple. status = (status, None) if not isinstance(status, list | tuple) else status clear_params = (None, clear_params) if not isinstance(clear_params, list | tuple) else clear_params if clear_params[0] is not None: module.clear_params(clear_params[0]) tree_apply(lambda m: m._status.set(status[0]), lambda x: isinstance(x, EnergyModule), tree=module) yield tree_apply(lambda m: m._status.set(status[1]), lambda x: isinstance(x, EnergyModule), tree=module) if clear_params[1] is not None: module.clear_params(clear_params[1])