Source code for nvflare.app_opt.pt.fedavg_early_stopping

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Any, Dict, List, Optional

import torch

from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.app_common.utils.math_utils import parse_compare_criteria
from nvflare.app_common.workflows.base_fedavg import BaseFedAvg
from nvflare.app_opt.pt.decomposers import TensorDecomposer
from nvflare.fuel.utils import fobs


[docs] class PTFedAvgEarlyStopping(BaseFedAvg): def __init__( self, *args, stop_cond: Optional[str] = None, patience: Optional[int] = None, task_to_optimize: Optional[str] = "train", save_filename: Optional[str] = "FL_global_model.pt", initial_model: Optional[FLModel] = None, **kwargs, ) -> None: """Controller for FedAvg Workflow with Early Stopping and Model Selection. Args: num_clients (int, optional): The number of clients. Defaults to 3. num_rounds (int, optional): The total number of training rounds. Defaults to 5. stop_cond (str, optional): early stopping condition based on metric. String literal in the format of '\\<key\\> \\<op\\> \\<value\\>' (e.g. "accuracy >= 80") patience (int, optional): The number of checks with no improvement after which the FL will be stopped. If set to `None`, this parameter is disabled. If stop_condition is None, patience does not apply task_to_optimize (str, optional): Specifies whether to optimize the target metric on the training or validation task. Defaults is train. save_filename (str, optional): filename for saving model initial_model (nn.Module, optional): initial PyTorch model """ super().__init__(*args, **kwargs) self.patience = patience self.task_to_optimize = task_to_optimize self.num_fl_rounds_without_improvement: int = 0 self.stop_cond = stop_cond if self.stop_cond: self.stop_condition = parse_compare_criteria(stop_cond) else: self.stop_condition = None self.save_filename = save_filename self.initial_model: FLModel = initial_model self.best_target_metric_value: Any = None
[docs] def run(self) -> None: self.info("Start FedAvg.") if self.initial_model: # Use FOBS for serializing/deserializing PyTorch tensors (self.initial_model) fobs.register(TensorDecomposer) # PyTorch weights initial_weights = self.initial_model.state_dict() else: initial_weights = {} model = FLModel(params=initial_weights) model.start_round = self.start_round model.total_rounds = self.num_rounds for self.current_round in range(self.start_round, self.start_round + self.num_rounds): self.info(f"Round {self.current_round} started.") model.current_round = self.current_round clients = self.sample_clients(self.num_clients) results: List[FLModel] = self.send_model_and_wait( task_name=self.task_to_optimize, targets=clients, data=model ) # using default aggregate_fn with `WeightedAggregationHelper`. # Can overwrite self.aggregate_fn with signature Callable[List[FLModel], FLModel] aggregate_results = self.aggregate(results, aggregate_fn=self.aggregate_fn) model = self.update_model(model, aggregate_results) self.info(f"Round {self.current_round} global metrics: {model.metrics}") if self.is_curr_model_better(model): self.info("New best model found") self.save_model(model, os.path.join(os.getcwd(), self.save_filename)) else: if self.patience: self.info( f"No metric improvement, num of FL rounds without improvement: " f"{self.num_fl_rounds_without_improvement}" ) if self.should_stop(model.metrics): self.info(f"Stopping at round={self.current_round} out of total_rounds={self.num_rounds}.") break self.info("Finished FedAvg.")
[docs] def should_stop(self, metrics: Optional[Dict] = None) -> bool: """Checks whether the current FL experiment should stop. Args: metrics (Dict, optional): experiment metrics. Returns: True if the experiment should stop, False otherwise. """ if self.stop_condition is None or metrics is None: return False if self.patience and (self.patience <= self.num_fl_rounds_without_improvement): self.info(f"Exceeded the number of FL rounds ({self.patience}) without improvements") return True key, target, op_fn = self.stop_condition value = metrics.get(key, None) if value is None: raise RuntimeError(f"stop criteria key '{key}' doesn't exists in metrics") if op_fn(value, target): self.info(f"Early stop condition satisfied: {self.stop_condition}") return True return False
[docs] def is_curr_model_better(self, curr_model: FLModel) -> bool: """Checks if the new model is better than the current best model. Args: curr_model (FLModel): the new model to evaluate. Returns: True if the new model is better than the current best model, False otherwise """ if self.stop_condition is None: return True curr_metrics = curr_model.metrics if curr_metrics is None: return False target_metric, _, op_fn = self.stop_condition curr_target_metric = curr_metrics.get(target_metric, None) if curr_target_metric is None: return False if self.best_target_metric_value is None or op_fn(curr_target_metric, self.best_target_metric_value): if self.patience and self.best_target_metric_value == curr_target_metric: self.num_fl_rounds_without_improvement += 1 return False else: self.best_target_metric_value = curr_target_metric self.num_fl_rounds_without_improvement = 0 return True self.num_fl_rounds_without_improvement += 1 return False
[docs] def save_model(self, model: FLModel, filepath: Optional[str] = "") -> None: """Saves the model to the specified file path. Args: model (FLModel): model to save filepath (str, optional): location where the model will be saved """ params = model.params # PyTorch save torch.save(params, filepath) # save FLModel metadata model.params = {} fobs.dumpf(model, f"{filepath}.metadata") model.params = params
[docs] def load_model(self, filepath: Optional[str] = "") -> FLModel: """Loads a model from the provided file path. Args: filepath (str, optional): location of the saved model to load """ # PyTorch load params = torch.load(filepath) # load FLModel metadata model: FLModel = fobs.loadf(f"{filepath}.metadata") model.params = params return model