__all__ = [
"scan",
"while_loop",
"cond",
"switch",
"jit",
"vmap",
"value_and_grad",
]
from typing import Any, Hashable, Sequence, Callable
from ._transform import _BaseTransform, Jit, Vmap, ValueAndGrad
from ._flow import Scan, WhileLoop, Cond, Switch
# Flow ###############################################################################################################
[docs]
def scan(
f: _BaseTransform | Callable,
xs: Sequence[Any] | None = None,
length: int | None = None,
reverse: bool = False,
unroll: int | bool = 1,
) -> Scan:
"""Utility function to use the jax.lax.scan syntax for the :class:`~pcax.functional.Scan` transformation."""
return Scan(f, xs=xs, length=length, reverse=reverse, unroll=unroll)
[docs]
def while_loop(
f: _BaseTransform | Callable,
cond_fun: _BaseTransform | Callable,
) -> WhileLoop:
"""Utility function to use the jax.lax.while_loop syntax for the :class:`~pcax.functional.WhileLoop` transformation."""
return WhileLoop(f, cond_fun=cond_fun)
[docs]
def cond(
true_fun: _BaseTransform | Callable,
false_fun: _BaseTransform | Callable,
) -> Cond:
"""Utility function to use the jax.lax.cond syntax for the :class:`~pcax.functional.Cond` transformation."""
return Cond(true_fun, false_fun)
[docs]
def switch(
branches: Sequence[_BaseTransform | Callable],
) -> Switch:
"""Utility function to use the jax.lax.switch syntax for the :class:`~pcax.functional.Switch` transformation."""
return Switch(branches)
# Transform ############################################################################################################
[docs]
def jit(
static_argnums=None,
# static_argnames=None, # this is currently disabled as it hasn't been tested.
donate_argnums=None,
donate_argnames=None,
**kwargs,
):
def decorator(fn: _BaseTransform | Callable):
return Jit(
fn,
static_argnums=static_argnums,
# static_argnames=static_argnames,
donate_argnums=donate_argnums,
donate_argnames=donate_argnames,
**kwargs,
)
return decorator
[docs]
def vmap(
kwargs_mask: Any = {},
in_axes: Sequence[int | None] = (),
out_axes: Sequence[int | None] = (),
axis_name: str | None = None,
):
def decorator(fn: _BaseTransform | Callable):
return Vmap(
fn, kwargs_mask, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name
)
return decorator
[docs]
def value_and_grad(
kwargs_mask: Any = {},
argnums: int | Sequence[int] = (),
has_aux: bool = False,
reduce_axes: Sequence[Hashable] = (),
):
def decorator(fn: _BaseTransform | Callable):
return ValueAndGrad(
fn, kwargs_mask, argnums=argnums, has_aux=has_aux, reduce_axes=reduce_axes
)
return decorator