__all__ = ["BaseParam", "Param", "ParamDict", "ParamCache", "get", "set"]
import abc
from typing import Tuple, Dict, Any, Type
import functools
import jax
########################################################################################################################
#
# PARAMETER
#
# pcax is inspired by objax (https://github.com/google/objax) and equinox (https://github.com/patrick-kidger/equinox),
# two other JAX libraries. The core idea is that each JAX array is wrapped in a Param object that pcax can track through
# JAX transformations, without the need to respect the strict functional programming paradigm of JAX.
#
########################################################################################################################
# Core #################################################################################################################
class _BaseParamMeta(abc.ABCMeta):
"""
Metaclass to register all parameters in the JAX pytree flatten/unflatten util.
A parameter is flattened by separating its '_value' from its other attributes, which are considered static.
"""
def __new__(mcs, name, bases, dct):
_cls = super().__new__(mcs, name, bases, dct)
jax.tree_util.register_pytree_with_keys(
_cls,
flatten_func=_BaseParamMeta.flatten_parameter,
flatten_with_keys=_BaseParamMeta.flatten_parameter_with_keys,
unflatten_func=functools.partial(_BaseParamMeta.unflatten_parameter, cls=_cls),
)
return _cls
@staticmethod
def flatten_parameter(param: "BaseParam") -> Tuple[Any, Dict[str, Any]]:
_aux_data = dict.copy(param.__dict__)
del _aux_data["_value"]
return (param._value,), _aux_data
@staticmethod
def flatten_parameter_with_keys(param: "BaseParam") -> Tuple[Any, Dict[str, Any]]:
_aux_data = dict.copy(param.__dict__)
del _aux_data["_value"]
return ((jax.tree_util.GetAttrKey("value"), param._value),), _aux_data
@staticmethod
def unflatten_parameter(aux_data: Dict[str, Any], children: Any, *, cls: Type["BaseParam"]) -> "BaseParam":
_param = object.__new__(cls)
_param.__dict__ = dict.copy(aux_data)
_param._value = children[0]
return _param
[docs]
class BaseParam(metaclass=_BaseParamMeta):
"""
Base abstract class for all parameters. It is used to detect whether an object is a parameter or not.
"""
[docs]
def __init__(self, value: jax.Array | Any | None = None):
"""
_BaseParam constructor.
Args:
value: the value (usually a tensor) to wrap. pcax will treat (only!) such values as dynamic.
"""
self._value = value
[docs]
@abc.abstractmethod
def get(self):
raise NotImplementedError()
[docs]
@abc.abstractmethod
def set(self, value):
raise NotImplementedError()
def __bool__(self):
raise TypeError(
"To prevent accidental errors parameters can not be used as Python bool. "
"To check if variable is `None` use `is None` or `is not None` instead."
)
# Parameter ###########################################################################################################
class DynamicParam(BaseParam):
pass
[docs]
class Param(DynamicParam):
"""
The base class to represent and store a dynamic value in pcax.
This is mainly used to wrap jax.Arrays and track them through JAX transformations.
"""
[docs]
def get(self) -> jax.Array:
return self._value
[docs]
def set(self, value: jax.Array) -> "Param":
self._value = value
return self
def __repr__(self):
rvalue = (
f"[{','.join(map(str, self.shape))}], {self.dtype}"
if isinstance(self._value, jax.Array)
else repr(self._value)
)
t = f"{self.__class__.__name__}({rvalue})"
return t
# Python looks up special methods only on classes, not instances. This means
# these methods needs to be defined explicitly rather than relying on
# __getattr__.
def __neg__(self):
return self._value.__neg__() # noqa: E704
def __pos__(self):
return self._value.__pos__() # noqa: E704
def __abs__(self):
return self._value.__abs__() # noqa: E704
def __invert__(self):
return self._value.__invert__() # noqa: E704
def __eq__(self, __other):
return self._value.__eq__(get(__other)) # noqa: E704
def __ne__(self, __other):
return self._value.__ne__(get(__other)) # noqa: E704
def __lt__(self, __other):
return self._value.__lt__(get(__other)) # noqa: E704
def __le__(self, __other):
return self._value.__le__(get(__other)) # noqa: E704
def __gt__(self, __other):
return self._value.__gt__(get(__other)) # noqa: E704
def __ge__(self, __other):
return self._value.__ge__(get(__other)) # noqa: E704
def __add__(self, __other):
return self._value.__add__(get(__other)) # noqa: E704
def __radd__(self, __other):
return self._value.__radd__(get(__other)) # noqa: E704
def __iadd__(self, __other):
self._value = self._value.__add__(get(__other)) # noqa: E704
return self
def __sub__(self, __other):
return self._value.__sub__(get(__other)) # noqa: E704
def __rsub__(self, __other):
return self._value.__rsub__(get(__other)) # noqa: E704
def __isub__(self, __other):
self._value = self._value.__sub__(get(__other)) # noqa: E704
return self
def __mul__(self, __other):
return self._value.__mul__(get(__other)) # noqa: E704
def __rmul__(self, __other):
return self._value.__rmul__(get(__other)) # noqa: E704
def __imul__(self, __other):
self._value = self._value.__mul__(get(__other)) # noqa: E704
return self
def __div__(self, __other):
return self._value.__div__(get(__other)) # noqa: E704
def __rdiv__(self, __other):
return self._value.__rdiv__(get(__other)) # noqa: E704
def __idiv__(self, __other):
self._value = self._value.__div__(get(__other)) # noqa: E704
return self
def __truediv__(self, __other):
return self._value.__truediv__(get(__other)) # noqa: E704
def __rtruediv__(self, __other):
return self._value.__rtruediv__(get(__other)) # noqa: E704
def __floordiv__(self, __other):
return self._value.__floordiv__(get(__other)) # noqa: E704
def __rfloordiv__(self, __other):
return self._value.__rfloordiv__(get(__other)) # noqa: E704
def __divmod__(self, __other):
return self._value.__divmod__(get(__other)) # noqa: E704
def __rdivmod__(self, __other):
return self._value.__rdivmod__(get(__other)) # noqa: E704
def __mod__(self, __other):
return self._value.__mod__(get(__other)) # noqa: E704
def __rmod__(self, __other):
return self._value.__rmod__(get(__other)) # noqa: E704
def __pow__(self, __other):
return self._value.__pow__(get(__other)) # noqa: E704
def __rpow__(self, __other):
return self._value.__rpow__(get(__other)) # noqa: E704
def __matmul__(self, __other):
return self._value.__matmul__(get(__other)) # noqa: E704
def __rmatmul__(self, __other):
return self._value.__rmatmul__(get(__other)) # noqa: E704
def __and__(self, __other):
return self._value.__and__(get(__other)) # noqa: E704
def __rand__(self, __other):
return self._value.__rand__(get(__other)) # noqa: E704
def __or__(self, __other):
return self._value.__or__(get(__other)) # noqa: E704
def __ror__(self, __other):
return self._value.__ror__(get(__other)) # noqa: E704
def __xor__(self, __other):
return self._value.__xor__(get(__other)) # noqa: E704
def __rxor__(self, __other):
return self._value.__rxor__(get(__other)) # noqa: E704
def __lshift__(self, __other):
return self._value.__lshift__(get(__other)) # noqa: E704
def __rlshift__(self, __other):
return self._value.__rlshift__(get(__other)) # noqa: E704
def __rshift__(self, __other):
return self._value.__rshift__(get(__other)) # noqa: E704
def __rrshift__(self, __other):
return self._value.__rrshift__(get(__other)) # noqa: E704
def __round__(self, ndigits=None):
return self._value.__round__(ndigits) # noqa: E704
def __getitem__(self, __idx):
return self._value.__getitem__(__idx)
def __array__(self, dtype=None):
return self._value.__array__(dtype)
def __getattr__(self, __name):
return getattr(self._value, __name)
@property
def dtype(self):
"""Wrapped value data type."""
return self._value.dtype
@property
def shape(self):
"""Wrapped value shape."""
return self._value.shape
@property
def ndim(self):
"""Number of dimentions of wrapped value."""
return self._value.ndim
[docs]
class ParamDict(DynamicParam):
def __init__(self, value: Dict[str, jax.Array | Any | None] = None):
super().__init__(value)
def __getitem__(self, __key: str) -> Any:
return self._value[__key]
def __setitem__(self, __key: str, __value: jax.Array) -> None:
# Clearing a parameter equates setting its _value to None,
# so we need to reset it to an empty dictionary when necessary.
if self._value is None:
self._value = {}
self._value[__key] = __value
def __contains__(self, __key: str) -> bool:
return __key in self._value
[docs]
def get(self, key: str | None = None, default: jax.Array | Any | None = None) -> Any:
return self._value.get(key, default) if key is not None else self._value
[docs]
def set(self, value) -> None:
self._value = value
def __repr__(self):
return f"{self.__class__.__name__}(params={repr(self._value)})"
[docs]
class ParamCache:
"""
A simple sentinel class used to identify all parameters used as a temporary cache.
"""
pass
# Utils ################################################################################################################
[docs]
def get(x: Any | BaseParam) -> Any:
"""Return the value encapsulated in the input argument if it is a BaseParam, otherwise return the input argument
itself. Used in ambiguous situations to ensure that the input is a value and not a BaseParam.
Args:
x (Any | BaseParam): input argument
Returns:
Any: value encapsulated in the input argument if it is a BaseParam, otherwise the input argument itself.
"""
if isinstance(x, BaseParam):
return x.get()
else:
return x
[docs]
def set(obj: Any, x: Any | BaseParam) -> Any | BaseParam:
"""Set the value of the input object and returns it if it is a BaseParam, otherwise return the new value itself.
Used in ambiguous situations to ensure that the input object is correctly updated.
Returns:
Any | BaseParam: the updated input object if it is a BaseParam, otherwise the new value itself.
"""
if isinstance(obj, BaseParam):
obj.set(get(x))
else:
obj = set(x)
return obj