Source code for nvflare.app_opt.pt.fedavg

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union

import torch

from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.decomposers import TensorDecomposer
from nvflare.fuel.utils import fobs


[docs] class PTFedAvg(FedAvg): """PyTorch FedAvg Controller with Early Stopping and Model Selection. This is a PyTorch-specific wrapper around the generic FedAvg controller. It adds PyTorch-specific model serialization using torch.save/torch.load. The FedAvg controller includes: - InTime (streaming) aggregation for memory efficiency - Early stopping support - Best model selection and saving - Custom aggregator support Args: num_clients (int, optional): The number of clients. Defaults to 3. num_rounds (int, optional): The total number of training rounds. Defaults to 5. stop_cond (str, optional): Early stopping condition based on metric. String literal in the format of '<key> <op> <value>' (e.g. "accuracy >= 80"). If None, early stopping is disabled. patience (int, optional): The number of rounds with no improvement after which FL will be stopped. Only applies if stop_cond is set. Defaults to None. task_name (str, optional): Task name for training. Defaults to "train". save_filename (str, optional): Filename for saving the best model. Defaults to "FL_global_model.pt". model (nn.Module, optional): Initial PyTorch model. Can be an nn.Module (will call .state_dict()) or a dict of parameters. Example: ```python from model import Net from nvflare import FedJob from nvflare.app_opt.pt.fedavg import PTFedAvg job = FedJob(name="pt_fedavg") controller = PTFedAvg( num_clients=2, num_rounds=10, stop_cond="accuracy >= 80", patience=3, model=Net(), ) job.to(controller, "server") ``` """ def __init__( self, *args, stop_cond: Optional[str] = None, patience: Optional[int] = None, task_name: Optional[str] = "train", save_filename: Optional[str] = "FL_global_model.pt", model: Optional[Union[torch.nn.Module, dict, FLModel]] = None, **kwargs, ) -> None: # Convert PyTorch model to dict if needed if model is None: initial_model_params = None elif isinstance(model, torch.nn.Module): initial_model_params = model.state_dict() elif isinstance(model, dict): initial_model_params = model elif isinstance(model, FLModel): initial_model_params = model else: raise TypeError( f"model must be torch.nn.Module, dict, FLModel, or None, " f"but got {type(model).__name__}" ) super().__init__( *args, model=initial_model_params, save_filename=save_filename, stop_cond=stop_cond, patience=patience, task_name=task_name, **kwargs, )
[docs] def run(self) -> None: """Run FedAvg workflow with PyTorch tensor serialization support.""" # Register TensorDecomposer for FOBS serialization of PyTorch tensors fobs.register(TensorDecomposer) super().run()
[docs] def save_model_file(self, model: FLModel, filepath: str) -> None: """Save model using PyTorch's torch.save. Saves parameters via torch.save and FLModel metadata via FOBS. Args: model (FLModel): model to save filepath (str): path to save the model """ # Save parameters with torch.save torch.save(model.params, filepath) # Save FLModel metadata (metrics, params_type, etc.) separately params = model.params try: model.params = {} # Temporarily remove params to save only metadata fobs.dumpf(model, f"{filepath}.metadata") finally: model.params = params # Restore params
[docs] def load_model_file(self, filepath: str) -> FLModel: """Load model using PyTorch's torch.load. Loads parameters via torch.load and FLModel metadata via FOBS. Args: filepath (str): path to load the model from Returns: FLModel: loaded model with params and metadata """ import os # Load parameters with torch.load params = torch.load(filepath, weights_only=True) # Load FLModel metadata if exists metadata_path = f"{filepath}.metadata" if os.path.exists(metadata_path): model: FLModel = fobs.loadf(metadata_path) model.params = params else: model = FLModel(params=params) return model
# Backward compatibility alias PTFedAvgEarlyStopping = PTFedAvg