Source code for pcax.utils._serialisation

__all__ = ["save_params", "load_params"]

from types import UnionType
from typing import Any, Callable, Type
from jaxtyping import PyTree
import numpy as np

import jax
import jax.tree_util as jtu

from ..core._parameter import BaseParam
from ..core._tree import _cache
from ..nn._parameter import LayerParam


########################################################################################################################
#
# SERIALIZATION
#
# Utilities to save/load a model.
#
########################################################################################################################


[docs] def save_params( model: PyTree, path: str, filter: Callable[[Any], bool] | Type[BaseParam] = LayerParam ) -> None: """Function to save the parameters of a model to a file. The '.npz' extension is automatically added to the file name. Args: model (PyTree): the model to dump to disk. path (str): the path to the file where to save the model. If the file already exists, it will be overwritten. filter (Callable[[Any], bool] | Type[BaseParam], optional): filter function or type identifying the parameters to save. The default value 'LayerParam' selects all the weights of the layers in the model. """ _filter_fn = (filter if not isinstance(filter, type | UnionType) else lambda x: isinstance(x, filter)) _params = jtu.tree_flatten_with_path( model, is_leaf=_filter_fn )[0] # Cache to check for duplicate parameters. _seen = _cache() _data = {} for key, param in _params: if _filter_fn(param): assert isinstance(param, BaseParam), "Only parameters can be serialized." if _seen(id(param)) is None: _data[jtu.keystr(key)] = param.get() else: _data[jtu.keystr(key)] = None np.savez_compressed(path, **_data)
[docs] def load_params( model: PyTree, path: str, filter: Callable[[Any], bool] | Type[BaseParam] = LayerParam ) -> None: """Function to load the parameters of a model from a file. The '.npz' extension is automatically added to the file name. The model must have the same structure as the one used to save the parameters and must already be initialized: ```python model = Model() load_params(model, "model.npz") ``` Args: model (PyTree): target model. path (str): the path to the file containing the model parameters to load. filter (Callable[[Any], bool] | Type[BaseParam], optional): filter function or type identifying the parameters to save. The default value 'LayerParam' selects all the weights of the layers in the model. Raises: KeyError: the file does not contain all the parameters required by the model. """ path = path if path.endswith(".npz") else f"{path}.npz" _filter_fn = (filter if not isinstance(filter, type | UnionType) else lambda x: isinstance(x, filter)) _loaded_values = np.load(path) _params = jtu.tree_flatten_with_path( model, is_leaf=_filter_fn )[0] for _key, _param in _params: if _filter_fn(_param): _key = jtu.keystr(_key) if _key not in _loaded_values: raise KeyError(f"Parameter '{_key}' not found in the file '{path}'.") elif (_value := _loaded_values[_key]) is not None: _param.set(jax.numpy.array(_value)) _loaded_values.close()