nvflare.app_opt.pt.file_model_persistor module

class PTFileModelPersistor(exclude_vars=None, model=None, global_model_file_name='FL_global_model.pt', best_global_model_file_name='best_FL_global_model.pt', source_ckpt_file_full_name=None, filter_id: str | None = None)[source]

Bases: ModelPersistor

Persist pytorch-based model to/from file system.

This Model Persistor tries to load PT model data in the following three ways:

  1. Load from a specified source checkpoint file

  2. Load from a location from the app folder

  3. Load from a torch model object

The Persistor tries method 1 first if the source_ckpt_file_full_name is specified; If source_ckpt_file_full_name is not specified, it tries method 2; If no checkpoint location is specified in the app folder, it tries method 3.

Method 2 - Load from a location from the app folder

It is assumed that the app folder must contain the environments.json file. Among other things, this JSON file must specify where to find the checkpoint file. It does so with two JSON elements:

  • APP_CKPT_DIR: specifies the folder (within the app) where the checkpoint file resides.

  • APP_CKPT: specifies the base file name of the checkpoint

Here is an example of the environments.json content:

{
    "APP_CKPT_DIR": "model",
    "APP_CKPT": "pretrained_model.pt"
}

In this example, the checkpoint file is located in the “model” folder within the app and is named pretrained_model.pt.

Method 3 - Load from a torch model object. In this case, the ‘model’ arg must be a valid torch model, or the component ID of a valid torch model included in the “components” section of your config_fed_server.json.

If all 3 methods fail, system_panic() is called.

If checkpoint folder name is specified, then global model and best global model will be saved to it; Otherwise they will be saved directly in the app folder.

The model is saved in a dict depending on the persistor you used. You might need to access it with model.load_state_dict(torch.load(path_to_model)["model"]) as there is additional meta information together with the model weights.

Parameters:
  • exclude_vars (str, optional) – regex expression specifying weight vars to be excluded from training. Defaults to None.

  • model (str, optional) – torch model object or component id of the model object. Defaults to None.

  • global_model_file_name (str, optional) – file name for saving global model. Defaults to DefaultCheckpointFileName.GLOBAL_MODEL.

  • best_global_model_file_name (str, optional) – file name for saving best global model. Defaults to DefaultCheckpointFileName.BEST_GLOBAL_MODEL.

  • source_ckpt_file_full_name (str, optional) – full file name for source model checkpoint file. Defaults to None.

  • filter_id – Optional string that defines a filter component that is applied to prepare the model to be saved, e.g. for serialization of custom Python objects.

Raises:

ValueError – when source_ckpt_file_full_name does not exist

get_model(model_file: str, fl_ctx: FLContext) ModelLearnable[source]
get_model_from_location(location, fl_ctx)[source]
get_model_inventory(fl_ctx: FLContext) Dict[str, ModelDescriptor][source]

Get the model inventory of the ModelPersister.

Parameters:

fl_ctx – FLContext

Returns: { model_kind: ModelDescriptor }

handle_event(event: str, fl_ctx: FLContext)[source]

Handles events.

Parameters:
  • event_type (str) – event type fired by workflow.

  • fl_ctx (FLContext) – FLContext information.

load_model(fl_ctx: FLContext) ModelLearnable[source]

Convert initialised model into Learnable/Model format.

Parameters:

fl_ctx (FLContext) – FL Context delivered by workflow

Returns:

a Learnable/Model object

Return type:

Model

save_model(ml: ModelLearnable, fl_ctx: FLContext)[source]

Persist the model object.

Parameters:
  • model – Model object to be saved

  • fl_ctx – FLContext

save_model_file(save_path: str)[source]