Source code for pcax.functional

__all__ = [
    "scan",
    "while_loop",
    "cond",
    "switch",
    "jit",
    "vmap",
    "value_and_grad",
]

from typing import Any, Hashable, Sequence, Callable

from ._transform import _BaseTransform, Jit, Vmap, ValueAndGrad
from ._flow import Scan, WhileLoop, Cond, Switch


# Flow ###############################################################################################################


[docs] def scan( f: _BaseTransform | Callable, xs: Sequence[Any] | None = None, length: int | None = None, reverse: bool = False, unroll: int | bool = 1, ) -> Scan: """Utility function to use the jax.lax.scan syntax for the :class:`~pcax.functional.Scan` transformation.""" return Scan(f, xs=xs, length=length, reverse=reverse, unroll=unroll)
[docs] def while_loop( f: _BaseTransform | Callable, cond_fun: _BaseTransform | Callable, ) -> WhileLoop: """Utility function to use the jax.lax.while_loop syntax for the :class:`~pcax.functional.WhileLoop` transformation.""" return WhileLoop(f, cond_fun=cond_fun)
[docs] def cond( true_fun: _BaseTransform | Callable, false_fun: _BaseTransform | Callable, ) -> Cond: """Utility function to use the jax.lax.cond syntax for the :class:`~pcax.functional.Cond` transformation.""" return Cond(true_fun, false_fun)
[docs] def switch( branches: Sequence[_BaseTransform | Callable], ) -> Switch: """Utility function to use the jax.lax.switch syntax for the :class:`~pcax.functional.Switch` transformation.""" return Switch(branches)
# Transform ############################################################################################################
[docs] def jit( static_argnums=None, # static_argnames=None, # this is currently disabled as it hasn't been tested. donate_argnums=None, donate_argnames=None, **kwargs, ): def decorator(fn: _BaseTransform | Callable): return Jit( fn, static_argnums=static_argnums, # static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, **kwargs, ) return decorator
[docs] def vmap( kwargs_mask: Any = {}, in_axes: Sequence[int | None] = (), out_axes: Sequence[int | None] = (), axis_name: str | None = None, ): def decorator(fn: _BaseTransform | Callable): return Vmap( fn, kwargs_mask, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name ) return decorator
[docs] def value_and_grad( kwargs_mask: Any = {}, argnums: int | Sequence[int] = (), has_aux: bool = False, reduce_axes: Sequence[Hashable] = (), ): def decorator(fn: _BaseTransform | Callable): return ValueAndGrad( fn, kwargs_mask, argnums=argnums, has_aux=has_aux, reduce_axes=reduce_axes ) return decorator