Source code for pcax.predictive_coding._energy
__all__ = [
"zero_energy",
"se_energy",
"ce_energy"
]
import jax
from ..core._random import RKG, RandomKeyGenerator
########################################################################################################################
#
# Energy
#
# Collection of the most common energy functions used in predictive coding.
#
########################################################################################################################
# Core #################################################################################################################
[docs]
def zero_energy(vode, rkg: RandomKeyGenerator = RKG):
"""used to unconstrain the value of a vode from its prior distribution (i.e., input)."""
return jax.numpy.zeros((1,))
[docs]
def se_energy(vode, rkg: RandomKeyGenerator = RKG):
"""Squared error energy function derived from a Gaussian distribution."""
e = vode.get("h") - vode.get("u")
return 0.5 * (e * e)
[docs]
def ce_energy(vode, rkg: RandomKeyGenerator = RKG):
"""Cross entropy energy function derived from a categorical distribution."""
return -(vode.get("h") * jax.nn.log_softmax(vode.get("u")))