Source code for pcax.utils._optim

__all__ = ["Optim"]

from jaxtyping import PyTree
import optax
import jax.tree_util as jtu
import equinox as eqx

from ..core._module import BaseModule
from ..core._parameter import Param, BaseParam, set, get
from ..core._static import static


########################################################################################################################
#
# OPTIM
#
# Optim offers a simple interface to the optax library. Being a 'BaseModule' it can be pass through pcax transformations
# with its state being tracked and updated. Note that init can be called anytime to reset the optimizer state. This
# can be helpful when the optimizer is used in a loop and the state needs to be reset at each iteration (for example,
# this may be the case for the vode opimitzer after each mini-batch).
#
# DEV NOTE: currently all the state is stored as a single parameter. This may be insufficient for advanced learning
# techniques that require, for example, differentiating with respect to some of the optimizer's values.
# In such case, this optimizer class should be upgradable to the same design pattern used for 'Layers', which substitues
# each individual weights with a different parameter (which when '.update' is called, can be firstly replaced with their
# values, similarly to the 'Layer.__call__' method).
#
########################################################################################################################


[docs] class Optim(BaseModule): """Optim inherits from core.BaseModule and thus it is a pytree. It is a thin wrapper around the optax library."""
[docs] def __init__(self, optax_opt: optax.GradientTransformation, parameters: PyTree | None = None): """Optim constructor. Args: optax_opt (optax.GradientTransformation): the optax constructor function. parameters (PyTree | None, optional): target parameters. The init method can be called separately by passing None. """ self.optax_opt = static(optax_opt) self.state = Param(None) self.filter = static(None) if parameters is not None: self.init(parameters)
[docs] def step( self, module: PyTree, grads: PyTree, scale_by_batch_size: bool = False, apply_updates: bool = True, mul: float = None ) -> None: """Performs a gradient update step similarly to Pytorch's 'optimizer.step()' by calling first 'optax_opt.update' and then 'eqx.apply_updates'. Args: module (PyTree): the module storing the target parameters. grads (PyTree): the computed gradients to apply. Provided gradients must match the same structure of the module used to initialise the optimizer. """ # Filter out the Params that do not have a gradient (this, for example, includes all the StaticParam whose # current state my differ from the original parameter structure and thus be incomparible with the gradients # structure). By doing so, and by enforcing the basic requirement that the gradients for all target paramemters # are provided, we can safely assume that the filtered gradients structure matches the filtered target # parameters structure. # For example 'grads' could contain gradients computed for parameters not targeted by this optimiser without # causing any issue since they will be filtered out automatically. module = eqx.filter(module, self.filter.get(), is_leaf=lambda x: isinstance(x, BaseParam)) if scale_by_batch_size is True: grads = jtu.tree_map( lambda x, f: x.set(x * x.shape[0]) if f is True else None, grads, self.filter.get(), is_leaf=lambda x: isinstance(x, BaseParam), ) elif mul is not None: grads = jtu.tree_map( lambda x, f: x.set(x * mul) if f is True else None, grads, self.filter.get(), is_leaf=lambda x: isinstance(x, BaseParam), ) else: grads = eqx.filter(grads, self.filter.get(), is_leaf=lambda x: isinstance(x, BaseParam)) updates, state = self.optax_opt.update( grads, self.state.get(), module, ) self.state.set(state) if apply_updates: self.apply_updates(module, updates) return updates
[docs] def apply_updates(self, module: PyTree, updates: PyTree) -> None: """Applies the updates to the module parameters. Args: module (PyTree): the module storing the target parameters. updates (PyTree): the updates to apply. Provided updates must match the same structure of the module used to initialise the optimizer. """ jtu.tree_map( lambda u, p: set(p, eqx.apply_updates(get(p), get(u))), updates, module, is_leaf=lambda x: isinstance(x, BaseParam), )
[docs] def init(self, parameters: PyTree) -> None: # We compute a static filter identifying the parameters given to be optimised. This is useful to filter out # he remaining parameters and allow them to change structure without affecting the functioning of the # optimizer. self.filter.set( jtu.tree_map(lambda x: get(x) is not None, parameters, is_leaf=lambda x: isinstance(x, BaseParam)) ) parameters = eqx.filter(parameters, self.filter.get(), is_leaf=lambda x: isinstance(x, BaseParam)) self.state.set(self.optax_opt.init(parameters))
[docs] def clear(self) -> None: self.state.set(None) self.filter.set(None)