nvflare.app_opt.pt.model_persistence_format_manager module¶
- class PTModelPersistenceFormatManager(data: dict, default_train_conf=None, allow_numpy_conversion=True)[source]¶
Bases:
objectManage the format for model persistence.
- Parameters:
data (dict) – either the dictionary mapping variables to values or a dict of dict.
default_train_conf (dict, optional) – configuration for train. Defaults to None.
allow_numpy_conversion (bool) – If set to True, enables conversion between PyTorch tensors and NumPy arrays. PyTorch tensors will be converted to NumPy arrays during ‘load_model’, and NumPy arrays will be converted to PyTorch tensors during ‘save_model’. Defaults to True.
- Raises:
TypeError – when data is not a dictionary
- PERSISTENCE_KEY_META_PROPS = 'meta_props'¶
- PERSISTENCE_KEY_MODEL = 'model'¶
- PERSISTENCE_KEY_TRAIN_CONF = 'train_conf'¶
- to_model_learnable(exclude_vars) ModelLearnable[source]¶
- update(ml: ModelLearnable)[source]¶
Update the persistence data with the learned values.
- Parameters:
ml (ModelLearnable) – updated information to be merged into existing ModelLearnable