Tutorial #3: Flow in JAX and pcax.#
Since JAX transformations are compiled, we cannot use python control flow, but we must rely on JAX primitives and follow their constraints. Similarly to other transformations, pcax offers custom wraps of JAX flow transformations that automatically track changes to their kwargs.
In particular, only static values can be used with python static flow, so it’s totally fine to have flags within a model to change the overall computation. Remember that everytime a static value is updated, a recompilation is triggered.
[ ]:
import jax
import jax.numpy as jnp
import pcax as px
import pcax.functional as pxf
Example of static flow.
[ ]:
model = {
'x': px.Param(1.0),
'c': px.static(True)
}
@pxf.jit()
def f(*, model):
print("f is being compiled...")
if model['c'].get():
model['x'] += 1.0
else:
model['x'] -= 1.0
f(model=model)
print('x:', model['x'].get())
f(model=model)
print('x:', model['x'].get())
model['c'].set(False)
f(model=model)
print('x:', model['x'].get())
f(model=model)
print('x:', model['x'].get())
# if 'c' is not set as static value it will not work. Try it!
[ ]:
def choice_a(x: jax.Array, *, p: px.Param):
p -= x # NOTE: remember `p = p - x` is wrong, since `p` is a param and `p - x` is casted to a jax.Array
def choice_b(x: jax.Array, *, p: px.Param):
p.set(p * x)
@pxf.jit()
def f(x: jax.Array, c: bool, *, p: px.Param):
# NOTE: c is automatically casted to a dynamic value by JAX, so within the function, it is a 0-dim jax.Array
pxf.cond(choice_a, choice_b)(c, x, p=p)
[ ]:
param = px.Param(jnp.array([1.0]))
x = jnp.array([-2.0])
f(x, True, p = param) # 1.0 - (-2.0) = 3.0
f(x, False, p = param) # 3.0 * (-2.0) = -6.0
assert param.get().item() == -6.0
To apply the operation multiple times, we can use scan. Note that scan requires a static shape (or a fixed length) to know the number of repetitions at compilation time. Furthermore note that for better clarity, we opted to move the scan index ‘i’ before the arguments list in the transformed function (i.e., the signature is f(i, *args, **kwargs) instead of the lax f(carry, i)).
[ ]:
@pxf.jit()
def fix_many_f(x: jax.Array, c: jax.Array, *, p: px.Param):
def f(i, x, *, p):
pxf.cond(choice_a, choice_b)(i, x, p=p)
# NOTE: `jax.lax.scan` requires to always return a tuple made by:
# - the updated tuple of args except the given index 'i' (so we return 'x', '(x,)' would also be fine)
# - 'y', that is the intermediate result we wish to save for the loop. We can simply return 'None'.
return x, None
pxf.scan(f, c)(x, p=p)
param = px.Param(jnp.array([1.0]))
x = jnp.array([-2.0])
c = jnp.array([False, False, True, False, True, True, False, True])
fix_many_f(x, c, p=param)
assert param.get().item() == 18.0
If we want to run something until some arbitrary condition, we can use while_loop. Note we put a counter so there is no risk to run the loop indefinitely.
[ ]:
import numpy as np
@pxf.jit()
def var_many_f(x: jax.Array, *, p: px.Param):
def f(x, count, *, p):
c = jax.random.bernoulli(px.RKG())
pxf.cond(choice_a, choice_b)(c, x, p=p)
# NOTE: `jax.lax.while_loop` requires to always return the updated tuple of args.
return x, count + 1
def loop_cond(x, count, *, p):
# NOTE: 'loop_cond' has access to both args and kwargs.
# We use `.all` to convert the 1-dim to a 0-dim array that can be evaluated as `bool`.
return jnp.all(jnp.logical_and(p > 0.0, count < 3))
return pxf.while_loop(f, loop_cond)(x, 0, p=p)
param = px.Param(jnp.array([1.0]))
x = jnp.array([-2.0])
_, count = var_many_f(x, p=param)
print(param.get(), "steps:", count)
# Note that each iteration we get 50% chance to go negative. So only around 1/8 of the times we
# should get a positive number.
values = []
for i in range(1024):
param = px.Param(jnp.array([1.0]))
var_many_f(x, p=param)
values.append(param.get().item() > 0)
assert np.allclose(np.mean(values), 1/8, atol=0.05)