Tutorial #2: Randomness in pcax

Tutorial #2: Randomness in pcax#

This is a small notebook showing a couple of details on how randomness is implemented in pcax. JAX provides its own stateless random utilities on which we build to provide a simple interface: pcax.RandomKeyGenerator. By default, pcax offer a globally instantiated pcax.RandomKeyGenerator, px.RKG, which is used by default if no alternative is provided.

[ ]:
import pcax as px
import pcax.nn as pxnn
import jax
import jax.numpy as jnp

# By default, px.RKG is initialised with the system time.
# We set both the global and a custom rkg seed to 0 and show their usage.
px.RKG.seed(0)
custom_RKG = px.RandomKeyGenerator(0)
layer_default = pxnn.Linear(8, 8, True) # by default uses px.RKG
layer_custom = pxnn.Linear(8, 8, True, rkg=custom_RKG)

assert jnp.all(layer_default.nn.weight == layer_custom.nn.weight), "This doesn't fail since both RKGs are initialised with the seed 0."

# Note that pcax functions accept a `pcax.RandomKeyGenerator`, while jax functions require a key,
# which can be obtained as following:
a_key = px.RKG()

Being pcax.RKG globally accessible, it can also be accessed with pcax transformations. This, however, requires its state to be accordingly transformed as well. This happens by automatically adding it to a transformation keyword arguments and applying relevant transformations: - using vmap splits the state in n different states which are mapped over the vmapped dimension. At the end of the function, the vmapped states are discarded and only one is kept, becoming the new pcax.RKG state.

If other behaviours are necessary, it is always possible to pass your own pcax.RandomKeyGenerator via keyword arguments and apply the desired transformations.

[ ]:
import pcax.functional as pxf

@pxf.jit()
@pxf.vmap(in_axes=(0, None, None), out_axes=0)
def vsum(a, min_val, max_val):
    a = a + jax.random.uniform(px.RKG(), a.shape, minval=min_val, maxval=max_val)

    px.RKG.seed(0)

    return a

a = jnp.ones((10, 1))

a_1 = vsum(a, -1.0, 1.0)
a_2 = vsum(a, -1.0, 1.0)

assert jnp.any(a_1 != a_2), "The two arrays should be different since vsum changes the state of the RKG."

key = px.RKG.key.get()
assert jnp.all(key == 0), "The key should be 0, as set inside the vsum function"

print("All good!")

Note how the following uses the same key for all the values along the vmapped dimension, as we do not vmap the custom pcax.RandomKeyGenerator

[ ]:
@pxf.jit()
@pxf.vmap({'rkg': None}, in_axes=(0, None, None), out_axes=0)
def sum_custom(a, min_val, max_val, *, rkg):
    return a + jax.random.uniform(rkg(), a.shape, minval=min_val, maxval=max_val)

a = jnp.ones((10, 1))
a_ = sum_custom(a, -1.0, 1.0, rkg=custom_RKG)

print("All entries of a_ should be the same:")
print(a_)

# Since we use a custom rkg and we do not batch over it, the key state is shared
# over the vmap dimension and all the values produced are the same.
#
# NOTE: this is not something you would probably need normally, so think carefully
# about it if you find yourself using it. For standard use cases, one should simply
# rely on the provided default RKG.
assert jnp.all(a_ == a_[0]), "All the entries in a_ should be the same."

print("All good!")

If we want to vmap a custom RKG, we need to explicitly split and merge the key state other the vmap dimension.

[ ]:
@pxf.jit()
@pxf.vmap({'rkg': 0}, in_axes=(0, None, None), out_axes=0)
def vsum_custom(a, min_val, max_val, *, rkg):
    return a + jax.random.uniform(rkg(), a.shape, minval=min_val, maxval=max_val)

a = jnp.ones((10, 1))
custom_RKG.key.set(custom_RKG.key.split(len(a)))
a_ = vsum_custom(a, -1.0, 1.0, rkg=custom_RKG)
custom_RKG.key.set(custom_RKG.key[0])

print("All entries of a_ should now look random:")
print(a_)