nvflare.app_opt.pt.fedavg module

class PTFedAvg(*args, stop_cond: str | None = None, patience: int | None = None, task_name: str | None = 'train', save_filename: str | None = 'FL_global_model.pt', model: Module | dict | FLModel | None = None, **kwargs)[source]

Bases: FedAvg

PyTorch FedAvg Controller with Early Stopping and Model Selection.

This is a PyTorch-specific wrapper around the generic FedAvg controller. It adds PyTorch-specific model serialization using torch.save/torch.load.

The FedAvg controller includes: - InTime (streaming) aggregation for memory efficiency - Early stopping support - Best model selection and saving - Custom aggregator support

Parameters:
  • num_clients (int, optional) – The number of clients. Defaults to 3.

  • num_rounds (int, optional) – The total number of training rounds. Defaults to 5.

  • stop_cond (str, optional) – Early stopping condition based on metric. String literal in the format of ‘<key> <op> <value>’ (e.g. “accuracy >= 80”). If None, early stopping is disabled.

  • patience (int, optional) – The number of rounds with no improvement after which FL will be stopped. Only applies if stop_cond is set. Defaults to None.

  • task_name (str, optional) – Task name for training. Defaults to “train”.

  • save_filename (str, optional) – Filename for saving the best model. Defaults to “FL_global_model.pt”.

  • model (nn.Module, optional) – Initial PyTorch model. Can be an nn.Module (will call .state_dict()) or a dict of parameters.

Example

```python from model import Net from nvflare import FedJob from nvflare.app_opt.pt.fedavg import PTFedAvg

job = FedJob(name=”pt_fedavg”) controller = PTFedAvg(

num_clients=2, num_rounds=10, stop_cond=”accuracy >= 80”, patience=3, model=Net(),

) job.to(controller, “server”) ```

The base controller for FedAvg Workflow. Note: This class is based on the ModelController.

Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).

A model persistor can be configured via the persistor_id argument of the ModelController. The model persistor is used to load the initial global model which is sent to a list of clients. Each client sends it’s updated weights after local training which is aggregated. Next, the global model is updated. The model_persistor will also save the model after training.

Provides the default implementations for the follow routines:
  • def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel

  • def update_model(self, aggr_result)

The run routine needs to be implemented by the derived class:

  • def run(self)

Parameters:
  • num_clients (int, optional) – The number of clients. Defaults to 3. NOTE: this argument should not be here

  • release. (we will remove this argument in next)

  • num_rounds (int, optional) – The total number of training rounds. Defaults to 5.

  • start_round (int, optional) – The starting round number.

  • memory_gc_rounds (int, optional) – Run memory cleanup (gc.collect + malloc_trim) every N rounds. Set to 0 to disable. Defaults to 0 (disabled).

load_model_file(filepath: str) FLModel[source]

Load model using PyTorch’s torch.load.

Loads parameters via torch.load and FLModel metadata via FOBS.

Parameters:

filepath (str) – path to load the model from

Returns:

loaded model with params and metadata

Return type:

FLModel

run() None[source]

Run FedAvg workflow with PyTorch tensor serialization support.

save_model_file(model: FLModel, filepath: str) None[source]

Save model using PyTorch’s torch.save.

Saves parameters via torch.save and FLModel metadata via FOBS.

Parameters:
  • model (FLModel) – model to save

  • filepath (str) – path to save the model

PTFedAvgEarlyStopping

alias of PTFedAvg