Source code for pcax.predictive_coding._vode

__all__ = [
    "STATUS",
    "Ruleset",
    "Vode",
]

import jax
from typing import Callable, Any, Tuple, Dict, Sequence
import re

from ..core._random import RKG, RandomKeyGenerator
from ..core._parameter import Param
from ..core._module import BaseModule
from ..core._static import static
from ._parameter import VodeParam
from ._energy_module import EnergyModule
from ._energy import se_energy


########################################################################################################################
#
# VODE
#
# Vode is the fundamental building block to transform a deep learning network into a predictive coding network by
# breaking the continous flow of information from input to output layer into independent, stateful blocks.
# The standard usage is 'x = vode(act_fn(layer(x)))'. The behaviour of a Vode can be customised by specifying the
# 'energy_fn' and its 'ruleset'.
#
########################################################################################################################

# Core #################################################################################################################


[docs] class STATUS: """ List of common statuses used in predictive coding networks. Any string can be used as a status, but these are the most common ones. The 'STATUS.INIT' status is used to forward initialise the Vode value 'h' with the incoming activation 'u' in the default ruleset. """ NONE = None ALL = ".*" INIT = "init"
[docs] class Ruleset(BaseModule): """ A Ruleset transform input and output values of a vode according to the specified rules. A set of rules can be specified as a tuple of either input (i.e., 'target <- key:transformation') or output (i.e., 'key -> target:transformation') rules. Each set of rules is associated to a set of statuses that matches the current status of the Vode. The matching is determined by the regular expression pattern specified for each set of rules (i.e., '.*': (rule1, rule2) would apply the two rules to any status). If multiple input rules match the current status and operation, the are all executed in the order they are specified. If multiple output rules match the current status and operation, only the first one is executed. """
[docs] def __init__( self, rules: Dict[str, Sequence[str]], tforms: Dict[str, Callable[["Vode", str, jax.Array | None, RandomKeyGenerator], jax.Array | None]] = {}, ): """Ruleset constructor. Args: rules (Dict[str, Sequence[str]]): dictionary of set of rules, where each key is a regular expression to match the current status of the Vode, and each value is a sequence of rules to apply. tforms (Dict[str, Callable[['Vode', str, jax.Array | None, RandomKeyGenerator], jax.Array | None]]], optional): custom transformations provided to the ruleset. """ super().__init__() self.rules = static(rules) self.tforms = static(tforms)
[docs] def filter(self, status: str | None, rule_pattern: str): """Filter all the rules that match the current status and the given rule pattern. Args: status (str | None): the target status to match. rule_pattern (str): the target rule pattern to match. Yields: Tuple[str, str]: the target and transformation of the rule. """ status = status or "" for _pattern, _rules in self.rules.items(): if re.match(_pattern, status) is None: continue for _rule in _rules: if _match := re.match(rule_pattern, _rule): yield _match.group(1, 2)
[docs] def apply_set_transformation( self, node: "Vode", tform: str, key: str, value: Any | None = None, rkg: RandomKeyGenerator = RKG ) -> Any | None: """Recursively apply the transformation specified by the given tform to the given value. Args: node (Vode): target vode. tform (str): sequence of ":"-separated transformations to apply to the value. key (str): input key. value (Any | None, optional): input value. rkg (RandomKeyGenerator, optional): random key generator. Defaults to RKG. Returns: Any | None: transformed value. """ if ":" in tform: tform, _t = tform.rsplit(":", 1) value = self.tforms[_t]( node, key, self.apply_set_transformation(node, tform, key, value, rkg), rkg ) return value
[docs] def apply_get_transformation( self, node: "Vode", tform: str, key: str, rkg: RandomKeyGenerator = RKG ) -> Any | None: """Recursively apply the transformation specified by the given tform to the given value. Args: node (Vode): target vode. tform (str): sequence of ":"-separated transformations to apply to the value. key (str): input key. value (Any | None, optional): input value. rkg (RandomKeyGenerator, optional): random key generator. Defaults to RKG. Returns: Any | None: transformed value. """ _value = node.get(tform, None) if _value is None and ":" in tform: tform, _t = tform.rsplit(":", 1) _value = self.tforms[_t]( node, key, self.apply_get_transformation(node, tform, key, rkg), rkg ) return _value
[docs] class Vode(EnergyModule): """ Base and configurable class for Vectorised Nodes. In a predictive coding network, a Vode is any element whose state depends on a particular sample data provided to the network (in contrast with the module weights, which are shared across all samples). In predictive coding, the most common type of Vode is a value node, usually denoted with 'x' in the literature. A value node is characterised by an energy function, which calculates the error between the predicted and the actual value of the node. A Vode offers a series of methods to customise its behavioir, with its default configuration being the one used by Gaussian value nodes. The user can define a custom energy function, a custom set of rules to update the Vode and simply inherits from it do define even more customised behaviour."""
[docs] def __init__( self, shape: Tuple[int, ...], energy_fn: Callable[["Vode", RandomKeyGenerator], jax.Array] = se_energy, ruleset: dict = {}, tforms: dict = {}, param_type: type[VodeParam] = VodeParam, *param_args, **param_kwargs, ): """Vode constructor. Args: shape (Tuple[int, ...]): shape (not including the batch dimension) of the Vode value. It should match the input activation 'u'. energy_fn (Callable[['Vode', RandomKeyGenerator], jax.Array], optional): function used to compute the Vode energy. ruleset (Ruleset, optional): ruleset specifying the Vode behaviour. The default value indicates that, with status set to 'STATUS.INIT', the incoming activation 'u' is also saved to the value 'h', which corresponds to forward initialisation. param_type (type[VodeParam], optional): the parameter type of the value 'h'. Defaults to VodeParam. *param_args, **param_kwargs: arguments passed to the 'param_type' constructor. """ super().__init__() self.shape = static(shape) self.h = param_type(*param_args, **param_kwargs) self.cache = param_type.Cache() self.energy_fn = static(energy_fn) self.ruleset = Ruleset({STATUS.INIT: ("h, u <- u",), **ruleset}, tforms)
[docs] def __call__(self, u: jax.Array | None, rkg: RandomKeyGenerator = RKG, output="h", **kwargs) -> jax.Array | Any: """Deep learning layers are typically implemented as callable objects, taking in input the incoming activation and returning the transformed activation. Analogously, a Vode is implemented as a callable object, taking in input the Vode incoming activations (e.g., 'u' and/or other values), storing them, and returning the Vode value 'h' (the output can be customised by setting the 'output' parameter to the desired value). __call__ is equivalent to 'vode.set("u", u).get("h")'. Args: u (jax.Array | None): if provided, it sets the incoming activation 'u' to the given value. rkg (RandomKeyGenerator, optional): random key generator. Defaults to RKG. output (str, optional): Value to return. Defaults to "h". If 'None', the Vode object is returned. **kwargs: eventual additional activations to set. Returns: jax.Array | 'Vode': output value corresponding to the selected output parameter. """ if u is not None: self.set("u", u, rkg) for _k, _v in kwargs.items(): self.set(_k, _v, rkg) if output is None: return self else: return self.get(output, rkg=rkg)
[docs] def set(self, key: str, value: jax.Array | None, rkg: RandomKeyGenerator = RKG) -> "Vode": """Set the value of the parameter corresponding to the given key, after being processed by the Vode ruleset. The rule syntax is 'target <- key:transformation', where 'target' is the name of the parameter to set (can be a list of comma-separated names), and 'transformation' is a string that refers to the name of the transformation to apply to 'value' before saving it to the target. The transformation must be a method provided to the ruleset at construction time. A transformation signature is 'def transformation(vode: Vode, key: str, value: jax.Array | None, rkg: RandomKeyGenerator) -> jax.Array | None' Transformations can be chained (e.g., 'h <- u:se:zero'). Args: key (str): name of the parameter to set. If the parameter is not found, it is stored in the cache. value (jax.Array | None): value to set. rkg (RandomKeyGenerator, optional): random key generator. Defaults to RKG. Returns: Vode: returns itself to allow for chaining. """ _rule_pattern = f'(.*(?<!\\s))\\s*<-\\s*({key}.*)' rules = tuple(self.ruleset.filter(self.status, _rule_pattern)) for _targets, _tform in rules: _value = self.ruleset.apply_set_transformation(self, _tform, _tform.split(":", 1)[0], value, rkg) for _target in _targets.split(","): _target = _target.strip() if hasattr(self, _target) and isinstance((_param := getattr(self, _target)), Param): _param.set(_value) else: self.cache[_target] = _value if len(rules) == 0: if hasattr(self, key) and isinstance((_param := getattr(self, key)), Param): _param.set(value) else: self.cache[key] = value return self
[docs] def get(self, key: str, default: Any | None = None, rkg: RandomKeyGenerator = RKG) -> jax.Array | Any | None: """Returns the value of the parameter corresponding to the given key, after being processed by the Vode ruleset. The rule syntax is 'key -> target:transformation', where 'target' is the name of the parameter to get when key is queried. NOTE: the right-hand side of the rule is also saved to the cache, so subsequent calls to the same key will return the same value without recomputation. Args: key (str): name of the parameter to get. default (Any | None, optional): default value to return if the parameter is not found. rkg (RandomKeyGenerator, optional): random key generator. Defaults to RKG. Returns: jax.Array | Any | None: the value of the parameter corresponding to the given key. """ _rule_pattern = f"({key})\\s*->\\s*(.*)" _rules = tuple(self.ruleset.filter(self.status, _rule_pattern)) if len(_rules) == 0: if hasattr(self, key) and isinstance((_param := getattr(self, key)), Param): return _param.get() else: return self.cache.get(key, default) else: # TODO: use warnings if len(_rules) > 1: print(f"WARNING: Multiple output rules matched for key '{key}' in status '{self.status}'.") (_target, _tform) = _rules[0] _value = self.ruleset.apply_get_transformation(self, _tform, _target, rkg=rkg) if ":" in _tform: self.cache[_tform] = _value return _value
[docs] def energy(self, rkg: RandomKeyGenerator = RKG) -> jax.Array: """Compute the Vode energy and saves it to the cache, using the key 'E'. The energy is computed by the energy function provided at construction time. Information about individual samples is preserved and the energy is returned as a vector with shape (batch_size,). Args: rkg (RandomKeyGenerator, optional): random key generator. Defaults to RKG. Returns: jax.Array: Vode energy """ if "E" not in self.cache: _E = self.energy_fn(self, rkg=rkg) if self.energy_fn is not None else 0.0 if self.h.shape == self.shape.get(): # if the shape is the same as the vode shape, # '.energy' is being called from a vmapped function # otherwise 'h' would have a an extra dimension (batch) _E = _E.sum() else: # .energy is being called from a non-vmapped function # we want to preserve the energy information of each element _E = jax.numpy.reshape(_E, (self.h.shape[0], -1)).sum(axis=1) self.cache["E"] = _E return self.cache["E"]