nvflare.app_opt.pt.fedavg_early_stopping module

class PTFedAvgEarlyStopping(*args, stop_cond: str | None = None, save_filename: str = 'FL_global_model.pt', initial_model=None, **kwargs)[source]

Bases: BaseFedAvg

Controller for FedAvg Workflow with Early Stopping and Model Selection.

Parameters:
  • num_clients (int, optional) – The number of clients. Defaults to 3.

  • num_rounds (int, optional) – The total number of training rounds. Defaults to 5.

  • stop_cond (str, optional) – early stopping condition based on metric. string literal in the format of “<key> <op> <value>” (e.g. “accuracy >= 80”)

  • save_filename (str, optional) – filename for saving model

  • initial_model (nn.Module, optional) – initial PyTorch model

The base controller for FedAvg Workflow. Note: This class is based on the ModelController.

Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).

A model persistor can be configured via the persistor_id argument of the ModelController. The model persistor is used to load the initial global model which is sent to a list of clients. Each client sends it’s updated weights after local training which is aggregated. Next, the global model is updated. The model_persistor will also save the model after training.

Provides the default implementations for the follow routines:
  • def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel

  • def update_model(self, aggr_result)

The run routine needs to be implemented by the derived class:

  • def run(self)

Parameters:
  • num_clients (int, optional) – The number of clients. Defaults to 3.

  • num_rounds (int, optional) – The total number of training rounds. Defaults to 5.

  • start_round (int, optional) – The starting round number.

is_curr_model_better(best_model: FLModel, curr_model: FLModel, target_metric: str, op_fn: Callable) bool[source]
load_model(filepath='')[source]

Load initial model from persistor. If persistor is not configured, returns empty FLModel.

Returns:

FLModel

run() None[source]

Main run routine for the controller workflow.

save_model(model, filepath='')[source]

Saves model with persistor. If persistor is not configured, does not save.

Parameters:

model (FLModel) – model to save.

Returns:

None

select_best_model(curr_model: FLModel)[source]
should_stop(metrics: Dict | None = None, stop_condition: str | None = None)[source]