pcax.core package#

Module contents#

class pcax.core.BaseModule[source]#

Bases: object

_BaseModule is the base class for all modules in the library.

submodules(*, cls: Type[T] | None = None) Generator[T, None, None][source]#

Return the children submodules of the given type. Does not work recursively, and only returns the direct children of matching type.

Parameters:

cls (Type[T] | None, optional) – indicates the type of the submodules to select. If None, ‘_BaseModule’ is used.

Yields:

Generator[T, None, None] – genereator of the matched submodules.

class pcax.core.BaseParam(value: Array | Any | None = None)[source]#

Bases: object

Base abstract class for all parameters. It is used to detect whether an object is a parameter or not.

__init__(value: Array | Any | None = None)[source]#

_BaseParam constructor.

Parameters:

value – the value (usually a tensor) to wrap. pcax will treat (only!) such values as dynamic.

abstract get()[source]#
abstract set(value)[source]#
class pcax.core.Module[source]#

Bases: BaseModule

Module represents a standard deep learning module with a train/eval mode flag that can be recursively set.

class MODE(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: IntEnum

EVAL = 2#
NONE = 0#
TRAIN = 1#
eval() None[source]#

Set the module in eval mode.

property is_eval: bool#

Returns: bool: whether the module is in eval mode.

property is_train: bool#

Returns: bool: whether the module is in train mode.

mode(value: MODE | None) MODE | None[source]#

Recursively set the mode of the module and its submodules. If the value is None, the current mode is instead returned.

Parameters:

value (MODE | None) – mode to set.

Returns:

current mode if value is None, otherwise None.

Return type:

MODE | None

train() None[source]#

Set the module in train mode.

class pcax.core.Param(value: Array | Any | None = None)[source]#

Bases: DynamicParam

The base class to represent and store a dynamic value in pcax. This is mainly used to wrap jax.Arrays and track them through JAX transformations.

property dtype#

Wrapped value data type.

get() Array[source]#
property ndim#

Number of dimentions of wrapped value.

set(value: Array) Param[source]#
property shape#

Wrapped value shape.

class pcax.core.ParamCache[source]#

Bases: object

A simple sentinel class used to identify all parameters used as a temporary cache.

class pcax.core.ParamDict(value: Dict[str, Array | Any | None] = None)[source]#

Bases: DynamicParam

get(key: str | None = None, default: Array | Any | None = None) Any[source]#
set(value) None[source]#
class pcax.core.RandomKeyGenerator(seed: int = 0)[source]#

Bases: BaseModule

Random number generator module. Provide an stateful interface to generate random keys accessible in the global scope.

__call__(n: int = 1) Tuple[Array | ndarray | bool_ | number | bool | int | float | complex, ...] | Array | ndarray | bool_ | number | bool | int | float | complex[source]#

Generate n random keys.

Parameters:

n (int, optional) – number of keys to generate.

Returns:

a single key if n is 1, otherwise a tuple of keys.

Return type:

Tuple[ArrayLike, …] | ArrayLike

__init__(seed: int = 0)[source]#

RandomKeyGenerator constructor.

Parameters:

seed (int, optional) – initial seed. Defaults to 0.

seed(seed: int = 0)[source]#

Set a new seed.

Parameters:

seed (int, optional) – new seed. Defaults to 0.

pcax.core.get(x: Any | BaseParam) Any[source]#

Return the value encapsulated in the input argument if it is a BaseParam, otherwise return the input argument itself. Used in ambiguous situations to ensure that the input is a value and not a BaseParam.

Parameters:

x (Any | BaseParam) – input argument

Returns:

value encapsulated in the input argument if it is a BaseParam, otherwise the input argument itself.

Return type:

Any

pcax.core.set(obj: Any, x: Any | BaseParam) Any | BaseParam[source]#

Set the value of the input object and returns it if it is a BaseParam, otherwise return the new value itself. Used in ambiguous situations to ensure that the input object is correctly updated.

Returns:

the updated input object if it is a BaseParam, otherwise the new value itself.

Return type:

Any | BaseParam

pcax.core.static(x: Any | StaticParam) StaticParam[source]#

Wraps a value into a StaticParam, making it static and thus compatible with JAX transformations.

Parameters:

x (Any | StaticParam) – value to be wrapped. If x is already a StaticParam, it is returned as is.

Returns:

the static parameter wrapping the input value.

Return type:

StaticParam

pcax.core.tree_apply(fn: Callable[[Any], None], filter_fn: Callable[[Any], bool], tree: PyTree, recursive: bool = True) None[source]#

Executes a function on the selected nodes of the pytree. Note that pydag are supported since the structure of the pytree is preserved (i.e., the function can only modify the content of the nodes, not the nodes themselves). This, however, implies that if a duplicate reference is present in the pytree, the function will be applied to each occurrence of the reference (so multiple times on the same node), which must be taken into account when designing the function. For example:

```python p = Param(1.0)

m = [p, p]

def inc(p):

p += 1

tree_apply(inc, lambda x: isinstance(x, Param), m)

print(m) # [Param(3.0), Param(3.0)] ```

Parameters:
  • fn (Callable[[Any], None]) – function to apply to the selected nodes of the pytree.

  • filter_fn (Callable[[Any], bool]) – filter function to select the nodes on which to apply ‘fn’.

  • tree (PyTree) – input pytree.

  • recursive (bool, optional) – whether to call ‘fn’ recursively or to stop after the first generation of nodes matching ‘filter_fn’ is encountered. Normally is set to False for performance reasons when targeting parameters (that are leaves of the pytree).

pcax.core.tree_extract(pydag: ~jaxtyping.PyTree, *rest: ..., extract_fn: ~typing.Callable[[~typing.Any | ~typing.Tuple[~typing.Any, ...]], ~typing.Any] = <function <lambda>>, filter_fn: ~typing.Callable[[~typing.Any], bool] = <function <lambda>>, is_pytree: bool = False) Sequence[Any][source]#

Extract an ordered sequence of values from the BaseParams of a pytree. Similarly to ‘ref’/’unref’, ‘extract’/’inject’ rely on a consistent structure of the input pytree (i.e., you can only inject into the same pytree structure you extracted from).

Parameters:
  • pydag (PyTree) – input pydag.

  • rest – (…): a tuple of pytrees, each of which has the same structure as tree or has tree as a prefix.

  • extract_fn (Callable[[Any | Tuple[Any, ...]], Any], optional) – function that takes 1 + len(rest) arguments, to be applied at the corresponding leaves of the pytrees.

  • filter_fn (Callable[[Any], bool], optional) – filter function to select the BaseParam on which to apply ‘extract_fn’.

  • is_pytree (bool, optional) – whether the input pydag is a pytree and contains no references; used to avoid unnecessary reffing.

Returns:

list of extracted values.

Return type:

Sequence[Any]

pcax.core.tree_inject(pydag: ~jaxtyping.PyTree, *, params: ~jaxtyping.PyTree = None, values: ~typing.Sequence[~typing.Any] = None, inject_fn: ~typing.Callable[[~typing.Tuple[~typing.Any, ~typing.Any]], None] = <function <lambda>>, filter_fn: ~typing.Callable[[~typing.Any], bool] = <function <lambda>>, is_pytree: bool = False, strict: bool = True) PyTree[source]#

Inverse function of ‘extract’. Note that it doesn’t modify the pydag structure, but rather the values of its BaseParam leaves.

Parameters:
  • pydag (PyTree) – input pydag.

  • values (Sequence[Any]) – input sequence of values to inject into pydag at the selected leaves.

  • inject_fn (Callable[[Tuple[Any, Any]], None], optional) – function that takes the target leaf and previously extracted value to inject into the leaf. Note: the return value is ignored and does not replace the original leaf as in ‘jtu.tree_map’.

  • filter_fn (Callable[[Any], bool], optional) – filter function to select the leaves on which to apply ‘extract_fn’

  • is_pytree (bool, optional) – whether the input pydag is a pytree and contains no references; used to avoid unnecessary reffing.

  • strict (bool, optional) – if True, the number of values must match the number of leaves in the pytree.

Returns:

pytree with values injected via ‘inject_fn’.

Return type:

PyTree

pcax.core.tree_ref(pydag: PyTree) PyTree[source]#

Transforms a pydag in a pytree by replacing all duplicate BaseParams references with explicit indexing. This effectively means that all the occurences, except the first encountered, of each unique parameter are replaced by an integer index wrapped into a _BaseParamRef. This is necessary as jax treats all input/output values of its transformations as pytree, which results in unexpected behaviour when passing in pydags.

NOTE #1: ref has some usage limitations, see unref for a complete overview.

Parameters:

pydag (PyTree) – input pydag

Returns:

output pytree with duplicate BaseParams replaced by explicit references.

Return type:

PyTree

pcax.core.tree_unref(pytree: PyTree) PyTree[source]#

Replace explicit _BaseParamRef with the indexed BaseParam, recreating the original pydag. The most common usage pattern would be the following:

```python def f(pytree):

pydag = unref(pytree) return ref(pydag)

p = pydag(…) t = a_jax_transformation(f) p = unref(t(ref(p))) ```

This is automatically and efficiently done when using automatic parameter tracing via pcax transformations (i.e., passing parameters within the kwargs of a pcax transformation).

NOTE #1: Refernces work via simple indexing, which requires the underlying pydag/pytree structure to be constant between ref and unref (i.e., unref has defined behaviour only if used on a pytree with the same structure as the value returned by ref). For example, the following is not allowed:

`python p = pydag() pytree = ref(p) a, p = unref([Param(), pytree])  # THIS IS WRONG: pytree has not shape [Param(), pytree] `

NOTE #2: Note #1 implies that it is possible to ref an already [partially] reffed pytree, but unreffing must be done in the same (reversed) order:

```python # Example 1 p = unref(unref(ref(ref(p))))

# Example 2 p1 = pydag() tree1 = ref(p1) p2 = [pydag(), tree1] tree2 = ref(p2)

# Here the following is NOT allowed (as the structure of p2[1] may be changed by the second reffing): p1 = unref(p2[1]) # WRONG!

# Instead the following order must be respected: p2 = unref(tree2) tree1 = p2[1] p1 = unref(tree1) ```

NOTE #3: the behaviour of NOTE #2 has not been extensively tested, so not sure which are the exact limitations of the approach.

Parameters:

pytree (PyTree) – input pytree

Returns:

output pydag with resolved references.

Return type:

PyTree