Source code for pcax.core._random
__all__ = [
"RKGState",
"RKG",
"RandomKeyGenerator"
]
from typing import Tuple
from jaxtyping import ArrayLike
import time
import jax
from ._parameter import Param
from ._module import BaseModule
########################################################################################################################
#
# RANDOM
#
# jax requires to keep track of the random generator state. This is done by passing a PRNGKey to every random function.
# pcax offers a random generator class that keeps track of the state and can be used to generate random keys. RKG is
# the default random generator used by pcax.
#
########################################################################################################################
# Utils ################################################################################################################
class RKGState(Param):
"""RKGState is a state parameter that tracks a random generator state."""
def __init__(self, seed: int):
"""RKGState constructor.
Args:
seed (int): the initial seed of the random number generator.
"""
super().__init__(jax.random.PRNGKey(seed))
def seed(self, seed: int) -> None:
"""Sets a new random seed.
Args:
seed (int): the new seed of the random number generator.
"""
self.set(jax.random.PRNGKey(seed))
def split(self, n: int) -> jax.typing.ArrayLike:
"""Generates n new keys, updating the internal state.
Args:
n (int): the number of keys to generate.
Returns:
jax.typing.ArrayLike: a list of n keys.
"""
values = jax.random.split(self.get(), n + 1)
self.set(values[0])
return values[1:]
# Random ###############################################################################################################
[docs]
class RandomKeyGenerator(BaseModule):
"""Random number generator module. Provide an stateful interface to generate random keys accessible in the global
scope."""
[docs]
def __init__(self, seed: int = 0):
"""RandomKeyGenerator constructor.
Args:
seed (int, optional): initial seed. Defaults to 0.
"""
super().__init__()
self.key = RKGState(seed)
[docs]
def seed(self, seed: int = 0):
"""Set a new seed.
Args:
seed (int, optional): new seed. Defaults to 0.
"""
self.key.seed(seed)
[docs]
def __call__(self, n: int = 1) -> Tuple[ArrayLike, ...] | ArrayLike:
"""Generate n random keys.
Args:
n (int, optional): number of keys to generate.
Returns:
Tuple[ArrayLike, ...] | ArrayLike: a single key if n is 1, otherwise a tuple of keys.
"""
_k = self.key.split(n)
# For comodity, return a single key if n is 1
if n == 1:
return _k[0]
else:
return _k
"""Default random generator, globally accessible
Initialize the random generator with a seed based on the current time,
if the user wants to use a different seed, they can call RKG.seed(seed)
"""
RKG = RandomKeyGenerator(time.time_ns())