__all__ = ["tree_extract", "tree_inject", "tree_ref", "tree_unref"]
from typing import Any, Tuple, Sequence, Callable
from jaxtyping import PyTree
import jax.tree_util as jtu
import equinox as eqx
from ..core._parameter import BaseParam, DynamicParam
from ..core._static import StaticParam
########################################################################################################################
#
# TREE
#
# A set of helper functions to simplify the management of stateful pytrees (and pydags) based on 'AbstracParams'.
# See the documentation of each function for more details.
#
########################################################################################################################
# Utils ################################################################################################################
def _cache() -> Callable[[Any], int | None]:
"""Utiliy function to create a fast cache set, which keeps track of the order in which elements are seen.
Usage:
```
cache = _cache()
if (n := cache(x)) is not None:
print(f"{x} was already observed. It is the {n}th unique observed object")
```
Returns:
Callable[[Any], int | None]: callable cache set.
"""
_data = {}
_setdefault = _data.setdefault
_n = 0
def _add(x: Any) -> int | None:
"""Cache 'test and set' function.
Args:
x (Any): element to check against the cache
Returns:
int | None: None if x is a newly seen object; otherwise i,
where i is the number of unique objects seen before x.
"""
nonlocal _n
_r = _setdefault(x, _n)
if _r == _n:
_n += 1
return None
else:
return _r
return _add
class _BaseParamRef(StaticParam):
"""
Class used to replace multiple references to the same BaseParam with a static index
to the unique BaseParam. This allows to transform pydags to pytrees (as duplicate direct
references are replaced with indices).
"""
def __init__(self, n: int) -> None:
"""_BaseParamRef constructor.
Args:
n (int): the index of the referenced parameter as leaf in the flattened pytree it belongs to.
"""
super().__init__(n)
# Core #################################################################################################################
[docs]
def tree_apply(
fn: Callable[[Any], None], filter_fn: Callable[[Any], bool], tree: PyTree, recursive: bool = True
) -> None:
"""Executes a function on the selected nodes of the pytree. Note that pydag are supported since the structure of
the pytree is preserved (i.e., the function can only modify the content of the nodes, not the nodes themselves).
This, however, implies that if a duplicate reference is present in the pytree, the function will be applied to each
occurrence of the reference (so multiple times on the same node), which must be taken into account when designing
the function. For example:
```python
p = Param(1.0)
m = [p, p]
def inc(p):
p += 1
tree_apply(inc, lambda x: isinstance(x, Param), m)
print(m) # [Param(3.0), Param(3.0)]
```
Args:
fn (Callable[[Any], None]): function to apply to the selected nodes of the pytree.
filter_fn (Callable[[Any], bool]): filter function to select the nodes on which to apply 'fn'.
tree (PyTree): input pytree.
recursive (bool, optional): whether to call 'fn' recursively or to stop after the first generation of nodes
matching 'filter_fn' is encountered. Normally is set to False for performance reasons when targeting
parameters (that are leaves of the pytree).
"""
def _wrap_fn(x):
if r := filter_fn(x):
fn(x)
return r
leaves = jtu.tree_leaves(tree, is_leaf=_wrap_fn)
if recursive:
for leaf in leaves:
for x in eqx.tree_flatten_one_level(leaf)[0]:
if x is not tree:
tree_apply(fn, filter_fn, tree=x)
[docs]
def tree_inject(
pydag: PyTree,
*,
params: PyTree = None,
values: Sequence[Any] = None,
inject_fn: Callable[[Tuple[Any, Any]], None] = lambda n, v: n.set(v),
filter_fn: Callable[[Any], bool] = lambda x: isinstance(x, DynamicParam),
is_pytree: bool = False,
strict: bool = True,
) -> PyTree:
"""Inverse function of 'extract'. Note that it doesn't modify the pydag structure, but rather the values
of its BaseParam leaves.
Args:
pydag (PyTree): input pydag.
values (Sequence[Any]): input sequence of values to inject into pydag at the selected leaves.
inject_fn (Callable[[Tuple[Any, Any]], None], optional): function that takes the target leaf and
previously extracted value to inject into the leaf. Note: the return value is ignored and does
not replace the original leaf as in 'jtu.tree_map'.
filter_fn (Callable[[Any], bool], optional): filter function to select the leaves on which to apply 'extract_fn'
is_pytree (bool, optional): whether the input pydag is a pytree and contains no references; used to
avoid unnecessary reffing.
strict (bool, optional): if True, the number of values must match the number of leaves in the pytree.
Returns:
PyTree: pytree with values injected via 'inject_fn'.
"""
assert is_pytree is True, "Not implemented for non-pytrees."
if values is None:
values = filter(filter_fn, jtu.tree_leaves(params, is_leaf=lambda x: isinstance(x, BaseParam)))
else:
assert params is None, "Cannot specify both 'values' and 'params'"
# We use jtu.tree_leaves to apply inject_fn to the dynamically identified leaves of the pytree.
_values_it = iter(values)
def _inject_param(x: Any):
if isinstance(x, BaseParam):
if filter_fn(x):
inject_fn(x, next(_values_it).get())
return True
else:
return False
jtu.tree_leaves(pydag, is_leaf=_inject_param)
if strict is True:
# This is to assert the user didn't mess up with the pytree structure.
try:
next(_values_it)
raise ValueError("The number of values does not match the number of leaves in the pytree.")
except StopIteration:
pass
return pydag
[docs]
def tree_ref(pydag: PyTree) -> PyTree:
"""Transforms a pydag in a pytree by replacing all duplicate BaseParams references with explicit indexing.
This effectively means that all the occurences, except the first encountered, of each unique parameter are replaced
by an integer index wrapped into a _BaseParamRef.
This is necessary as jax treats all input/output values of its transformations as pytree, which results in
unexpected behaviour when passing in pydags.
NOTE #1: ref has some usage limitations, see unref for a complete overview.
Args:
pydag (PyTree): input pydag
Returns:
PyTree: output pytree with duplicate BaseParams replaced by explicit references.
"""
# We use BaseParam and not Param to target also StaticParams and ParamDicts.
_seen = _cache()
def _ref(x):
if isinstance(x, _BaseParamRef):
# We ref a ref to keep the structure consistent and allow multiple reffing/unreffing.
return _BaseParamRef(x)
elif isinstance(x, BaseParam) and ((_ref := _seen(id(x))) is not None):
# If the parameter is already seen, we replace it with a reference.
# _ref is an integer that refers to the index of parameter in the ordered sequence of unique parameters
# encountered in pydag during flattening.
return _BaseParamRef(_ref)
return x
return jtu.tree_map(_ref, pydag, is_leaf=lambda x: isinstance(x, BaseParam))
[docs]
def tree_unref(pytree: PyTree) -> PyTree:
"""Replace explicit _BaseParamRef with the indexed BaseParam, recreating the original pydag.
The most common usage pattern would be the following:
```python
def f(pytree):
pydag = unref(pytree)
return ref(pydag)
p = pydag(...)
t = a_jax_transformation(f)
p = unref(t(ref(p)))
```
This is automatically and efficiently done when using automatic parameter tracing via pcax transformations
(i.e., passing parameters within the kwargs of a pcax transformation).
NOTE #1: Refernces work via simple indexing, which requires the underlying pydag/pytree structure to be constant
between ref and unref (i.e., unref has defined behaviour only if used on a pytree with the same structure as
the value returned by ref). For example, the following is not allowed:
```python
p = pydag()
pytree = ref(p)
a, p = unref([Param(), pytree]) # THIS IS WRONG: pytree has not shape [Param(), pytree]
```
NOTE #2: Note #1 implies that it is possible to ref an already [partially] reffed pytree, but unreffing must be
done in the same (reversed) order:
```python
# Example 1
p = unref(unref(ref(ref(p))))
# Example 2
p1 = pydag()
tree1 = ref(p1)
p2 = [pydag(), tree1]
tree2 = ref(p2)
# Here the following is NOT allowed (as the structure of p2[1] may be changed by the second reffing):
p1 = unref(p2[1]) # WRONG!
# Instead the following order must be respected:
p2 = unref(tree2)
tree1 = p2[1]
p1 = unref(tree1)
```
NOTE #3: the behaviour of NOTE #2 has not been extensively tested, so not sure which are the exact limitations of
the approach.
Args:
pytree (PyTree): input pytree
Returns:
PyTree: output pydag with resolved references.
"""
_seen = []
def _unref(x: Any) -> Any:
if isinstance(x, _BaseParamRef):
# Since a _BaseParamRef can be nested, we check if we reached the final index and, if so, unref.
x = x.get()
return _seen[x] if isinstance(x, int) else x
elif isinstance(x, BaseParam):
_seen.append(x)
return x
return jtu.tree_map(_unref, pytree, is_leaf=lambda x: isinstance(x, BaseParam))