__all__ = ["Optim"]
from jaxtyping import PyTree
import optax
import jax.tree_util as jtu
import equinox as eqx
from ..core._module import BaseModule
from ..core._parameter import Param, BaseParam, set, get
from ..core._static import static
########################################################################################################################
#
# OPTIM
#
# Optim offers a simple interface to the optax library. Being a 'BaseModule' it can be pass through pcax transformations
# with its state being tracked and updated. Note that init can be called anytime to reset the optimizer state. This
# can be helpful when the optimizer is used in a loop and the state needs to be reset at each iteration (for example,
# this may be the case for the vode opimitzer after each mini-batch).
#
# DEV NOTE: currently all the state is stored as a single parameter. This may be insufficient for advanced learning
# techniques that require, for example, differentiating with respect to some of the optimizer's values.
# In such case, this optimizer class should be upgradable to the same design pattern used for 'Layers', which substitues
# each individual weights with a different parameter (which when '.update' is called, can be firstly replaced with their
# values, similarly to the 'Layer.__call__' method).
#
########################################################################################################################
[docs]
class Optim(BaseModule):
"""Optim inherits from core.BaseModule and thus it is a pytree. It is a thin wrapper around the optax library."""
[docs]
def __init__(self, optax_opt: optax.GradientTransformation, parameters: PyTree | None = None):
"""Optim constructor.
Args:
optax_opt (optax.GradientTransformation): the optax constructor function.
parameters (PyTree | None, optional): target parameters. The init method can be called separately by passing
None.
"""
self.optax_opt = static(optax_opt)
self.state = Param(None)
self.filter = static(None)
if parameters is not None:
self.init(parameters)
[docs]
def step(
self, module: PyTree, grads: PyTree, scale_by_batch_size: bool = False, apply_updates: bool = True, mul: float = None
) -> None:
"""Performs a gradient update step similarly to Pytorch's 'optimizer.step()' by calling first 'optax_opt.update'
and then 'eqx.apply_updates'.
Args:
module (PyTree): the module storing the target parameters.
grads (PyTree): the computed gradients to apply. Provided gradients must match the same structure of the
module used to initialise the optimizer.
"""
# Filter out the Params that do not have a gradient (this, for example, includes all the StaticParam whose
# current state my differ from the original parameter structure and thus be incomparible with the gradients
# structure). By doing so, and by enforcing the basic requirement that the gradients for all target paramemters
# are provided, we can safely assume that the filtered gradients structure matches the filtered target
# parameters structure.
# For example 'grads' could contain gradients computed for parameters not targeted by this optimiser without
# causing any issue since they will be filtered out automatically.
module = eqx.filter(module, self.filter.get(), is_leaf=lambda x: isinstance(x, BaseParam))
if scale_by_batch_size is True:
grads = jtu.tree_map(
lambda x, f: x.set(x * x.shape[0]) if f is True else None,
grads,
self.filter.get(),
is_leaf=lambda x: isinstance(x, BaseParam),
)
elif mul is not None:
grads = jtu.tree_map(
lambda x, f: x.set(x * mul) if f is True else None,
grads,
self.filter.get(),
is_leaf=lambda x: isinstance(x, BaseParam),
)
else:
grads = eqx.filter(grads, self.filter.get(), is_leaf=lambda x: isinstance(x, BaseParam))
updates, state = self.optax_opt.update(
grads,
self.state.get(),
module,
)
self.state.set(state)
if apply_updates:
self.apply_updates(module, updates)
return updates
[docs]
def apply_updates(self, module: PyTree, updates: PyTree) -> None:
"""Applies the updates to the module parameters.
Args:
module (PyTree): the module storing the target parameters.
updates (PyTree): the updates to apply. Provided updates must match the same structure of the module used to
initialise the optimizer.
"""
jtu.tree_map(
lambda u, p: set(p, eqx.apply_updates(get(p), get(u))),
updates,
module,
is_leaf=lambda x: isinstance(x, BaseParam),
)
[docs]
def init(self, parameters: PyTree) -> None:
# We compute a static filter identifying the parameters given to be optimised. This is useful to filter out
# he remaining parameters and allow them to change structure without affecting the functioning of the
# optimizer.
self.filter.set(
jtu.tree_map(lambda x: get(x) is not None, parameters, is_leaf=lambda x: isinstance(x, BaseParam))
)
parameters = eqx.filter(parameters, self.filter.get(), is_leaf=lambda x: isinstance(x, BaseParam))
self.state.set(self.optax_opt.init(parameters))
[docs]
def clear(self) -> None:
self.state.set(None)
self.filter.set(None)