nvflare.app_opt.pt.model_persistence_format_manager module

class PTModelPersistenceFormatManager(data: dict, default_train_conf=None, allow_numpy_conversion=True)[source]

Bases: object

Manage the format for model persistence.

Parameters:
  • data (dict) – either the dictionary mapping variables to values or a dict of dict.

  • default_train_conf (dict, optional) – configuration for train. Defaults to None.

  • allow_numpy_conversion (bool) – If set to True, enables conversion between PyTorch tensors and NumPy arrays. PyTorch tensors will be converted to NumPy arrays during ‘load_model’, and NumPy arrays will be converted to PyTorch tensors during ‘save_model’. Defaults to True.

Raises:

TypeError – when data is not a dictionary

PERSISTENCE_KEY_META_PROPS = 'meta_props'
PERSISTENCE_KEY_MODEL = 'model'
PERSISTENCE_KEY_TRAIN_CONF = 'train_conf'
static get_persist_model_format()[source]
to_model_learnable(exclude_vars) ModelLearnable[source]
to_persistence_dict() dict[source]
update(ml: ModelLearnable)[source]

Update the persistence data with the learned values.

Parameters:

ml (ModelLearnable) – updated information to be merged into existing ModelLearnable

Raises:

ValueError – if the incoming learnable is invalid, if any matching key has a shape mismatch, if a non-empty update has zero compatible matches with the persisted checkpoint, or if the update would introduce keys that do not already exist in the checkpoint after the checkpoint schema has been initialized.

Notes

The persisted checkpoint is the server schema for client updates. Partial updates are supported: learned weights only need to cover a subset of checkpoint keys that the client actually trained. New client keys outside the server schema are rejected. If no persisted checkpoint exists yet, the first non-empty learnable initializes it.