nvflare.app_opt.pt.he_model_reader_writer module¶
- class HEPTModelReaderWriter[source]¶
Bases:
PTModelReaderWriter
Perform the actual read/write operation for PyTorch-based models.
- apply_model(network, multi_processes: bool, model_params: dict, fl_ctx: FLContext, options=None)[source]¶
Write global model back to local model.
Needed to extract local parameter shape to reshape decrypted vectors.
- Parameters:
network (pytorch.nn) – network object to read/write
multi_processes (bool) – is the workflow in multi_processes environment
model_params (dict) – which parameters to read/write
fl_ctx (FLContext) – FL system-wide context
options (dict, optional) – additional information on how to process read/write. Defaults to None.
- Raises:
RuntimeError – unable to reshape the network layers or mismatch between network layers and model_params
- Returns:
a list of parameters been processed
- Return type:
list