from typing import Any, Callable, Tuple, Sequence
from jaxtyping import PyTree
import abc
import inspect
import jax
import jax.tree_util as jtu
import equinox as eqx
from ..core._tree import tree_extract, tree_inject, tree_ref, tree_unref
from ..core._random import RKG
from ..core._parameter import BaseParam
########################################################################################################################
#
# TRANSFORM
#
# pcax keeps track of the changes applied to any Parameter. To achieve so, it introduces its own set of transformations
# which replace the ones provided by jax (such as jit, vmap, ...). Each transformation behaves exactly as its jax
# counterpart and is a simple wrapper that allows for parameter tracking and introduces some very small QOL improvements
# over the jax version.
#
# In particular, the protocol is defined as treating any positional argument as a "pure" jax one (i.e., not tracked),
# while it introduces the possibility to also use keyword arguments for transformations, whose Parameters are instead
# tracked. The most common usage pattern is to pass (stateful) models as keyword arguments, while any simple jax.Array
# as positional argument. For example:
#
# ```
# @vmap(...)
# def eval(x: jax.Array, y: jax.Array, *, model: pcax.Module):
# y_hat = model(x) # if model has any stateful layer such as batch norm, __call__ can safely update it
# return ((y_hat - y)**2).mean() # note that we do not return model, but its parameters are kept updated nontheless.
# ```
#
# NOTE #1: since pcax.Module is a pytree, one can safely pass it as positional argument as well, however its parameters
# will not be tracked.
#
# NOTE #2: pcax transformations automatically deal with multiple references to the same Parameter within the given
# keyword arguments (important: only for Parameters!). This enables parameter sharing.
#
# NOTE #3: all of this, requires each jax transformation to be ported to pcax. The parent class _BaseTransformation
# significantly automates this process but ad-hoc fixes are necessary for each function. In particular each pcax
# transformation must inherit from it and define its own _t function that, given the function self.fn with signature
#
# ```
# def fn(*args, **kwargs) -> *, params
# ```
#
# has the following signature:
#
# ```
# def _t(*args, **kwargs) -> *, params
# ```
#
# The standard way to do so is to define a _wrap_fn to wrap self.fn and rearrange input and output arguments to match
# the required structure.
#
# NOTE #4: some transformations require a kwargs_mask, which is an extra argument compared to the jax interface. It is
# necessary to specify which parameters in the kwargs must be affected by the transformation. Check
# '_AbstractTransform._process_mask' for a more detailed explanation.
#
########################################################################################################################
# Utils ################################################################################################################
def _make_tuple(x: Any):
return x if isinstance(x, tuple) else (x,)
def _repr_function(f: Callable) -> str:
"""Human readable function representation."""
_signature = inspect.signature(f)
_args = [f"{k}={v.default}" for k, v in _signature.parameters.items() if v.default is not inspect.Parameter.empty]
_args = ", ".join(_args)
while not hasattr(f, "__name__"):
if not hasattr(f, "func"):
break
f = f.func
if not hasattr(f, "__name__") and hasattr(f, "__class__"):
return f.__class__.__name__
if _args:
return f"{f.__name__}(*, {_args})"
return f.__name__
# Core #################################################################################################################
class Jit(_BaseTransform):
"""
Wrap around jax.jit(fn, ...).
Uses 'tree_extract' to return a list of parameters instead of a complex pytree.
This is used to reduce the overhead of injecting the new values back into the
original kwargs outside of the "jit barrier".
"""
def __init__(self, fn: "_BaseTransform" | Callable, **t_kwargs: Any):
super().__init__(fn)
def _wrap_fn(*args, **kwargs):
_r, _kwargs = self.fn(*args, **kwargs)
return _r, tree_extract(_kwargs, is_pytree=True)
self.wrap_fn = jax.jit(_wrap_fn, **t_kwargs)
def _t(self, *args, **kwargs):
_r, kwargs = self.wrap_fn(*args, **kwargs)
return _r, kwargs
class ValueAndGrad(_BaseTransform):
"""
Wrap around jax.value_and_grad(fn, ...).
kwargs_mask must specify whether each leaf is to be differentiated (True) or not (False).
NOTE #1: the optimizer class provided by pcax assumes that each 'BaseParam' is a leaf of the optimized pytree,
this implies that the mask must provide a value at the 'BaseParam' level and not for its value. For example:
```python
model = [Param(jax.numpy.array([1.0])), Param(jax.numpy.array([2.0]))]
mask = [True, False] # this is correct, [Param(True), Param(False)] would be wrong.
```
This is the behavior of 'Mask' which replaces each parameter (and not the value of the parameters) with the given
values.
NOTE #2: #1 implies that it is not possible to differentiate the behavior for different 'jax.Arrays' within a single
parameter (in general, the value of a parameter can be anything, and it is not limited to a single 'jax.Array', see
for example the 'ParamDict' class). So please keep this in mind when using 'ValueAndGrad' (or any other transform
that requires a mask) and structure your models accordingly (i.e., different parameters for different behaviours).
NOTE #3: the assumption in #1 is to simplify the implementation of the optimizer class and in general it seemed
more intuitive and less error prone (i.e., masking substitutes the whole parameter object, and not only its value),
and it also required to completly discard static parameters (which, otherwise, would clutter the mask with
unnecessary static information which could cause pytree incompatibilities down the line, which is what would indeed
happen in the optimizer). Furthermore, it allows to deal with unreffed pytrees (maybe, to be verfied with the
behaviour of each jax transformation). However, the assumption it is not a strict requirement and the code can
easily be redesigned to allow for masking of the value if deemed to be a necessary feature.
"""
def __init__(self, fn: "_BaseTransform" | Callable, kwargs_mask: Any = {}, **t_kwargs: Any):
super().__init__(fn)
self.kwargs_mask = kwargs_mask
self.has_aux = t_kwargs["has_aux"]
t_kwargs["has_aux"] = True
self.t_kwargs = t_kwargs
def _t(self, *args, **kwargs):
def _wrap_fn(*args):
_args, _target_kwargs, _other_kwargs = args[:-2], args[-2], args[-1]
_kwargs = eqx.combine(_target_kwargs, _other_kwargs, is_leaf=lambda x: isinstance(x, BaseParam))
_r, _kwargs = self.fn(*_args, **_kwargs)
_r = _make_tuple(_r)
return _r[0], (_r[1:], _kwargs)
(_l, (_r, _values)), _aux = jax.value_and_grad(
_wrap_fn, **{**self.t_kwargs, "argnums": self.t_kwargs.get("argnums", ()) + (len(args),)}
)(
*args,
# We split kwargs to isolate the parameters we want to differentiate, following the jax syntax.
*eqx.partition(
kwargs,
# we pass 'False' as rkg mask to not take its gradient.
self._process_mask(self.kwargs_mask, kwargs, False),
is_leaf=lambda x: isinstance(x, BaseParam),
),
)
if self.t_kwargs.get("argnums", ()) != ():
_aux = (_aux[:-1], _aux[-1])
else:
_aux = _aux[0]
if self.has_aux:
return (((_l, _r), _aux), _values)
else:
return ((_l, _aux), _values)
class Vmap(_BaseTransform):
"""
Wrap around jax.vmap(fn, ...).
kwargs_mask must specify whether each leaf is vectorised or not. It is assumed that the behaviour for
each leaf is the same for both input and output (this could be changed by providing an 'out_kwargs_mask'
as well as a 'in_kwargs_mask'). Both 'in_axes' and 'out_axes' must be provided in a jax supported format.
NOTE: RKG is automatically handled by the transformation, so it must not be provided in the kwargs.
"""
def __init__(self, fn: "_BaseTransform" | Callable, kwargs_mask: Any = {}, **t_kwargs: Any):
super().__init__(fn)
self.kwargs_mask = kwargs_mask
self.t_kwargs = t_kwargs
def _t(self, *args, **kwargs):
_kwargs_mask = self._process_mask(self.kwargs_mask, kwargs)
_in_axes_mask = _make_tuple(self.t_kwargs.get("in_axes", ())) + (_kwargs_mask,)
# Compute vaxes dimension which is necessary to split the RKG key.
def _extract_vaxes_dim(node, mask):
for param in filter(lambda _node: hasattr(_node, "shape"), jtu.tree_leaves(node)):
return param.shape[mask]
return None
_vaxis_dim = jtu.tree_leaves(
jtu.tree_map(lambda mask, node: _extract_vaxes_dim(node, mask), _in_axes_mask, (*args, kwargs))
)[0]
# Split the __RKG key over the vmap axis (and set the mask accordingly)
_in_axes_mask[-1]["__RKG"] = 0
kwargs["__RKG"].key.set(kwargs["__RKG"].key.split(_vaxis_dim))
def _wrap_fn(*args):
*_args, _kwargs = args
_r, _kwargs = self.fn(*_args, **_kwargs)
return _r, _kwargs
_r, kwargs = jax.vmap(
_wrap_fn,
**{
**self.t_kwargs,
"in_axes": _in_axes_mask,
"out_axes": (self.t_kwargs.get("out_axes", None), _kwargs_mask),
},
)(*args, kwargs)
# Merge back the key value to remove the vmap axis before returning it;
# it will automatically be injected back into the global RKG (being it a kwarg)
kwargs["__RKG"].key.set(kwargs["__RKG"].key[0])
return _r, kwargs