Source code for pcax.nn._parameter

__all__ = [
    "LayerParam",
    "LayerState",
]


from typing import Optional
import jax

from ..core._parameter import Param


########################################################################################################################
#
# PARAMETER
#
# We introduce different types of parameters to be used in the layers. This allow the user to distinguish between them.
#
########################################################################################################################


# Core #################################################################################################################


[docs] class LayerParam(Param): def __init__( self, value: Optional[jax.Array] = None ): super().__init__(value)
[docs] class LayerState(Param): def __init__( self, value: Optional[jax.Array] = None ): super().__init__(value)