Source code for pcax.nn._shared

__all__ = ['shared']

from typing import Type, Callable, Any
from jaxtyping import PyTree
from types import UnionType

import jax.tree_util as jtu

from ..core._parameter import BaseParam
from ..core._tree import tree_ref, tree_unref


####################################################################################################
#
# SHARED
#
# Utility to simplify parameter sharing between modules.
#
####################################################################################################


[docs] def shared(module: PyTree, filter: Callable[[Any], bool] | Type[BaseParam] = BaseParam) -> PyTree: """Creates a copy of the input pytree which shares all the target parameters with the original. It can be used to create modules with weight sharing: ```python linear1 = Linear(10, 10) linear2 = shared(linear1) ``` The same can be achieved manually if only a subset of the parameters is to be shared: ```python linear1 = Linear(10, 10) linear2 = Linear(10, 10) linear2.nn.weight = linear1.nn.weight ``` NOTE #1: jax works exclusively with pytrees, so duplicate references to the same object within a pytree are not allowed. pcax partially solves the problem by allowing multiple references to the same parameter. Thus, it is not possible to share a whole module, but only to create a new one and share the required parameters to it. The following is not allowed: ```python linear1 = Linear(10, 10) linear2 = linear1 # WRONG if used in a pytree. ``` NOTE #2: flatten/unflatten creates a copy of the tree which, however, references the same leaves. By changing what is considered a leaf, we can control what is copied and what is shared: - if 'filter = BaseParam', all parameters are shared. - if 'filter = DynamicParam', only dynamic parameters are shared, static parameters are flattened and thus copied. Args: module (PyTree): input pytree to copy. filter (Callable[[Any], bool] | Type[BaseParam], optional): filter function or type identifying the parameters to share. Returns: PyTree: copy of the input pytree with shared parameters. """ # we ref the tree to preserve its structure even in the new copy (if any parameter is copied and # not shared, any reference to it would get duplicated in the new tree). _tree, _structure = jtu.tree_flatten( tree_ref(module), is_leaf=filter if not isinstance(filter, type | UnionType) else lambda x: isinstance(x, filter) ) return tree_unref(jtu.tree_unflatten(_structure, _tree))