nvflare.app_opt.pt.model_persistence_format_manager module
- class PTModelPersistenceFormatManager(data: dict, default_train_conf=None, allow_numpy_conversion=True)[source]
Bases:
objectManage the format for model persistence.
- Parameters:
data (dict) – either the dictionary mapping variables to values or a dict of dict.
default_train_conf (dict, optional) – configuration for train. Defaults to None.
allow_numpy_conversion (bool) – If set to True, enables conversion between PyTorch tensors and NumPy arrays. PyTorch tensors will be converted to NumPy arrays during ‘load_model’, and NumPy arrays will be converted to PyTorch tensors during ‘save_model’. Defaults to True.
- Raises:
TypeError – when data is not a dictionary
- PERSISTENCE_KEY_META_PROPS = 'meta_props'
- PERSISTENCE_KEY_MODEL = 'model'
- PERSISTENCE_KEY_TRAIN_CONF = 'train_conf'
- to_model_learnable(exclude_vars) ModelLearnable[source]
- update(ml: ModelLearnable)[source]
Update the persistence data with the learned values.
- Parameters:
ml (ModelLearnable) – updated information to be merged into existing ModelLearnable
- Raises:
ValueError – if the incoming learnable is invalid, if any matching key has a shape mismatch, if a non-empty update has zero compatible matches with the persisted checkpoint, or if the update would introduce keys that do not already exist in the checkpoint after the checkpoint schema has been initialized.
Notes
The persisted checkpoint is the server schema for client updates. Partial updates are supported: learned weights only need to cover a subset of checkpoint keys that the client actually trained. New client keys outside the server schema are rejected. If no persisted checkpoint exists yet, the first non-empty learnable initializes it.