Source code for pcax.core._module

__all__ = [
    "BaseModule",
    "Module"
]


import abc
import functools
from enum import IntEnum
from typing import Any, Tuple, Generator, TypeVar, Type

import jax
import jax.tree_util as jtu
import equinox as eqx

from ._parameter import DynamicParam
from ._static import static
from ._tree import tree_apply


T = TypeVar("T")


########################################################################################################################
#
# MODULE
#
# as in equinox, each Module is a JAX pytree, and can be used as a container for other Modules and Parameters.
# Thus, each class that inherits from Module is automatically registered as a JAX pytree whose children are the
# attributes of the class. This allows to flatten and unflatten the class as if it were a dictionary.
# For example:
#
# ```python
# class Obj(Module):
#    def __init__(self, x, y):
#       self.x = x
#       self.y = y
#
# obj = Obj(1, 2)
# leaves, structure = jtu.tree_flatten(obj)
# print(leaves)  # (1, 2), structure contains the keys of the attributes "x" and "y"
# ```
########################################################################################################################


# Core #################################################################################################################


class _BaseModuleMeta(abc.ABCMeta):
    """
    Metaclass to register all modules as JAX pytrees so that can be (un)flattened.
    A module is flattened as if it were a dictionary, separating keys and values.
    
    NOTE: as equinox does, it may be necessary to update this code to treat __special__ attributes differently.
    """

    def __new__(mcs, name, bases, dct):
        _cls = super().__new__(mcs, name, bases, dct)

        jax.tree_util.register_pytree_with_keys(
            _cls,
            flatten_func=_BaseModuleMeta.flatten_module,
            flatten_with_keys=_BaseModuleMeta.flatten_module_with_keys,
            unflatten_func=functools.partial(_BaseModuleMeta.unflatten_module, cls=_cls)
        )

        return _cls
    
    @staticmethod
    def flatten_module(module: 'BaseModule') -> Tuple[Tuple[Any, ...], Tuple[str, ...]]:
        return tuple(module.__dict__.values()), tuple(module.__dict__.keys())
    
    @staticmethod
    def flatten_module_with_keys(module: 'BaseModule') -> Tuple[Tuple[Tuple[str, Any], ...], Tuple[str, ...]]:
        return (
            tuple(zip(map(lambda k: jtu.GetAttrKey(k), module.__dict__.keys()), module.__dict__.values())),
            tuple(module.__dict__.keys())
        )
    
    @staticmethod
    def unflatten_module(
        aux_data: Tuple[str, ...], children: Tuple[Any, ...], cls: Type['BaseModule']
    ) -> 'BaseModule':
        _module = object.__new__(cls)
                        
        _module.__dict__ = dict(zip(
            aux_data,
            children
        ))
        
        return _module
            

[docs] class BaseModule(metaclass=_BaseModuleMeta): """ _BaseModule is the base class for all modules in the library. """ def __call__(self): raise NotImplementedError def __repr__(self) -> str: leaves = jtu.tree_leaves_with_path(self, is_leaf=lambda x: isinstance(x, DynamicParam)) return "\n".join(( f"({self.__class__.__name__}):", *(f" {jtu.keystr(key)}: {repr(value)}" for key, value in leaves) ))
[docs] def submodules(self, *, cls: Type[T] | None = None) -> Generator[T, None, None]: """Return the children submodules of the given type. Does not work recursively, and only returns the direct children of matching type. Args: cls (Type[T] | None, optional): indicates the type of the submodules to select. If None, '_BaseModule' is used. Yields: Generator[T, None, None]: genereator of the matched submodules. """ # We target all modules by default. cls = cls or BaseModule _is_leaf_fn = lambda x: isinstance(x, cls) _leaves, _ = eqx.tree_flatten_one_level(self) yield from filter(_is_leaf_fn, jtu.tree_leaves(_leaves, is_leaf=_is_leaf_fn))
# Module ###############################################################################################################
[docs] class Module(BaseModule): """ Module represents a standard deep learning module with a train/eval mode flag that can be recursively set. """
[docs] class MODE(IntEnum): NONE = 0 TRAIN = 1 EVAL = 2
def __init__(self) -> None: self._mode = static(None)
[docs] def mode(self, value: MODE | None) -> MODE | None: """Recursively set the mode of the module and its submodules. If the value is None, the current mode is instead returned. Args: value (MODE | None): mode to set. Returns: MODE | None: current mode if value is None, otherwise None. """ if value is None: return self._mode.get() else: tree_apply(lambda m: m._mode.set(value), lambda x: isinstance(x, Module), self) return
[docs] def train(self) -> None: """Set the module in train mode.""" self.mode(Module.MODE.TRAIN)
[docs] def eval(self) -> None: """Set the module in eval mode.""" self.mode(Module.MODE.EVAL)
@property def is_train(self) -> bool: """Returns: bool: whether the module is in train mode. """ return self._mode.get() == Module.MODE.TRAIN @property def is_eval(self) -> bool: """Returns: bool: whether the module is in eval mode. """ return self._mode.get() == Module.MODE.EVAL