pcax.nn package#
Module contents#
- class pcax.nn.AvgPool2d(kernel_size: int | Sequence[int], stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[Tuple[int, int]] = 0, use_ceil: bool = False, **kwargs)[source]#
Bases:
Layer
- class pcax.nn.Conv(num_spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int | ~typing.Sequence[int], stride: int | ~typing.Sequence[int] = 1, padding: int | ~typing.Sequence[int] | ~typing.Sequence[~typing.Tuple[int, int]] = 0, dilation: int | ~typing.Sequence[int] = 1, groups: int = 1, use_bias: bool = True, rkg: ~pcax.core._random.RandomKeyGenerator = (RandomKeyGenerator): .key: RKGState([2], uint32))[source]#
Bases:
Layer
- class pcax.nn.Conv2d(in_channels: int, out_channels: int, kernel_size: int | ~typing.Sequence[int], stride: int | ~typing.Sequence[int] = 1, padding: int | ~typing.Sequence[int] | ~typing.Sequence[~typing.Tuple[int, int]] = 0, dilation: int | ~typing.Sequence[int] = 1, groups: int = 1, use_bias: bool = True, rkg: ~pcax.core._random.RandomKeyGenerator = (RandomKeyGenerator): .key: RKGState([2], uint32))[source]#
Bases:
Conv
- class pcax.nn.LayerNorm(shape: Tuple[int, ...] | None = None, eps: float = 1e-05, elementwise_affine: bool = True)[source]#
Bases:
Layer
- class pcax.nn.Linear(in_features: int, out_features: int, bias: bool = True, rkg: ~pcax.core._random.RandomKeyGenerator = (RandomKeyGenerator): .key: RKGState([2], uint32))[source]#
Bases:
Layer
- class pcax.nn.MaxPool2d(kernel_size: int | Sequence[int], stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[Tuple[int, int]] = 0, use_ceil: bool = False, **kwargs)[source]#
Bases:
Layer
Creates a copy of the input pytree which shares all the target parameters with the original. It can be used to create modules with weight sharing:
`python linear1 = Linear(10, 10) linear2 = shared(linear1) `The same can be achieved manually if only a subset of the parameters is to be shared:
```python linear1 = Linear(10, 10) linear2 = Linear(10, 10)
linear2.nn.weight = linear1.nn.weight ```
NOTE #1: jax works exclusively with pytrees, so duplicate references to the same object within a pytree are not allowed. pcax partially solves the problem by allowing multiple references to the same parameter. Thus, it is not possible to share a whole module, but only to create a new one and share the required parameters to it. The following is not allowed:
`python linear1 = Linear(10, 10) linear2 = linear1 # WRONG if used in a pytree. `NOTE #2: flatten/unflatten creates a copy of the tree which, however, references the same leaves. By changing what is considered a leaf, we can control what is copied and what is shared:
if ‘filter = BaseParam’, all parameters are shared.
if ‘filter = DynamicParam’, only dynamic parameters are shared, static parameters are flattened and thus copied.
- Parameters:
module (PyTree) – input pytree to copy.
filter (Callable[[Any], bool] | Type[BaseParam], optional) – filter function or type identifying the parameters to share.
- Returns:
copy of the input pytree with shared parameters.
- Return type:
PyTree