nvflare.app_opt.pt.model_reader_writer module¶
- class PTModelReaderWriter[source]¶
Bases:
ModelProcessor
Perform the actual read/write operation for PyTorch-based models.
- apply_model(network, multi_processes: bool, model_params: dict, fl_ctx: FLContext, options=None)[source]¶
Set the local model according to model_data.
- Parameters:
model_params – model data information
fl_ctx (FLContext) – FL Context delivered by workflow
options – . Defaults to None.
- Raises:
RuntimeError – Raised when being unable to apply model_params to the network
- Returns:
a list of ops applied to model
- extract_model(network, multi_processes: bool, model_vars: dict, fl_ctx: FLContext) dict [source]¶
Call to extract the current model from the training network.
- Parameters:
network – training network
multi_processes – boolean to indicates if it’s a multi-processes
model_vars – global model dict
fl_ctx – FLContext
- Returns:
a dictionary representing the model