pcax.nn package#

Module contents#

class pcax.nn.AvgPool2d(kernel_size: int | Sequence[int], stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[Tuple[int, int]] = 0, use_ceil: bool = False, **kwargs)[source]#

Bases: Layer

class pcax.nn.Conv(num_spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int | ~typing.Sequence[int], stride: int | ~typing.Sequence[int] = 1, padding: int | ~typing.Sequence[int] | ~typing.Sequence[~typing.Tuple[int, int]] = 0, dilation: int | ~typing.Sequence[int] = 1, groups: int = 1, use_bias: bool = True, rkg: ~pcax.core._random.RandomKeyGenerator = (RandomKeyGenerator):   .key: RKGState([2], uint32))[source]#

Bases: Layer

class pcax.nn.Conv2d(in_channels: int, out_channels: int, kernel_size: int | ~typing.Sequence[int], stride: int | ~typing.Sequence[int] = 1, padding: int | ~typing.Sequence[int] | ~typing.Sequence[~typing.Tuple[int, int]] = 0, dilation: int | ~typing.Sequence[int] = 1, groups: int = 1, use_bias: bool = True, rkg: ~pcax.core._random.RandomKeyGenerator = (RandomKeyGenerator):   .key: RKGState([2], uint32))[source]#

Bases: Conv

class pcax.nn.Layer(cls, *args, filter=<function is_array>, **kwargs)[source]#

Bases: Module

class pcax.nn.LayerNorm(shape: Tuple[int, ...] | None = None, eps: float = 1e-05, elementwise_affine: bool = True)[source]#

Bases: Layer

class pcax.nn.LayerParam(value: Array | None = None)[source]#

Bases: Param

class pcax.nn.LayerState(value: Array | None = None)[source]#

Bases: Param

class pcax.nn.Linear(in_features: int, out_features: int, bias: bool = True, rkg: ~pcax.core._random.RandomKeyGenerator = (RandomKeyGenerator):   .key: RKGState([2], uint32))[source]#

Bases: Layer

class pcax.nn.MaxPool2d(kernel_size: int | Sequence[int], stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[Tuple[int, int]] = 0, use_ceil: bool = False, **kwargs)[source]#

Bases: Layer

pcax.nn.shared(module: ~jaxtyping.PyTree, filter: ~typing.Callable[[~typing.Any], bool] | ~typing.Type[~pcax.core._parameter.BaseParam] = <class 'pcax.core._parameter.BaseParam'>) PyTree[source]#

Creates a copy of the input pytree which shares all the target parameters with the original. It can be used to create modules with weight sharing:

`python linear1 = Linear(10, 10) linear2 = shared(linear1) `

The same can be achieved manually if only a subset of the parameters is to be shared:

```python linear1 = Linear(10, 10) linear2 = Linear(10, 10)

linear2.nn.weight = linear1.nn.weight ```

NOTE #1: jax works exclusively with pytrees, so duplicate references to the same object within a pytree are not allowed. pcax partially solves the problem by allowing multiple references to the same parameter. Thus, it is not possible to share a whole module, but only to create a new one and share the required parameters to it. The following is not allowed:

`python linear1 = Linear(10, 10) linear2 = linear1  # WRONG if used in a pytree. `

NOTE #2: flatten/unflatten creates a copy of the tree which, however, references the same leaves. By changing what is considered a leaf, we can control what is copied and what is shared:

  • if ‘filter = BaseParam’, all parameters are shared.

  • if ‘filter = DynamicParam’, only dynamic parameters are shared, static parameters are flattened and thus copied.

Parameters:
  • module (PyTree) – input pytree to copy.

  • filter (Callable[[Any], bool] | Type[BaseParam], optional) – filter function or type identifying the parameters to share.

Returns:

copy of the input pytree with shared parameters.

Return type:

PyTree