Source code for pcax.nn._layer

__all__ = [
    "Layer",
    "Linear",
    "LayerNorm",
    "Conv",
    "Conv2d",
    "MaxPool2d",
    "AvgPool2d",
]


from typing import Tuple, Sequence

import jax.tree_util as jtu
import equinox as eqx

from ..core._module import Module
from ..core._random import RandomKeyGenerator, RKG
from ..core._parameter import BaseParam
from ..core._static import StaticParam
from ._parameter import LayerParam


########################################################################################################################
#
# LAYER
#
# pcax layers are a thin wrapper around equinox layers that replaces all jax.Arrays with LayerParam instances.
# In this file only stateless layers are implemented as they don't need any particular ad-hoc adaptation.
########################################################################################################################


# Core #################################################################################################################


[docs] class Layer(Module): def __init__( self, cls, *args, filter=eqx._filters.is_array, **kwargs, ): super().__init__() self.nn = jtu.tree_map( lambda w: LayerParam(w) if filter(w) else StaticParam(w), cls(*args, **kwargs), ) def __call__(self, *args, key=None, **kwargs): # Can do this, since nn is stateless _nn = jtu.tree_map( lambda w: w.get() if isinstance(w, BaseParam) else w, self.nn, is_leaf=lambda w: isinstance(w, BaseParam), ) return _nn(*args, **kwargs, key=key)
# Common Layers ########################################################################################################
[docs] class Linear(Layer): def __init__(self, in_features: int, out_features: int, bias: bool = True, rkg: RandomKeyGenerator = RKG): super().__init__(eqx.nn.Linear, in_features, out_features, bias, key=rkg())
[docs] class LayerNorm(Layer): def __init__( self, shape: Tuple[int, ...] | None = None, eps: float = 1e-05, elementwise_affine: bool = True, ): super().__init__(eqx.nn.LayerNorm, shape, eps, elementwise_affine)
[docs] class Conv(Layer): def __init__( self, num_spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int | Sequence[int], stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[Tuple[int, int]] = 0, dilation: int | Sequence[int] = 1, groups: int = 1, use_bias: bool = True, rkg: RandomKeyGenerator = RKG, ): super().__init__( eqx.nn.Conv, num_spatial_dims, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, key=rkg(), )
[docs] class Conv2d(Conv): def __init__( self, in_channels: int, out_channels: int, kernel_size: int | Sequence[int], stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[Tuple[int, int]] = 0, dilation: int | Sequence[int] = 1, groups: int = 1, use_bias: bool = True, rkg: RandomKeyGenerator = RKG, ): super().__init__(2, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, rkg)
# Pooling ##############################################################################################################
[docs] class MaxPool2d(Layer): def __init__( self, kernel_size: int | Sequence[int], stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[Tuple[int, int]] = 0, use_ceil: bool = False, **kwargs, ): super().__init__(eqx.nn.MaxPool2d, kernel_size, stride, padding, use_ceil, **kwargs)
[docs] class AvgPool2d(Layer): def __init__( self, kernel_size: int | Sequence[int], stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[Tuple[int, int]] = 0, use_ceil: bool = False, **kwargs, ): super().__init__(eqx.nn.AvgPool2d, kernel_size, stride, padding, use_ceil, **kwargs)