nvflare.app_opt.pt.model_persistence_format_manager module

class PTModelPersistenceFormatManager(data: dict, default_train_conf=None, allow_numpy_conversion=True)[source]

Bases: object

Manage the format for model persistence.

Parameters:
  • data (dict) – either the dictionary mapping variables to values or a dict of dict.

  • default_train_conf (dict, optional) – configuration for train. Defaults to None.

  • allow_numpy_conversion (bool) – If set to True, enables conversion between PyTorch tensors and NumPy arrays. PyTorch tensors will be converted to NumPy arrays during ‘load_model’, and NumPy arrays will be converted to PyTorch tensors during ‘save_model’. Defaults to True.

Raises:

TypeError – when data is not a dictionary

PERSISTENCE_KEY_META_PROPS = 'meta_props'
PERSISTENCE_KEY_MODEL = 'model'
PERSISTENCE_KEY_TRAIN_CONF = 'train_conf'
static get_persist_model_format()[source]
to_model_learnable(exclude_vars) ModelLearnable[source]
to_persistence_dict() dict[source]
update(ml: ModelLearnable)[source]

Update the persistence data with the learned values.

Parameters:

ml (ModelLearnable) – updated information to be merged into existing ModelLearnable