__all__ = [
"STATUS",
"Ruleset",
"Vode",
]
import jax
from typing import Callable, Any, Tuple, Dict, Sequence
import re
from ..core._random import RKG, RandomKeyGenerator
from ..core._parameter import Param
from ..core._module import BaseModule
from ..core._static import static
from ._parameter import VodeParam
from ._energy_module import EnergyModule
from ._energy import se_energy
########################################################################################################################
#
# VODE
#
# Vode is the fundamental building block to transform a deep learning network into a predictive coding network by
# breaking the continous flow of information from input to output layer into independent, stateful blocks.
# The standard usage is 'x = vode(act_fn(layer(x)))'. The behaviour of a Vode can be customised by specifying the
# 'energy_fn' and its 'ruleset'.
#
########################################################################################################################
# Core #################################################################################################################
[docs]
class STATUS:
"""
List of common statuses used in predictive coding networks. Any string can be used as a status, but these
are the most common ones. The 'STATUS.INIT' status is used to forward initialise the Vode value 'h' with the
incoming activation 'u' in the default ruleset.
"""
NONE = None
ALL = ".*"
INIT = "init"
[docs]
class Ruleset(BaseModule):
"""
A Ruleset transform input and output values of a vode according to the specified rules. A set of rules can be
specified as a tuple of either input (i.e., 'target <- key:transformation') or output (i.e.,
'key -> target:transformation') rules. Each set of rules is associated to a set of statuses that matches the
current status of the Vode. The matching is determined by the regular expression pattern specified for each set
of rules (i.e., '.*': (rule1, rule2) would apply the two rules to any status). If multiple input rules match
the current status and operation, the are all executed in the order they are specified. If multiple output rules
match the current status and operation, only the first one is executed.
"""
[docs]
def __init__(
self,
rules: Dict[str, Sequence[str]],
tforms: Dict[str, Callable[["Vode", str, jax.Array | None, RandomKeyGenerator], jax.Array | None]] = {},
):
"""Ruleset constructor.
Args:
rules (Dict[str, Sequence[str]]): dictionary of set of rules, where each key is a regular expression to
match the current status of the Vode, and each value is a sequence of rules to apply.
tforms (Dict[str, Callable[['Vode', str, jax.Array | None, RandomKeyGenerator], jax.Array | None]]],
optional): custom transformations provided to the ruleset.
"""
super().__init__()
self.rules = static(rules)
self.tforms = static(tforms)
[docs]
def filter(self, status: str | None, rule_pattern: str):
"""Filter all the rules that match the current status and the given rule pattern.
Args:
status (str | None): the target status to match.
rule_pattern (str): the target rule pattern to match.
Yields:
Tuple[str, str]: the target and transformation of the rule.
"""
status = status or ""
for _pattern, _rules in self.rules.items():
if re.match(_pattern, status) is None:
continue
for _rule in _rules:
if _match := re.match(rule_pattern, _rule):
yield _match.group(1, 2)
[docs]
class Vode(EnergyModule):
"""
Base and configurable class for Vectorised Nodes. In a predictive coding network, a Vode is any element whose
state depends on a particular sample data provided to the network (in contrast with the module weights, which
are shared across all samples). In predictive coding, the most common type of Vode is a value node, usually
denoted with 'x' in the literature. A value node is characterised by an energy function, which calculates the
error between the predicted and the actual value of the node.
A Vode offers a series of methods to customise its behavioir, with its default configuration being the one used
by Gaussian value nodes. The user can define a custom energy function, a custom set of rules to update the Vode
and simply inherits from it do define even more customised behaviour."""
[docs]
def __init__(
self,
shape: Tuple[int, ...],
energy_fn: Callable[["Vode", RandomKeyGenerator], jax.Array] = se_energy,
ruleset: dict = {},
tforms: dict = {},
param_type: type[VodeParam] = VodeParam,
*param_args,
**param_kwargs,
):
"""Vode constructor.
Args:
shape (Tuple[int, ...]): shape (not including the batch dimension) of the Vode value. It should match
the input activation 'u'.
energy_fn (Callable[['Vode', RandomKeyGenerator], jax.Array], optional): function used to compute the Vode
energy.
ruleset (Ruleset, optional): ruleset specifying the Vode behaviour. The default value indicates that, with
status set to 'STATUS.INIT', the incoming activation 'u' is also saved to the value 'h', which
corresponds to forward initialisation.
param_type (type[VodeParam], optional): the parameter type of the value 'h'. Defaults to VodeParam.
*param_args, **param_kwargs: arguments passed to the 'param_type' constructor.
"""
super().__init__()
self.shape = static(shape)
self.h = param_type(*param_args, **param_kwargs)
self.cache = param_type.Cache()
self.energy_fn = static(energy_fn)
self.ruleset = Ruleset({STATUS.INIT: ("h, u <- u",), **ruleset}, tforms)
[docs]
def __call__(self, u: jax.Array | None, rkg: RandomKeyGenerator = RKG, output="h", **kwargs) -> jax.Array | Any:
"""Deep learning layers are typically implemented as callable objects, taking in input the incoming activation
and returning the transformed activation. Analogously, a Vode is implemented as a callable object, taking in
input the Vode incoming activations (e.g., 'u' and/or other values), storing them, and returning the Vode value
'h' (the output can be customised by setting the 'output' parameter to the desired value).
__call__ is equivalent to 'vode.set("u", u).get("h")'.
Args:
u (jax.Array | None): if provided, it sets the incoming activation 'u' to the given value.
rkg (RandomKeyGenerator, optional): random key generator. Defaults to RKG.
output (str, optional): Value to return. Defaults to "h". If 'None', the Vode object is returned.
**kwargs: eventual additional activations to set.
Returns:
jax.Array | 'Vode': output value corresponding to the selected output parameter.
"""
if u is not None:
self.set("u", u, rkg)
for _k, _v in kwargs.items():
self.set(_k, _v, rkg)
if output is None:
return self
else:
return self.get(output, rkg=rkg)
[docs]
def set(self, key: str, value: jax.Array | None, rkg: RandomKeyGenerator = RKG) -> "Vode":
"""Set the value of the parameter corresponding to the given key, after being processed by the Vode ruleset.
The rule syntax is 'target <- key:transformation', where 'target' is the name of the parameter to set (can be a
list of comma-separated names), and 'transformation' is a string that refers to the name of the transformation
to apply to 'value' before saving it to the target. The transformation must be a method provided to the ruleset
at construction time. A transformation signature is
'def transformation(vode: Vode, key: str, value: jax.Array | None, rkg: RandomKeyGenerator) -> jax.Array | None'
Transformations can be chained (e.g., 'h <- u:se:zero').
Args:
key (str): name of the parameter to set. If the parameter is not found, it is stored in the cache.
value (jax.Array | None): value to set.
rkg (RandomKeyGenerator, optional): random key generator. Defaults to RKG.
Returns:
Vode: returns itself to allow for chaining.
"""
_rule_pattern = f'(.*(?<!\\s))\\s*<-\\s*({key}.*)'
rules = tuple(self.ruleset.filter(self.status, _rule_pattern))
for _targets, _tform in rules:
_value = self.ruleset.apply_set_transformation(self, _tform, _tform.split(":", 1)[0], value, rkg)
for _target in _targets.split(","):
_target = _target.strip()
if hasattr(self, _target) and isinstance((_param := getattr(self, _target)), Param):
_param.set(_value)
else:
self.cache[_target] = _value
if len(rules) == 0:
if hasattr(self, key) and isinstance((_param := getattr(self, key)), Param):
_param.set(value)
else:
self.cache[key] = value
return self
[docs]
def get(self, key: str, default: Any | None = None, rkg: RandomKeyGenerator = RKG) -> jax.Array | Any | None:
"""Returns the value of the parameter corresponding to the given key, after being processed by the Vode ruleset.
The rule syntax is 'key -> target:transformation', where 'target' is the name of the parameter to get when
key is queried. NOTE: the right-hand side of the rule is also saved to the cache, so subsequent calls to the
same key will return the same value without recomputation.
Args:
key (str): name of the parameter to get.
default (Any | None, optional): default value to return if the parameter is not found.
rkg (RandomKeyGenerator, optional): random key generator. Defaults to RKG.
Returns:
jax.Array | Any | None: the value of the parameter corresponding to the given key.
"""
_rule_pattern = f"({key})\\s*->\\s*(.*)"
_rules = tuple(self.ruleset.filter(self.status, _rule_pattern))
if len(_rules) == 0:
if hasattr(self, key) and isinstance((_param := getattr(self, key)), Param):
return _param.get()
else:
return self.cache.get(key, default)
else:
# TODO: use warnings
if len(_rules) > 1:
print(f"WARNING: Multiple output rules matched for key '{key}' in status '{self.status}'.")
(_target, _tform) = _rules[0]
_value = self.ruleset.apply_get_transformation(self, _tform, _target, rkg=rkg)
if ":" in _tform:
self.cache[_tform] = _value
return _value
[docs]
def energy(self, rkg: RandomKeyGenerator = RKG) -> jax.Array:
"""Compute the Vode energy and saves it to the cache, using the key 'E'.
The energy is computed by the energy function provided at construction time.
Information about individual samples is preserved and the energy is returned as a vector
with shape (batch_size,).
Args:
rkg (RandomKeyGenerator, optional): random key generator. Defaults to RKG.
Returns:
jax.Array: Vode energy
"""
if "E" not in self.cache:
_E = self.energy_fn(self, rkg=rkg) if self.energy_fn is not None else 0.0
if self.h.shape == self.shape.get():
# if the shape is the same as the vode shape,
# '.energy' is being called from a vmapped function
# otherwise 'h' would have a an extra dimension (batch)
_E = _E.sum()
else:
# .energy is being called from a non-vmapped function
# we want to preserve the energy information of each element
_E = jax.numpy.reshape(_E, (self.h.shape[0], -1)).sum(axis=1)
self.cache["E"] = _E
return self.cache["E"]