nvflare.app_opt.pt.fedavg_early_stopping module¶
- class PTFedAvgEarlyStopping(*args, stop_cond: str | None = None, patience: int | None = None, task_to_optimize: str | None = 'train', save_filename: str | None = 'FL_global_model.pt', initial_model: FLModel | None = None, **kwargs)[source]¶
Bases:
BaseFedAvgController for FedAvg Workflow with Early Stopping and Model Selection.
- Parameters:
num_clients (int, optional) – The number of clients. Defaults to 3.
num_rounds (int, optional) – The total number of training rounds. Defaults to 5.
stop_cond (str, optional) – early stopping condition based on metric. String literal in the format of ‘<key> <op> <value>’ (e.g. “accuracy >= 80”)
patience (int, optional) – The number of checks with no improvement after which the FL will be stopped. If set to None, this parameter is disabled. If stop_condition is None, patience does not apply
task_to_optimize (str, optional) – Specifies whether to optimize the target metric on the training or validation task. Defaults is train.
save_filename (str, optional) – filename for saving model
initial_model (nn.Module, optional) – initial PyTorch model
- is_curr_model_better(curr_model: FLModel) bool[source]¶
Checks if the new model is better than the current best model.
- Parameters:
curr_model (FLModel) – the new model to evaluate.
- Returns:
True if the new model is better than the current best model, False otherwise
- load_model(filepath: str | None = '') FLModel[source]¶
Loads a model from the provided file path.
- Parameters:
filepath (str, optional) – location of the saved model to load