nvflare.app_opt.pt.he_model_reader_writer module

class HEPTModelReaderWriter[source]

Bases: PTModelReaderWriter

Perform the actual read/write operation for PyTorch-based models.

apply_model(network, multi_processes: bool, model_params: dict, fl_ctx: FLContext, options=None)[source]

Write global model back to local model.

Needed to extract local parameter shape to reshape decrypted vectors.

Parameters:
  • network (pytorch.nn) – network object to read/write

  • multi_processes (bool) – is the workflow in multi_processes environment

  • model_params (dict) – which parameters to read/write

  • fl_ctx (FLContext) – FL system-wide context

  • options (dict, optional) – additional information on how to process read/write. Defaults to None.

Raises:

RuntimeError – unable to reshape the network layers or mismatch between network layers and model_params

Returns:

a list of parameters been processed

Return type:

list