nvflare.app_opt.pt.model_reader_writer module

class PTModelReaderWriter[source]

Bases: ModelProcessor

Perform the actual read/write operation for PyTorch-based models.

apply_model(network, multi_processes: bool, model_params: dict, fl_ctx: FLContext, options=None)[source]

Set the local model according to model_data.

Parameters:
  • model_params – model data information

  • fl_ctx (FLContext) – FL Context delivered by workflow

  • options – . Defaults to None.

Raises:

RuntimeError – Raised when being unable to apply model_params to the network

Returns:

a list of ops applied to model

extract_model(network, multi_processes: bool, model_vars: dict, fl_ctx: FLContext) dict[source]

Call to extract the current model from the training network.

Parameters:
  • network – training network

  • multi_processes – boolean to indicates if it’s a multi-processes

  • model_vars – global model dict

  • fl_ctx – FLContext

Returns:

a dictionary representing the model