pcax.predictive_coding package#

Module contents#

class pcax.predictive_coding.EnergyModule[source]#

Bases: Module

Module inherits from core.Module. he extra ‘_status’ attribute can be used to configure the behavior of its methods. In particular, it used by ‘preditive_coding.Vode’.

__init__() None[source]#

Module constructor.

clear_params(filter: Callable[[Any], bool] | Type) None[source]#

Set the selected parameters to None. This is especially useful to clear the cache of the parameters when needed. Note that, being pcax an imperative library, the change is done in-place and no updated module is returned.

Parameters:

filter (Callable[[Any], bool] | Type) – filter function or type identifying the parameters to clear.

energy() Array[source]#

Return the total energy of the module as the recursive sum of all the energies of its submodules. Note that differently from the Vodes, the energy is not cached.

Returns:

total energy of the module.

Return type:

jax.Array

property status: Any#
class pcax.predictive_coding.Ruleset(rules: Dict[str, Sequence[str]], tforms: Dict[str, Callable[[Vode, str, Array | None, RandomKeyGenerator], Array | None]] = {})[source]#

Bases: BaseModule

A Ruleset transform input and output values of a vode according to the specified rules. A set of rules can be specified as a tuple of either input (i.e., ‘target <- key:transformation’) or output (i.e., ‘key -> target:transformation’) rules. Each set of rules is associated to a set of statuses that matches the current status of the Vode. The matching is determined by the regular expression pattern specified for each set of rules (i.e., ‘.*’: (rule1, rule2) would apply the two rules to any status). If multiple input rules match the current status and operation, the are all executed in the order they are specified. If multiple output rules match the current status and operation, only the first one is executed.

__init__(rules: Dict[str, Sequence[str]], tforms: Dict[str, Callable[[Vode, str, Array | None, RandomKeyGenerator], Array | None]] = {})[source]#

Ruleset constructor.

Parameters:
  • rules (Dict[str, Sequence[str]]) – dictionary of set of rules, where each key is a regular expression to match the current status of the Vode, and each value is a sequence of rules to apply.

  • (Dict[str (tforms) – optional): custom transformations provided to the ruleset.

  • Callable[['Vode' – optional): custom transformations provided to the ruleset.

  • str – optional): custom transformations provided to the ruleset.

  • None (jax.Array |) – optional): custom transformations provided to the ruleset.

  • RandomKeyGenerator] – optional): custom transformations provided to the ruleset.

  • None]]] (jax.Array |) – optional): custom transformations provided to the ruleset.

:param : optional): custom transformations provided to the ruleset.

apply_get_transformation(node: ~pcax.predictive_coding._vode.Vode, tform: str, key: str, rkg: ~pcax.core._random.RandomKeyGenerator = (RandomKeyGenerator):   .key: RKGState([2], uint32)) Any | None[source]#

Recursively apply the transformation specified by the given tform to the given value.

Parameters:
  • node (Vode) – target vode.

  • tform (str) – sequence of “:”-separated transformations to apply to the value.

  • key (str) – input key.

  • value (Any | None, optional) – input value.

  • rkg (RandomKeyGenerator, optional) – random key generator. Defaults to RKG.

Returns:

transformed value.

Return type:

Any | None

apply_set_transformation(node: ~pcax.predictive_coding._vode.Vode, tform: str, key: str, value: ~typing.Any | None = None, rkg: ~pcax.core._random.RandomKeyGenerator = (RandomKeyGenerator):   .key: RKGState([2], uint32)) Any | None[source]#

Recursively apply the transformation specified by the given tform to the given value.

Parameters:
  • node (Vode) – target vode.

  • tform (str) – sequence of “:”-separated transformations to apply to the value.

  • key (str) – input key.

  • value (Any | None, optional) – input value.

  • rkg (RandomKeyGenerator, optional) – random key generator. Defaults to RKG.

Returns:

transformed value.

Return type:

Any | None

filter(status: str | None, rule_pattern: str)[source]#

Filter all the rules that match the current status and the given rule pattern.

Parameters:
  • status (str | None) – the target status to match.

  • rule_pattern (str) – the target rule pattern to match.

Yields:

Tuple[str, str] – the target and transformation of the rule.

class pcax.predictive_coding.STATUS[source]#

Bases: object

List of common statuses used in predictive coding networks. Any string can be used as a status, but these are the most common ones. The ‘STATUS.INIT’ status is used to forward initialise the Vode value ‘h’ with the incoming activation ‘u’ in the default ruleset.

ALL = '.*'#
INIT = 'init'#
NONE = None#
class pcax.predictive_coding.Vode(shape: ~typing.Tuple[int, ...], energy_fn: ~typing.Callable[[~pcax.predictive_coding._vode.Vode, ~pcax.core._random.RandomKeyGenerator], ~jax.Array] = <function se_energy>, ruleset: dict = {}, tforms: dict = {}, param_type: type[~pcax.predictive_coding._parameter.VodeParam] = <class 'pcax.predictive_coding._parameter.VodeParam'>, *param_args, **param_kwargs)[source]#

Bases: EnergyModule

Base and configurable class for Vectorised Nodes. In a predictive coding network, a Vode is any element whose state depends on a particular sample data provided to the network (in contrast with the module weights, which are shared across all samples). In predictive coding, the most common type of Vode is a value node, usually denoted with ‘x’ in the literature. A value node is characterised by an energy function, which calculates the error between the predicted and the actual value of the node.

A Vode offers a series of methods to customise its behavioir, with its default configuration being the one used by Gaussian value nodes. The user can define a custom energy function, a custom set of rules to update the Vode and simply inherits from it do define even more customised behaviour.

__call__(u: ~jax.Array | None, rkg: ~pcax.core._random.RandomKeyGenerator = (RandomKeyGenerator):   .key: RKGState([2], uint32), output='h', **kwargs) Array | Any[source]#

Deep learning layers are typically implemented as callable objects, taking in input the incoming activation and returning the transformed activation. Analogously, a Vode is implemented as a callable object, taking in input the Vode incoming activations (e.g., ‘u’ and/or other values), storing them, and returning the Vode value ‘h’ (the output can be customised by setting the ‘output’ parameter to the desired value). __call__ is equivalent to ‘vode.set(“u”, u).get(“h”)’.

Parameters:
  • u (jax.Array | None) – if provided, it sets the incoming activation ‘u’ to the given value.

  • rkg (RandomKeyGenerator, optional) – random key generator. Defaults to RKG.

  • output (str, optional) – Value to return. Defaults to “h”. If ‘None’, the Vode object is returned.

  • **kwargs – eventual additional activations to set.

Returns:

output value corresponding to the selected output parameter.

Return type:

jax.Array | ‘Vode’

__init__(shape: ~typing.Tuple[int, ...], energy_fn: ~typing.Callable[[~pcax.predictive_coding._vode.Vode, ~pcax.core._random.RandomKeyGenerator], ~jax.Array] = <function se_energy>, ruleset: dict = {}, tforms: dict = {}, param_type: type[~pcax.predictive_coding._parameter.VodeParam] = <class 'pcax.predictive_coding._parameter.VodeParam'>, *param_args, **param_kwargs)[source]#

Vode constructor.

Parameters:
  • shape (Tuple[int, ...]) – shape (not including the batch dimension) of the Vode value. It should match the input activation ‘u’.

  • energy_fn (Callable[['Vode', RandomKeyGenerator], jax.Array], optional) – function used to compute the Vode energy.

  • ruleset (Ruleset, optional) – ruleset specifying the Vode behaviour. The default value indicates that, with status set to ‘STATUS.INIT’, the incoming activation ‘u’ is also saved to the value ‘h’, which corresponds to forward initialisation.

  • param_type (type[VodeParam], optional) – the parameter type of the value ‘h’. Defaults to VodeParam.

  • *param_args – arguments passed to the ‘param_type’ constructor.

  • **param_kwargs

    arguments passed to the ‘param_type’ constructor.

energy(rkg: ~pcax.core._random.RandomKeyGenerator = (RandomKeyGenerator):   .key: RKGState([2], uint32)) Array[source]#

Compute the Vode energy and saves it to the cache, using the key ‘E’. The energy is computed by the energy function provided at construction time. Information about individual samples is preserved and the energy is returned as a vector with shape (batch_size,).

Parameters:

rkg (RandomKeyGenerator, optional) – random key generator. Defaults to RKG.

Returns:

Vode energy

Return type:

jax.Array

get(key: str, default: ~typing.Any | None = None, rkg: ~pcax.core._random.RandomKeyGenerator = (RandomKeyGenerator):   .key: RKGState([2], uint32)) Array | Any | None[source]#

Returns the value of the parameter corresponding to the given key, after being processed by the Vode ruleset. The rule syntax is ‘key -> target:transformation’, where ‘target’ is the name of the parameter to get when key is queried. NOTE: the right-hand side of the rule is also saved to the cache, so subsequent calls to the same key will return the same value without recomputation.

Parameters:
  • key (str) – name of the parameter to get.

  • default (Any | None, optional) – default value to return if the parameter is not found.

  • rkg (RandomKeyGenerator, optional) – random key generator. Defaults to RKG.

Returns:

the value of the parameter corresponding to the given key.

Return type:

jax.Array | Any | None

set(key: str, value: ~jax.Array | None, rkg: ~pcax.core._random.RandomKeyGenerator = (RandomKeyGenerator):   .key: RKGState([2], uint32)) Vode[source]#

Set the value of the parameter corresponding to the given key, after being processed by the Vode ruleset. The rule syntax is ‘target <- key:transformation’, where ‘target’ is the name of the parameter to set (can be a list of comma-separated names), and ‘transformation’ is a string that refers to the name of the transformation to apply to ‘value’ before saving it to the target. The transformation must be a method provided to the ruleset at construction time. A transformation signature is ‘def transformation(vode: Vode, key: str, value: jax.Array | None, rkg: RandomKeyGenerator) -> jax.Array | None’ Transformations can be chained (e.g., ‘h <- u:se:zero’).

Parameters:
  • key (str) – name of the parameter to set. If the parameter is not found, it is stored in the cache.

  • value (jax.Array | None) – value to set.

  • rkg (RandomKeyGenerator, optional) – random key generator. Defaults to RKG.

Returns:

returns itself to allow for chaining.

Return type:

Vode

class pcax.predictive_coding.VodeParam(value: Array | None = None)[source]#

Bases: Param

class Cache(params: Dict[str, Array] = None)[source]#

Bases: ParamDict, ParamCache

pcax.predictive_coding.ce_energy(vode, rkg: ~pcax.core._random.RandomKeyGenerator = (RandomKeyGenerator):   .key: RKGState([2], uint32))[source]#

Cross entropy energy function derived from a categorical distribution.

pcax.predictive_coding.se_energy(vode, rkg: ~pcax.core._random.RandomKeyGenerator = (RandomKeyGenerator):   .key: RKGState([2], uint32))[source]#

Squared error energy function derived from a Gaussian distribution.

pcax.predictive_coding.zero_energy(vode, rkg: ~pcax.core._random.RandomKeyGenerator = (RandomKeyGenerator):   .key: RKGState([2], uint32))[source]#

used to unconstrain the value of a vode from its prior distribution (i.e., input).