nvflare.app_opt.pt.file_model_persistor module
- class PTFileModelPersistor(exclude_vars: str | None = None, model: Module | str | Dict[str, Any] | None = None, global_model_file_name: str = 'FL_global_model.pt', best_global_model_file_name: str = 'best_FL_global_model.pt', source_ckpt_file_full_name: str | None = None, filter_id: str | None = None, load_weights_only: bool = True, allow_numpy_conversion: bool = True)[source]
Bases:
ModelPersistorPersist pytorch-based model to/from file system.
This Model Persistor tries to load PT model data in the following three ways:
Load from a specified source checkpoint file
Load from a location from the app folder
Load from a torch model object
The Persistor tries method 1 first if the source_ckpt_file_full_name is specified; If source_ckpt_file_full_name is not specified, it tries method 2; If no checkpoint location is specified in the app folder, it tries method 3.
Method 2 - Load from a location from the app folder
It is assumed that the app folder must contain the environments.json file. Among other things, this JSON file must specify where to find the checkpoint file. It does so with two JSON elements:
APP_CKPT_DIR: specifies the folder (within the app) where the checkpoint file resides.
APP_CKPT: specifies the base file name of the checkpoint
Here is an example of the environments.json content:
{ "APP_CKPT_DIR": "model", "APP_CKPT": "pretrained_model.pt" }
In this example, the checkpoint file is located in the “model” folder within the app and is named pretrained_model.pt.
Method 3 - Load from a torch model object. In this case, the ‘model’ arg must be a valid torch model, or the component ID of a valid torch model included in the “components” section of your config_fed_server.json.
If all 3 methods fail, system_panic() is called.
If checkpoint folder name is specified, then global model and best global model will be saved to it; Otherwise they will be saved directly in the app folder.
The model is saved in a dict depending on the persistor you used. You might need to access it with
model.load_state_dict(torch.load(path_to_model)["model"])as there is additional meta information together with the model weights.- Parameters:
exclude_vars (str, optional) – regex expression specifying weight vars to be excluded from training. Defaults to None.
model – Model input. Can be one of: - torch.nn.Module: Direct model instance - str: Component ID of a model registered in config - dict: {“path”: “fully.qualified.Class”, “args”: {…}} for dynamic instantiation Defaults to None.
global_model_file_name (str, optional) – file name for saving global model. Defaults to DefaultCheckpointFileName.GLOBAL_MODEL.
best_global_model_file_name (str, optional) – file name for saving best global model. Defaults to DefaultCheckpointFileName.BEST_GLOBAL_MODEL.
source_ckpt_file_full_name (str, optional) – full file name for source model checkpoint file. Defaults to None.
filter_id – Optional string that defines a filter component that is applied to prepare the model to be saved, e.g. for serialization of custom Python objects.
load_weights_only – Indicates whether torch’s unpickler should be restricted to loading only tensors, primitive types, dictionaries and any types added via
torch.serialization.add_safe_globals(). Defaults to True (safe mode). Set to False for legacy checkpoints that require full unpickling (pre-PyTorch 2.6 behavior).allow_numpy_conversion (bool) – If set to True, enables conversion between PyTorch tensors and NumPy arrays. PyTorch tensors will be converted to NumPy arrays during ‘load_model’, and NumPy arrays will be converted to PyTorch tensors during ‘save_model’. Defaults to True.
- Raises:
ValueError – when source_ckpt_file_full_name does not exist
- get_model(model_file: str, fl_ctx: FLContext) ModelLearnable[source]
Retrieve a specific model by file name.
This method is called by get() to load a specific model from the inventory. Persistors that support multiple models (e.g., for cross-site evaluation) should override this method to load the specified model file.
Simple persistors that only work with a single model (like NPModelPersistor) do not need to implement this method.
- Parameters:
model_file – Name or path of the model file to retrieve
fl_ctx – FLContext
- Returns:
ModelLearnable object containing the model data
- Raises:
NotImplementedError – If the persistor doesn’t support retrieving specific models
- get_model_inventory(fl_ctx: FLContext) Dict[str, ModelDescriptor][source]
Get the model inventory of the ModelPersister.
- Parameters:
fl_ctx – FLContext
- Returns:
ModelDescriptor
- Return type:
A dict of model_name
- handle_event(event: str, fl_ctx: FLContext)[source]
Handles events.
- Parameters:
event_type (str) – event type fired by workflow.
fl_ctx (FLContext) – FLContext information.
- load_model(fl_ctx: FLContext) ModelLearnable[source]
Convert initialised model into Learnable/Model format.
- Parameters:
fl_ctx (FLContext) – FL Context delivered by workflow
- Returns:
a Learnable/Model object
- Return type:
Model
- save_model(ml: ModelLearnable, fl_ctx: FLContext)[source]
Persist the model object.
- Parameters:
model – Model object to be saved
fl_ctx – FLContext