Tutorial #0: Predictive Coding Networks (PCNs)#
In this notebook we will see how to create and train a simple PCN to classify the two moons dataset.
[ ]:
from typing import Callable
# These are the default import names used in tutorials and documentation.
import jax
import jax.tree_util as jtu
import jax.numpy as jnp
import equinox as eqx
import pcax as px
import pcax.predictive_coding as pxc
import pcax.nn as pxnn
import pcax.functional as pxf
import pcax.utils as pxu
# px.RKG is the default key generator used in pcax, which is used as default
# source of randomness within pcax. Here we set its seed to 0 for more reproducibility.
# By default it is initialised with the system time.
px.RKG.seed(0)
[ ]:
# We create our model, which inherits from pxc.EnergyModule, so to have access to the notion
# energy. The constructor takes in input all the hyperparameters of the model. Being static
# values, if we intend to save them withing the model we must wrap them into a 'StaticParam'.
class Model(pxc.EnergyModule):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
nm_layers: int,
act_fn: Callable[[jax.Array], jax.Array]
) -> None:
super().__init__()
self.act_fn = px.static(act_fn)
self.layers = [pxnn.Linear(input_dim, hidden_dim)] + [
pxnn.Linear(hidden_dim, hidden_dim) for _ in range(nm_layers - 2)
] + [pxnn.Linear(hidden_dim, output_dim)]
# the default ruleset for a Vode is: `{"STATUS.INIT": ("h, u <- u",),}` which means:
# "if the status is set to 'STATUS.INIT', everytime I set 'u', save that value not only
# in 'u', but also in 'x', which is exactly the behvaiour of a forward pass.
# By default if not specified, the behaviour is '* <- *', i.e., save everything passed
# to the vode via __call__ (remember vode(a) equals to vode.set("u", a)).
#
# Since we are doing classification, we replace the last energy with the equivalent of
# cross entropy loss for predictive coding.
self.vodes = [
pxc.Vode((hidden_dim,)) for _ in range(nm_layers - 1)
] + [pxc.Vode((output_dim,), pxc.ce_energy)]
# 'frozen' is not a magic word, we define it here and use it later to distinguish between
# vodes we want to differentiate or not.
# NOTE: any attribute of a Param (except its value) is treated automatically as static,
# no need to specify it (but it's possible if you like more consistency,
# i.e., ...frozen = px.static(True)).
self.vodes[-1].h.frozen = True
def __call__(self, x, y):
for v, l in zip(self.vodes[:-1], self.layers[:-1]):
# remember 'x = v(a)' corresponds to v.set("u", a); x = v.get("x")
#
# note that 'self.act_fn' is a StaticParam, so to access it we would have to do
# self.act_fn.get()(...), however, all standard methods such as __call__ and
# __getitem__ are overloaded such that 'self.act_fn.__***__' becomes
# 'self.act_fn.get().__***__'
x = v(self.act_fn(l(x)))
x = self.vodes[-1](self.layers[-1](x))
if y is not None:
# if the target label is provided (e.g., during training), we save it to the last
# vode. Given that the 'froze' it, its value will not be upadated during inference,
# so we need to fix it only once for each new sample, usually during the init step.
self.vodes[-1].set("h", y)
# at least with this architecture, the input activation of the last vode is the actual
# output of the model ('h' is fixed to the label during training or 'h = u' during eval)
return self.vodes[-1].get("u")
[ ]:
# vmap is used to specify the batch dimension of the input data. Remember jax doesn't handle it
# implicitly but relies on the user to explicitly tell it over which dimension to parallelise the
# computation. That is, we always define a computational graph on a single sample, and then batch
# the computation over the given mini-batch. We use the jax syntax for in_axes, out_axes, axis_name,
# and the introduce a new parameter, kwargs_mask, to specify the batch information over the kwargs
# (which, just as a reminder, have the property of being automatically tracked by pcax).
# pxu.utils.mask has an in-depth explanation about how masking work. Here, we simply use the Mask
# object, which, in this case, replaces every parameter that matches any of the given types with '0',
# meaning that their value is batched over the 0th dimension (which is the case for the vode values
# and caches), and with 'None' the non matching ones (such as the weights, which are shared across
# different samples).
# Both positional input arguments and output are batched over the 0th dimension, so we specify it.
@pxf.vmap(pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=(0, 0), out_axes=0)
def forward(x, y, *, model: Model):
return model(x, y)
# Similarly here, we specify 'out_axes=(None, 0)' since the function returns two values, the first
# a single float storing the total energy of the model (not batched, but summed over the batch
# dimension; this is a requirement of the gradient transformation, which jax requires taking a
# scalar function in input and so a single scalar output). To follow on this, 'axis_name' is specified
# so that we can return the sum over the batch dimension as required (this is standard jax syntax).
@pxf.vmap(pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=(0,), out_axes=(None, 0), axis_name="batch")
def energy(x, *, model: Model):
y_ = model(x, None)
return jax.lax.pmean(model.energy().sum(), "batch"), y_
[ ]:
# JIT is Just In Time compilation, which effectively compiles our code for fast CPU/GPU executioning
# removing all python overhead.
# 'T' is an hyperparameter that determines the number of inferences steps (and therefore the computational flow).
# A such, it must be a static value. We can either specify it using 'static_argnums' (which however is only available
# when using 'jit'), or pass it as a static parameter, in which case we would to 'train_on_batch(px.static(T), ...)'.
#
# Remember that pcax distinguishes between positional and keyword arguments, tracking only the parameters in latter ones.
# Since we don't care about tracking of 'x' and 'y', we pass them as simple jax.Arrays as positional arguments. On the
# other hand, both the model and the optimizers, may have parameters that are going to change and we want to track, so
# we pass them as keyword arguments.
@pxf.jit(static_argnums=0)
def train_on_batch(
T: int,
x: jax.Array,
y: jax.Array,
*,
model: Model,
optim_w: pxu.Optim,
optim_h: pxu.Optim
):
print("Training!") # this will come in handy later
# This only sets an internal flag to be "train" (instead of "eval")
model.train()
# 'pxu.step' is an utility function that does two things:
# - sets the status to the provided one (default is 'None')
# (and resets it to 'None' afterwards);
# - clears the target parameters if clear_params is specified
# (normally we want to clear the vode cache, such as activation and energy,
# after each step).
#
# pxc.STATUS.INIT triggers the only default vode ruleset defined, as
# previously explained.
# Init step
with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
forward(x, y, model=model)
# Inference steps
for _ in range(T):
with pxu.step(model, clear_params=pxc.VodeParam.Cache):
# 'm' is a masking object with a couple of useful methods to create complex masking functions.
# Here, we use it to target all VodeParams that are not forzen (again, frozen is a totally custom
# attribute we, as users, we decided to use above in the model and here).
#
# As jax expects, we distinguish between Parameters to differentiate ('True') and the rest ('False')
#
# 'e', 'y_' are the values returned by the 'energy' function defined above
(e, y_), g = pxf.value_and_grad(
pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]),
has_aux=True
)(energy)(x, model=model)
# the returned gradient has the same structure of the function input. In this case, since we didn't use
# 'argnums' (jax argument of 'value_and_grad'), we only return the gradient with respect to the keyword
# arguments, that can be accessed as a dictionary. If we also had positional arguments gradients, we
# would have 'g = (positional_grad, keyword_grad)', so that, for example, the gradient of 'model' would
# be at 'g[1]["model"]'.
optim_h.step(model, g["model"], True)
# Weight update step
with pxu.step(model, clear_params=pxc.VodeParam.Cache):
(e, y_), g = pxf.value_and_grad(pxu.Mask(pxnn.LayerParam, [False, True]), has_aux=True)(energy)(x, model=model)
optim_w.step(model, g["model"])
[ ]:
import numpy as np
# Not much to say here: we usa a single forward pass to compute the output of the model.
# If we were to use a different initialisation, or a more complex architecture, we would have
# to run inference to converge to some output value.
@pxf.jit()
def eval_on_batch(x: jax.Array, y: jax.Array, *, model: Model):
model.eval()
with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
y_ = forward(x, None, model=model).argmax(axis=-1)
return (y_ == y).mean(), y_
# Standard training loop
def train(dl, T, *, model: Model, optim_w: pxu.Optim, optim_h: pxu.Optim):
for x, y in dl:
train_on_batch(T, x, jax.nn.one_hot(y, 2), model=model, optim_w=optim_w, optim_h=optim_h)
# Standard evaluation loop
def eval(dl, *, model: Model):
acc = []
ys_ = []
for x, y in dl:
a, y_ = eval_on_batch(x, y, model=model)
acc.append(a)
ys_.append(y_)
return np.mean(acc), np.concatenate(ys_)
[ ]:
import optax
batch_size = 32
model = Model(
input_dim=2,
hidden_dim=32,
output_dim=2,
nm_layers=3,
act_fn=jax.nn.leaky_relu
)
[ ]:
# only thing to note here is how optimizers are created. In particular,
# we first want all the parameters of the model to exist, so that the optimizers
# can capture them for optimization. This requires performing a dummy forward pass.
# Note that the batch_size is an hyperparameter of the model and determines, among
# other things, the shape of the Vode parameters, and thus must be kept as much
# constant as possible (each change would trigger ricompilation of the jitted functions).
with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
forward(jax.numpy.zeros((batch_size, 2)), None, model=model)
# 'pxu.Optim' accepts a optax optimizer and the parameters pytree in input. pxu.Mask
# can be used to partition between target parameters and not: when no 'map_to' is
# provided, such as here, it acts as 'eqx.partition', using pxc.VodeParam as filter.
optim_h = pxu.Optim(optax.sgd(1e-2 * batch_size), pxu.Mask(pxc.VodeParam)(model))
optim_w = pxu.Optim(optax.adamw(1e-2), pxu.Mask(pxnn.LayerParam)(model))
[ ]:
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
# this is unrelated to pcax: we generate and display the training set.
nm_elements = 1024
X, y = make_moons(n_samples=batch_size * (nm_elements // batch_size), noise=0.2, random_state=42)
# Plot the dataset
plt.figure(figsize=(6, 4))
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='viridis', edgecolor='k')
plt.title("Two Moons Dataset")
plt.show()
[ ]:
# we split the dataset in training batches and do the same for the generated test set.
train_dl = list(zip(X.reshape(-1, batch_size, 2), y.reshape(-1, batch_size)))
X_test, y_test = make_moons(n_samples=batch_size * (nm_elements // batch_size) // 2, noise=0.2, random_state=0)
test_dl = tuple(zip(X_test.reshape(-1, batch_size, 2), y_test.reshape(-1, batch_size)))
[ ]:
import random
nm_epochs = 256 // (nm_elements // batch_size)
# Note how the text "Training!" appears only once. This is because 'train_on_batch' is executed only once,
# and then its compiled equivalent is instead used (which only cares about what happens to jax.Arrays and
# discards all python code).
for e in range(nm_epochs):
random.shuffle(train_dl)
train(train_dl, T=8, model=model, optim_w=optim_w, optim_h=optim_h)
a, y = eval(test_dl, model=model)
# We print the average shift of the first vode during the inference steps. Note that it does not depend on
# the choice for the batch_size (feel free to play around with it, remember to reset the notebook if you
# you change it). This is because we multiply the learning rate of 'optim_h' by the batch_size. This is
# because the total energy is averaged over the batch dimension (as required for the weight updates),
# so we need to scale the learning rate accordingly for the vode updates.
print(f"Epoch {e + 1}/{nm_epochs} - Test Accuracy: {a * 100:.2f}%")
[ ]:
# pcax.utils contains a couple of useful functions to save and load the parameters of a model.
# They allow to define which subset of the parameters to save, and to load them back into the model.
# The default behaviour is to save all the weights (i.e., values contained in 'pxnn.LayerParam') of
# the model and ignore any 'Vode' value.
import os
# We check what is inside the model.
print(model)
# save/load the model
pxu.save_params(model, "model")
pxu.load_params(model, "model")
# Remove the saved model file
os.remove("model.npz")
[ ]:
# Here we evaluate all the grid set as a single batch with 96^2 elements. If we were to directly
# call 'forward' we would get an error as the size of the batch dimension do not agree:
# 'model' contains VodeParams whose batch size is 32 as previously defined, while the function
# input X_grid would have a batch size of 96^2.
# In order to solve the problem, we first clear all VodeParams (replacing them with None, which
# is ignored by jax) so that jax sees a single size for the batch dimension (i.e., 96^2) and works
# without any problem.
model.clear_params(pxc.VodeParam)
# Test the model on the grid of points in the range [-2.5, 2.5]x[-2.5, 2.5]
X_grid = jax.numpy.stack(np.meshgrid(np.linspace(-2.5, 2.5, 96), np.linspace(-2.0, 2.0, 96))).reshape(2, -1).T
with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
y_grid = forward(X_grid, None, model=model).argmax(axis=-1)
plt.figure(figsize=(6, 4))
plt.scatter(X_grid[:, 0], X_grid[:, 1], c=y_grid, cmap='viridis', s=14, marker='o', linewidths=0, alpha=0.2)
plt.scatter(X_test[:, 0], X_test[:, 1], c=y, cmap='viridis', edgecolor='k')
plt.title("Prediction on Two Moons Dataset")
plt.show()