Source code for nvflare.app_common.workflows.fedavg

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import time
from typing import Any, Dict, Optional, Set, Union

from nvflare.apis.fl_constant import FLMetaKey
from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.app_common.aggregators.model_aggregator import ModelAggregator
from nvflare.app_common.aggregators.weighted_aggregation_helper import (
    WeightedAggregationHelper,
    filter_aggregatable_metrics,
)
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.app_common.utils.math_utils import parse_compare_criteria
from nvflare.app_common.utils.tensor_disk_offload_context import (
    apply_enable_tensor_disk_offload,
    restore_enable_tensor_disk_offload,
)
from nvflare.fuel.utils import fobs
from nvflare.fuel.utils.log_utils import center_message

from .base_fedavg import BaseFedAvg


[docs] class FedAvg(BaseFedAvg): """Controller for FedAvg Workflow with optional Early Stopping and Model Selection. *Note*: This class is based on the `ModelController`. Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629). Uses InTime (streaming) aggregation for memory efficiency - each client result is aggregated immediately upon receipt rather than collecting all results first. Supports custom aggregators via the ModelAggregator interface. Provides the implementations for the `run` routine, controlling the main workflow: - def run(self) The parent classes provide the default implementations for other routines. For simple model persistence without complex ModelPersistor setup, you can: 1. Pass `model` (dict of params) and `save_filename` 2. Override `save_model()` and `load_model()` for framework-specific serialization Args: num_clients (int, optional): The number of clients. Defaults to 3. num_rounds (int, optional): The total number of training rounds. Defaults to 5. start_round (int, optional): The starting round number. persistor_id (str, optional): ID of the persistor component. Defaults to "persistor". If empty and model is provided, uses simple save_model/load_model methods. model (dict or FLModel, optional): Initial model parameters. If provided, this is used instead of loading from persistor. Defaults to None. save_filename (str, optional): Filename for saving the best model. Defaults to "FL_global_model.pt". Only used when persistor_id is empty. aggregator (ModelAggregator, optional): Custom aggregator for combining client model updates. Must implement accept_model(), aggregate_model(), reset_stats(). If None, uses built-in weighted averaging (memory-efficient). Defaults to None. stop_cond (str, optional): Early stopping condition based on metric. String literal in the format of '<key> <op> <value>' (e.g. "accuracy >= 80"). If None, early stopping is disabled. Defaults to None. patience (int, optional): The number of rounds with no improvement after which FL will be stopped. Only applies if stop_cond is set. Defaults to None. task_name (str, optional): Task name for training. Defaults to "train". exclude_vars (str, optional): Regex pattern for variables to exclude from aggregation. Defaults to None. Only used when no custom aggregator is provided. aggregation_weights (dict, optional): Per-client aggregation weights. Defaults to None (equal weights). Only used when no custom aggregator is provided. enable_tensor_disk_offload (bool, optional): Download tensors to disk during FOBS streaming instead of deserializing into memory. Reduces peak server memory from ~N× to ~1× model size during aggregation. When used with a custom aggregator, lazy refs are passed through directly and must be handled by that aggregator. Defaults to False. """ def __init__( self, *args, model: Optional[Union[Dict, FLModel]] = None, save_filename: Optional[str] = "FL_global_model.pt", aggregator: Optional[ModelAggregator] = None, stop_cond: Optional[str] = None, patience: Optional[int] = None, task_name: Optional[str] = "train", exclude_vars: Optional[str] = None, aggregation_weights: Optional[Dict[str, float]] = None, enable_tensor_disk_offload: bool = False, **kwargs, ) -> None: super().__init__(*args, **kwargs) # Simple model persistence (alternative to persistor) self.model = model self.save_filename = save_filename # Custom aggregator (optional) self.aggregator = aggregator # Early stopping configuration self.stop_cond = stop_cond self.patience = patience self.task_name = task_name # Aggregation configuration (used only when no custom aggregator) self.exclude_vars = exclude_vars self.aggregation_weights = aggregation_weights or {} self.enable_tensor_disk_offload = enable_tensor_disk_offload # Parse stop condition if self.stop_cond: self.stop_condition = parse_compare_criteria(stop_cond) else: self.stop_condition = None # Early stopping state self.num_fl_rounds_without_improvement: int = 0 self.best_target_metric_value: Any = None # InTime aggregation helpers (reset each round, used only when no custom aggregator) self._aggr_helper: Optional[WeightedAggregationHelper] = None self._aggr_metrics_helper: Optional[WeightedAggregationHelper] = None self._all_metrics: bool = True self._warned_metric_keys: Set[str] = set() # warn at most once per key (across clients/rounds) self._received_count: int = 0 self._expected_count: int = 0 self._params_type = None # Only store params_type, not full result
[docs] def run(self) -> None: disk_offload_root_dir = None previous_disk_offload = None try: disk_offload_root_dir = ( tempfile.mkdtemp(prefix=f"nvflare_tensor_offload_{self.fl_ctx.get_job_id('job')}_") if self.enable_tensor_disk_offload else None ) previous_disk_offload, disk_offload_applied = apply_enable_tensor_disk_offload( engine=getattr(self, "engine", None), enabled=self.enable_tensor_disk_offload, root_dir=disk_offload_root_dir, ) if self.enable_tensor_disk_offload and not disk_offload_applied: self.warning( "enable_tensor_disk_offload=True but no active cell is available; " "falling back to in-memory tensor download" ) self.info(center_message("Start FedAvg.")) # Set NUM_ROUNDS in FL context for persistor and other components. self.fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self.num_rounds, private=True, sticky=False) # Load initial model - prefer model if provided, else use persistor if self.model is not None: if isinstance(self.model, FLModel): model = self.model else: # Assume dict of params model = FLModel(params=self.model) self.info("Using provided model") else: model = self.load_model() model.start_round = self.start_round model.total_rounds = self.num_rounds for self.current_round in range(self.start_round, self.start_round + self.num_rounds): self.info(center_message(message=f"Round {self.current_round} started.", boarder_str="-")) model.current_round = self.current_round self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self.current_round, private=True, sticky=False) if self.aggregator and self.aggregator.fl_ctx: self.aggregator.fl_ctx.set_prop( AppConstants.CURRENT_ROUND, self.current_round, private=True, sticky=False ) self.event(AppEventType.ROUND_STARTED) clients = self.sample_clients(self.num_clients) # Reset aggregation state for this round if self.aggregator: # Use custom aggregator self.aggregator.reset_stats() else: # Use built-in InTime aggregation self._aggr_helper = WeightedAggregationHelper(exclude_vars=self.exclude_vars) self._aggr_metrics_helper = WeightedAggregationHelper() self._all_metrics = True # Only used by built-in aggregation # Shared state for both aggregator types self._received_count = 0 self._expected_count = len(clients) self._params_type = None # Non-blocking send with callback for streaming aggregation self.send_model( task_name=self.task_name, targets=clients, data=model, callback=self._aggregate_one_result, ) # Wait for all results to be processed while self.get_num_standing_tasks(): if self.abort_signal.triggered: self.info("Abort signal triggered. Finishing FedAvg.") return time.sleep(self._task_check_period) self.event(AppEventType.BEFORE_AGGREGATION) # Get final aggregated result aggregate_results = self._get_aggregated_result() model = self.update_model(model, aggregate_results) # Early stopping: check if current model is better if self.stop_condition: self.info(f"Round {self.current_round} global metrics: {model.metrics}") if self.is_curr_model_better(model): self.info("New best model found.") self.save_model(model) else: if self.patience: self.info( f"No metric improvement, num of FL rounds without improvement: " f"{self.num_fl_rounds_without_improvement}" ) # Check if we should stop early if self.should_stop(model.metrics): self.info(f"Stopping at round={self.current_round} out of total_rounds={self.num_rounds}.") break else: # No early stopping: save model every round self.save_model(model) # Memory cleanup at end of round (if configured) self._maybe_cleanup_memory() self.info(center_message("Finished FedAvg.")) finally: restore_enable_tensor_disk_offload( engine=getattr(self, "engine", None), previous_value=previous_disk_offload, root_dir=disk_offload_root_dir, ) if disk_offload_root_dir: shutil.rmtree(disk_offload_root_dir, ignore_errors=True)
def _aggregate_one_result(self, result: FLModel) -> None: """Callback: aggregate ONE client result immediately (InTime aggregation).""" if not result.params: client_name = result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN) self.warning(f"Empty result from client {client_name}, skipping.") return # Store only params_type from first result (not the full model) if self._params_type is None: self._params_type = result.params_type client_name = result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN) if self.aggregator: # Use custom aggregator self.aggregator.accept_model(result) else: # Built-in InTime aggregation: add() materializes lazy refs on-demand. # Cleanup relies on lazy ref object lifetime / GC. # Get weight: use aggregation_weights if specified, else use NUM_STEPS if self.aggregation_weights and client_name in self.aggregation_weights: aggregation_weight = self.aggregation_weights[client_name] else: aggregation_weight = 1.0 n_iter = result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, None) # Handle None case (e.g., first round of some algorithms like K-Means) if n_iter is None: n_iter = 1.0 weight = aggregation_weight * float(n_iter) self._aggr_helper.add( data=result.params, weight=weight, contributor_name=client_name, contribution_round=self.current_round, ) # Add to metrics aggregation if available (only aggregatable values; # non-aggregatable metrics like dicts are still in result.metrics for collection) # If a client omits metrics entirely (None), disable round-level metrics # aggregation instead of mixing present/absent metric coverage. if result.metrics is None: self._all_metrics = False if self._all_metrics and result.metrics: # Non-empty metric dicts are treated as "present"; unsupported values are # filtered per key while allowing other aggregatable keys to contribute. aggregatable = filter_aggregatable_metrics( result.metrics, warn_skipped=lambda k, tn: self.warning(f"Metric '{k}' ({tn}) skipped for aggregation."), warned_metric_keys=self._warned_metric_keys, ) if aggregatable: self._aggr_metrics_helper.add( data=aggregatable, weight=weight, contributor_name=client_name, contribution_round=self.current_round, ) self._received_count += 1 self.info(f"Aggregated {self._received_count}/{self._expected_count} results") def _get_aggregated_result(self) -> FLModel: """Get the final aggregated result after all clients have responded.""" if self.aggregator: # Use custom aggregator result: FLModel = self.aggregator.aggregate_model() result.meta = result.meta or {} result.meta["nr_aggregated"] = self._received_count result.meta["current_round"] = self.current_round return result else: # Use built-in InTime aggregation aggr_params = self._aggr_helper.get_result() aggr_metrics = self._aggr_metrics_helper.get_result() if self._all_metrics else None aggr_metrics = aggr_metrics or None return FLModel( params=aggr_params, params_type=self._params_type, metrics=aggr_metrics, meta={"nr_aggregated": self._received_count, "current_round": self.current_round}, )
[docs] def should_stop(self, metrics: Optional[Dict] = None) -> bool: """Checks whether the current FL experiment should stop. Args: metrics (Dict, optional): experiment metrics. Returns: True if the experiment should stop, False otherwise. """ if self.stop_condition is None or metrics is None: return False # Check patience if self.patience and (self.patience <= self.num_fl_rounds_without_improvement): self.info(f"Exceeded the number of FL rounds ({self.patience}) without improvements") return True # Check stop condition key, target, op_fn = self.stop_condition value = metrics.get(key, None) if value is None: self.warning(f"Stop criteria key '{key}' doesn't exist in metrics: {list(metrics.keys())}") return False if op_fn(value, target): self.info(f"Early stop condition satisfied: {self.stop_cond}") return True return False
[docs] def is_curr_model_better(self, curr_model: FLModel) -> bool: """Checks if the new model is better than the current best model. Args: curr_model (FLModel): the new model to evaluate. Returns: True if the new model is better than the current best model, False otherwise """ if self.stop_condition is None: return True curr_metrics = curr_model.metrics if curr_metrics is None: return False target_metric, _, op_fn = self.stop_condition curr_target_metric = curr_metrics.get(target_metric, None) if curr_target_metric is None: return False if self.best_target_metric_value is None or op_fn(curr_target_metric, self.best_target_metric_value): if self.patience and self.best_target_metric_value == curr_target_metric: self.num_fl_rounds_without_improvement += 1 return False else: self.best_target_metric_value = curr_target_metric self.num_fl_rounds_without_improvement = 0 return True self.num_fl_rounds_without_improvement += 1 return False
[docs] def load_model(self) -> FLModel: """Load model. Uses persistor if available, otherwise uses load_model_file. Override `load_model_file` for framework-specific deserialization (e.g., torch.load). Returns: FLModel: loaded model, or None if loading fails """ if self.persistor: # Use persistor (parent class behavior) return super().load_model() elif self.save_filename: # Try to load from file filepath = os.path.join(self.get_run_dir(), self.save_filename) if os.path.exists(filepath): self.info(f"Loading model from {filepath}") return self.load_model_file(filepath) else: self.info(f"No saved model found at {filepath}, starting fresh") return FLModel(params={}) else: self.warning("No persistor or save_filename configured") return FLModel(params={})
[docs] def save_model(self, model: FLModel) -> None: """Save model. Uses persistor if available, otherwise uses save_model_file. Override `save_model_file` for framework-specific serialization (e.g., torch.save). Args: model (FLModel): model to save """ if self.persistor: # Use persistor (parent class behavior) super().save_model(model) elif self.save_filename: # Use simple file-based saving filepath = os.path.join(self.get_run_dir(), self.save_filename) self.save_model_file(model, filepath) self.info(f"Model saved to {filepath}") else: self.warning("No persistor or save_filename configured, model not saved")
[docs] def save_model_file(self, model: FLModel, filepath: str) -> None: """Save model to file. Override this for framework-specific serialization. Default implementation uses FOBS (pickle-compatible). For PyTorch, override with: torch.save(model.params, filepath) Args: model (FLModel): model to save filepath (str): path to save the model """ # Default: use FOBS to save entire FLModel fobs.dumpf(model, filepath)
[docs] def load_model_file(self, filepath: str) -> FLModel: """Load model from file. Override this for framework-specific deserialization. Default implementation uses FOBS (pickle-compatible). For PyTorch, override with: FLModel(params=torch.load(filepath)) Args: filepath (str): path to load the model from Returns: FLModel: loaded model """ # Default: use FOBS to load entire FLModel return fobs.loadf(filepath)