Source code for pcax.core._tree

__all__ = ["tree_extract", "tree_inject", "tree_ref", "tree_unref"]


from typing import Any, Tuple, Sequence, Callable
from jaxtyping import PyTree

import jax.tree_util as jtu
import equinox as eqx

from ..core._parameter import BaseParam, DynamicParam
from ..core._static import StaticParam


########################################################################################################################
#
# TREE
#
# A set of helper functions to simplify the management of stateful pytrees (and pydags) based on 'AbstracParams'.
# See the documentation of each function for more details.
#
########################################################################################################################

# Utils ################################################################################################################


def _cache() -> Callable[[Any], int | None]:
    """Utiliy function to create a fast cache set, which keeps track of the order in which elements are seen.
    Usage:

    ```
    cache = _cache()

    if (n := cache(x)) is not None:
        print(f"{x} was already observed. It is the {n}th unique observed object")
    ```

    Returns:
        Callable[[Any], int | None]: callable cache set.
    """
    _data = {}
    _setdefault = _data.setdefault
    _n = 0

    def _add(x: Any) -> int | None:
        """Cache 'test and set' function.

        Args:
            x (Any): element to check against the cache

        Returns:
            int | None: None if x is a newly seen object; otherwise i,
                where i is the number of unique objects seen before x.
        """
        nonlocal _n
        _r = _setdefault(x, _n)
        if _r == _n:
            _n += 1

            return None
        else:
            return _r

    return _add


class _BaseParamRef(StaticParam):
    """
    Class used to replace multiple references to the same BaseParam with a static index
    to the unique BaseParam. This allows to transform pydags to pytrees (as duplicate direct
    references are replaced with indices).
    """

    def __init__(self, n: int) -> None:
        """_BaseParamRef constructor.

        Args:
            n (int): the index of the referenced parameter as leaf in the flattened pytree it belongs to.
        """
        super().__init__(n)


# Core #################################################################################################################


[docs] def tree_apply( fn: Callable[[Any], None], filter_fn: Callable[[Any], bool], tree: PyTree, recursive: bool = True ) -> None: """Executes a function on the selected nodes of the pytree. Note that pydag are supported since the structure of the pytree is preserved (i.e., the function can only modify the content of the nodes, not the nodes themselves). This, however, implies that if a duplicate reference is present in the pytree, the function will be applied to each occurrence of the reference (so multiple times on the same node), which must be taken into account when designing the function. For example: ```python p = Param(1.0) m = [p, p] def inc(p): p += 1 tree_apply(inc, lambda x: isinstance(x, Param), m) print(m) # [Param(3.0), Param(3.0)] ``` Args: fn (Callable[[Any], None]): function to apply to the selected nodes of the pytree. filter_fn (Callable[[Any], bool]): filter function to select the nodes on which to apply 'fn'. tree (PyTree): input pytree. recursive (bool, optional): whether to call 'fn' recursively or to stop after the first generation of nodes matching 'filter_fn' is encountered. Normally is set to False for performance reasons when targeting parameters (that are leaves of the pytree). """ def _wrap_fn(x): if r := filter_fn(x): fn(x) return r leaves = jtu.tree_leaves(tree, is_leaf=_wrap_fn) if recursive: for leaf in leaves: for x in eqx.tree_flatten_one_level(leaf)[0]: if x is not tree: tree_apply(fn, filter_fn, tree=x)
[docs] def tree_extract( pydag: PyTree, *rest: ..., extract_fn: Callable[[Any | Tuple[Any, ...]], Any] = lambda x: x, filter_fn: Callable[[Any], bool] = lambda x: isinstance(x, DynamicParam), is_pytree: bool = False, ) -> Sequence[Any]: """Extract an ordered sequence of values from the BaseParams of a pytree. Similarly to 'ref'/'unref', 'extract'/'inject' rely on a consistent structure of the input pytree (i.e., you can only inject into the same pytree structure you extracted from). Args: pydag (PyTree): input pydag. rest: (...): a tuple of pytrees, each of which has the same structure as tree or has tree as a prefix. extract_fn (Callable[[Any | Tuple[Any, ...]], Any], optional): function that takes 1 + len(rest) arguments, to be applied at the corresponding leaves of the pytrees. filter_fn (Callable[[Any], bool], optional): filter function to select the BaseParam on which to apply 'extract_fn'. is_pytree (bool, optional): whether the input pydag is a pytree and contains no references; used to avoid unnecessary reffing. Returns: Sequence[Any]: list of extracted values. """ assert is_pytree is True, "Not implemented for non-pytrees." # We use jtu.tree_map to apply extract_fn to the dynamically identified leaves of the pytree. _values = [] def _map_fn(x, *rest): if filter_fn(x): _values.append(extract_fn(x, *rest)) return x jtu.tree_map(_map_fn, pydag, *rest, is_leaf=lambda x: isinstance(x, BaseParam)) return tuple(_values)
[docs] def tree_inject( pydag: PyTree, *, params: PyTree = None, values: Sequence[Any] = None, inject_fn: Callable[[Tuple[Any, Any]], None] = lambda n, v: n.set(v), filter_fn: Callable[[Any], bool] = lambda x: isinstance(x, DynamicParam), is_pytree: bool = False, strict: bool = True, ) -> PyTree: """Inverse function of 'extract'. Note that it doesn't modify the pydag structure, but rather the values of its BaseParam leaves. Args: pydag (PyTree): input pydag. values (Sequence[Any]): input sequence of values to inject into pydag at the selected leaves. inject_fn (Callable[[Tuple[Any, Any]], None], optional): function that takes the target leaf and previously extracted value to inject into the leaf. Note: the return value is ignored and does not replace the original leaf as in 'jtu.tree_map'. filter_fn (Callable[[Any], bool], optional): filter function to select the leaves on which to apply 'extract_fn' is_pytree (bool, optional): whether the input pydag is a pytree and contains no references; used to avoid unnecessary reffing. strict (bool, optional): if True, the number of values must match the number of leaves in the pytree. Returns: PyTree: pytree with values injected via 'inject_fn'. """ assert is_pytree is True, "Not implemented for non-pytrees." if values is None: values = filter(filter_fn, jtu.tree_leaves(params, is_leaf=lambda x: isinstance(x, BaseParam))) else: assert params is None, "Cannot specify both 'values' and 'params'" # We use jtu.tree_leaves to apply inject_fn to the dynamically identified leaves of the pytree. _values_it = iter(values) def _inject_param(x: Any): if isinstance(x, BaseParam): if filter_fn(x): inject_fn(x, next(_values_it).get()) return True else: return False jtu.tree_leaves(pydag, is_leaf=_inject_param) if strict is True: # This is to assert the user didn't mess up with the pytree structure. try: next(_values_it) raise ValueError("The number of values does not match the number of leaves in the pytree.") except StopIteration: pass return pydag
[docs] def tree_ref(pydag: PyTree) -> PyTree: """Transforms a pydag in a pytree by replacing all duplicate BaseParams references with explicit indexing. This effectively means that all the occurences, except the first encountered, of each unique parameter are replaced by an integer index wrapped into a _BaseParamRef. This is necessary as jax treats all input/output values of its transformations as pytree, which results in unexpected behaviour when passing in pydags. NOTE #1: ref has some usage limitations, see unref for a complete overview. Args: pydag (PyTree): input pydag Returns: PyTree: output pytree with duplicate BaseParams replaced by explicit references. """ # We use BaseParam and not Param to target also StaticParams and ParamDicts. _seen = _cache() def _ref(x): if isinstance(x, _BaseParamRef): # We ref a ref to keep the structure consistent and allow multiple reffing/unreffing. return _BaseParamRef(x) elif isinstance(x, BaseParam) and ((_ref := _seen(id(x))) is not None): # If the parameter is already seen, we replace it with a reference. # _ref is an integer that refers to the index of parameter in the ordered sequence of unique parameters # encountered in pydag during flattening. return _BaseParamRef(_ref) return x return jtu.tree_map(_ref, pydag, is_leaf=lambda x: isinstance(x, BaseParam))
[docs] def tree_unref(pytree: PyTree) -> PyTree: """Replace explicit _BaseParamRef with the indexed BaseParam, recreating the original pydag. The most common usage pattern would be the following: ```python def f(pytree): pydag = unref(pytree) return ref(pydag) p = pydag(...) t = a_jax_transformation(f) p = unref(t(ref(p))) ``` This is automatically and efficiently done when using automatic parameter tracing via pcax transformations (i.e., passing parameters within the kwargs of a pcax transformation). NOTE #1: Refernces work via simple indexing, which requires the underlying pydag/pytree structure to be constant between ref and unref (i.e., unref has defined behaviour only if used on a pytree with the same structure as the value returned by ref). For example, the following is not allowed: ```python p = pydag() pytree = ref(p) a, p = unref([Param(), pytree]) # THIS IS WRONG: pytree has not shape [Param(), pytree] ``` NOTE #2: Note #1 implies that it is possible to ref an already [partially] reffed pytree, but unreffing must be done in the same (reversed) order: ```python # Example 1 p = unref(unref(ref(ref(p)))) # Example 2 p1 = pydag() tree1 = ref(p1) p2 = [pydag(), tree1] tree2 = ref(p2) # Here the following is NOT allowed (as the structure of p2[1] may be changed by the second reffing): p1 = unref(p2[1]) # WRONG! # Instead the following order must be respected: p2 = unref(tree2) tree1 = p2[1] p1 = unref(tree1) ``` NOTE #3: the behaviour of NOTE #2 has not been extensively tested, so not sure which are the exact limitations of the approach. Args: pytree (PyTree): input pytree Returns: PyTree: output pydag with resolved references. """ _seen = [] def _unref(x: Any) -> Any: if isinstance(x, _BaseParamRef): # Since a _BaseParamRef can be nested, we check if we reached the final index and, if so, unref. x = x.get() return _seen[x] if isinstance(x, int) else x elif isinstance(x, BaseParam): _seen.append(x) return x return jtu.tree_map(_unref, pytree, is_leaf=lambda x: isinstance(x, BaseParam))