pcax.functional package#
Module contents#
- class pcax.functional._BaseTransform(fn: _BaseTransform | Callable | Sequence[_BaseTransform | Callable])[source]#
Bases:
ABCBase class for all transformations in pcax that wraps a jax transformation. __call__ keeps track of the parameters. Each derived class needs to only define a _t function with signature _t(*args, **kwargs) -> *, params. _t must wrap the target jax transformation and define all the necessary rearrangements of input and output arguments to use it.
- __call__(*args, _is_root: bool = True, **kwargs: Any) Any[source]#
Call the transformed function.
- Parameters:
_is_root (bool, reserved) – used to distinguish between recursive calls.
- Returns:
the transformed output of the wrapped function.
- Return type:
Any
- __init__(fn: _BaseTransform | Callable | Sequence[_BaseTransform | Callable]) None[source]#
_BaseTransform constructor.
- Parameters:
fn (_BaseTransform' | Callable | Sequence['_BaseTransform' | Callable]) – the function (or sequence of function) to which the transformation is applied. As transformation can be composed, fn can be itself an ‘_BaseTransform’.
- static _process_mask(mask: PyTree, kwargs: PyTree, rkg_mask=None) PyTree[source]#
Applies the mask to the given kwargs. If the mask keys are tuples, they are expanded into individual keys. This utility is provided as several jax transformations require a mask to know which jax.Arrays to target.
- Parameters:
mask (PyTree) – a pytree with the same structure as kwargs, or a valid prefix of it, whose leaves are either callable objects or masked value. If a callable object is given, then it is applied to the corresponding kwarg subtree and the result is used as the mask. If a masked value is given, then it is used as mask.
kwargs (PyTree) – keyword arguments to which the mask is applied.
- Returns:
masked keyword arguments.
- Return type:
PyTree
- pcax.functional.cond(true_fun: _BaseTransform | Callable, false_fun: _BaseTransform | Callable) Cond[source]#
Utility function to use the jax.lax.cond syntax for the
Condtransformation.
- pcax.functional.jit(static_argnums=None, donate_argnums=None, donate_argnames=None, **kwargs)[source]#
- pcax.functional.scan(f: _BaseTransform | Callable, xs: Sequence[Any] | None = None, length: int | None = None, reverse: bool = False, unroll: int | bool = 1) Scan[source]#
Utility function to use the jax.lax.scan syntax for the
Scantransformation.
- pcax.functional.switch(branches: Sequence[_BaseTransform | Callable]) Switch[source]#
Utility function to use the jax.lax.switch syntax for the
Switchtransformation.
- pcax.functional.value_and_grad(kwargs_mask: Any = {}, argnums: int | Sequence[int] = (), has_aux: bool = False, reduce_axes: Sequence[Hashable] = ())[source]#
- pcax.functional.vmap(kwargs_mask: Any = {}, in_axes: Sequence[int | None] = (), out_axes: Sequence[int | None] = (), axis_name: str | None = None)[source]#
- pcax.functional.while_loop(f: _BaseTransform | Callable, cond_fun: _BaseTransform | Callable) WhileLoop[source]#
Utility function to use the jax.lax.while_loop syntax for the
WhileLooptransformation.