# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union
from pydantic import BaseModel
from nvflare.apis.dxo import DataKind
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.aggregators import InTimeAccumulateWeightedAggregator
from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather
from nvflare.app_opt.pt import PTFileModelPersistor
from nvflare.app_opt.pt.fedopt import PTFedOptModelShareableGenerator
from nvflare.app_opt.pt.file_model_locator import PTFileModelLocator
from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob
from nvflare.client.config import ExchangeFormat, TransferType
from nvflare.fuel.utils.constants import FrameworkType
from nvflare.job_config.script_runner import ScriptRunner
from nvflare.recipe.spec import Recipe
# Internal — not part of the public API
class _FedOptValidator(BaseModel):
model_config = {"arbitrary_types_allowed": True}
name: str
model: Any
initial_ckpt: Optional[str] = None
min_clients: int
num_rounds: int
train_script: str
train_args: str
aggregator: Optional[Aggregator]
launch_external_process: bool = False
command: str = "python3 -u"
server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY
device: Optional[str] = None
server_memory_gc_rounds: int = 1
client_memory_gc_rounds: int = 0
cuda_empty_cache: bool = False
[docs]
class FedOptRecipe(Recipe):
"""A recipe for implementing Federated Optimization (FedOpt) in NVFlare.
FedOpt is a federated learning algorithm that optimizes the global model using a server-side optimizer and learning rate scheduler.
After each round, the global model is updated using the specified optimizer and learning rate scheduler.
The algorithm is proposed in Reddi et al. "Adaptive Federated Optimization." arXiv preprint arXiv:2003.00295 (2020).
Note: FedOpt is only implemented for params_transfer_type == TransferType.DIFF and DataKind.WEIGHT_DIFF in the aggregator.
Args:
name: Name of the federated learning job. Defaults to "fedopt".
model: Initial model to start federated training with (REQUIRED). Can be:
- nn.Module instance
- Dict config: {"class_path": "module.ClassName", "args": {"param": value}}
Note: FedOpt requires a model for the server-side optimizer to work.
initial_ckpt: Absolute path to a pre-trained checkpoint file. The file may not
exist locally as it could be on the server. Used to load initial weights.
Note: PyTorch requires model when using initial_ckpt (for architecture).
min_clients: Minimum number of clients required to start a training round.
num_rounds: Number of federated training rounds to execute. Defaults to 2.
train_script: Path to the training script that will be executed on each client.
train_args: Command line arguments to pass to the training script.
aggregator: Aggregator for combining client updates. If None,
uses InTimeAccumulateWeightedAggregator with expected_data_kind=DataKind.WEIGHT_DIFF.
launch_external_process (bool): Whether to launch the script in external process. Defaults to False.
command (str): If launch_external_process=True, command to run script (prepended to script). Defaults to "python3".
server_expected_format (str): What format to exchange the parameters between server and client.
source_model (str): ID of the source model component. Defaults to "model".
optimizer_args (dict): Configuration for server-side optimizer with keys:
- path: Fully qualified optimizer class (e.g., "torch.optim.SGD"). "class_path" is also accepted.
- args: Dictionary of optimizer arguments (e.g., {"lr": 1.0, "momentum": 0.6})
- config_type: Optional; if omitted, set to "dict" so the config is not instantiated at load time.
lr_scheduler_args (dict): Optional configuration for learning rate scheduler with keys:
- path: Fully qualified scheduler class (e.g., "torch.optim.lr_scheduler.CosineAnnealingLR"). "class_path" is also accepted.
- args: Dictionary of scheduler arguments (e.g., {"T_max": 100, "eta_min": 0.9})
- config_type: Optional; if omitted, set to "dict" so the config is not instantiated at load time.
device (str): Device to use for server-side optimization, e.g. "cpu" or "cuda:0".
Defaults to None; will default to cuda if available and no device is specified.
server_memory_gc_rounds: Run memory cleanup (gc.collect + malloc_trim) every N rounds on server.
Set to 0 to disable. Defaults to 1 (every round).
Example:
```python
recipe = FedOptRecipe(
name="my_fedopt_job",
model=pretrained_model,
min_clients=2,
num_rounds=10,
train_script="client.py",
train_args="--epochs 5 --batch_size 32",
device="cpu",
source_model="model",
optimizer_args={
"path": "torch.optim.SGD",
"args": {"lr": 1.0, "momentum": 0.6},
"config_type": "dict"
},
lr_scheduler_args={
"path": "torch.optim.lr_scheduler.CosineAnnealingLR",
"args": {"T_max": "{num_rounds}", "eta_min": 0.9},
"config_type": "dict"
}
)
```
"""
def __init__(
self,
*,
name: str = "fedopt",
model: Union[Any, dict[str, Any], None] = None,
initial_ckpt: Optional[str] = None,
min_clients: int,
num_rounds: int = 2,
train_script: str,
train_args: str = "",
aggregator: Optional[Aggregator] = None,
launch_external_process: bool = False,
command: str = "python3 -u",
server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY,
device: Optional[str] = None,
source_model: str = "model",
optimizer_args: Optional[dict] = None,
lr_scheduler_args: Optional[dict] = None,
server_memory_gc_rounds: int = 1,
client_memory_gc_rounds: int = 0,
cuda_empty_cache: bool = False,
):
# Validate inputs internally
v = _FedOptValidator(
name=name,
model=model,
initial_ckpt=initial_ckpt,
min_clients=min_clients,
num_rounds=num_rounds,
train_script=train_script,
train_args=train_args,
aggregator=aggregator,
launch_external_process=launch_external_process,
command=command,
server_expected_format=server_expected_format,
device=device,
server_memory_gc_rounds=server_memory_gc_rounds,
client_memory_gc_rounds=client_memory_gc_rounds,
cuda_empty_cache=cuda_empty_cache,
)
self.name = v.name
self.model = v.model
self.initial_ckpt = v.initial_ckpt
# Validate inputs using shared utilities
from nvflare.recipe.utils import ensure_config_type_dict, recipe_model_to_job_model, validate_ckpt
validate_ckpt(self.initial_ckpt)
if isinstance(self.model, dict):
self.model = recipe_model_to_job_model(self.model)
self.min_clients = v.min_clients
self.num_rounds = v.num_rounds
self.train_script = v.train_script
self.train_args = v.train_args
self.aggregator = v.aggregator
self.launch_external_process = v.launch_external_process
self.command = v.command
self.server_expected_format: ExchangeFormat = v.server_expected_format
self.device = device
self.source_model = source_model
# Ensure config_type "dict" so the component builder does not try to instantiate
# optimizer/scheduler at config load time (params/optimizer are set at runtime).
self.optimizer_args = ensure_config_type_dict(optimizer_args)
self.lr_scheduler_args = ensure_config_type_dict(lr_scheduler_args)
self.server_memory_gc_rounds = v.server_memory_gc_rounds
self.client_memory_gc_rounds = v.client_memory_gc_rounds
self.cuda_empty_cache = v.cuda_empty_cache
# Replace {num_rounds} placeholder if present in lr_scheduler_args
processed_lr_scheduler_args = None
if self.lr_scheduler_args is not None:
processed_lr_scheduler_args = self.lr_scheduler_args.copy()
if "args" in processed_lr_scheduler_args:
lr_args = processed_lr_scheduler_args["args"].copy()
if "T_max" in lr_args and lr_args["T_max"] == "{num_rounds}":
lr_args["T_max"] = self.num_rounds
processed_lr_scheduler_args["args"] = lr_args
# Create BaseFedJob with initial model
job = BaseFedJob(
initial_model=None,
name=self.name,
min_clients=self.min_clients,
)
# FedOpt requires a model (either model or initial_ckpt must be provided)
# The PTFedOptModelShareableGenerator needs source_model to exist
if self.model is None:
raise ValueError(
"FedOpt requires model. Provide either:\n"
" - nn.Module instance\n"
" - Dict config: {'class_path': 'module.ClassName', 'args': {...}}\n"
"Note: initial_ckpt alone is not sufficient for PyTorch (model architecture needed)."
)
# Handle dict config: instantiate model before registering as component
# PTFileModelPersistor expects component ID to resolve to nn.Module, not dict
model_to_register = self.model
if isinstance(self.model, dict):
from nvflare.fuel.utils.class_utils import instantiate_class
class_path = self.model.get("path")
class_args = self.model.get("args", {})
try:
model_to_register = instantiate_class(class_path, class_args)
except Exception as e:
raise RuntimeError(f"Failed to instantiate model from dict config: {e}")
# Add initial model as a separate component
job.to_server(model_to_register, id=self.source_model)
# Add the persisted model to the job with checkpoint support
from nvflare.recipe.utils import prepare_initial_ckpt
ckpt_path = prepare_initial_ckpt(self.initial_ckpt, job)
persistor = PTFileModelPersistor(
model=self.source_model,
source_ckpt_file_full_name=ckpt_path,
)
persistor_id = job.to_server(persistor, id="persistor")
locator = PTFileModelLocator(pt_persistor_id=persistor_id)
job.to_server(locator, id="locator")
# Define the controller and send to server
if self.aggregator is None:
self.aggregator = InTimeAccumulateWeightedAggregator(
expected_data_kind=DataKind.WEIGHT_DIFF
) # FedOpt only supports DataKind.WEIGHT_DIFF
else:
if not isinstance(self.aggregator, Aggregator):
raise ValueError(f"Invalid aggregator type: {type(self.aggregator)}. Expected type: {Aggregator}")
# Define the shareable generator with fedopt parameters
shareable_generator = PTFedOptModelShareableGenerator(
optimizer_args=self.optimizer_args,
lr_scheduler_args=processed_lr_scheduler_args,
source_model=self.source_model,
device=self.device,
)
shareable_generator_id = job.to_server(shareable_generator, id="shareable_generator")
aggregator_id = job.to_server(self.aggregator, id="aggregator")
controller = ScatterAndGather(
min_clients=self.min_clients,
num_rounds=self.num_rounds,
wait_time_after_min_received=0,
aggregator_id=aggregator_id,
persistor_id="persistor",
shareable_generator_id=shareable_generator_id,
memory_gc_rounds=self.server_memory_gc_rounds,
)
# Send the controller to the server
job.to_server(controller)
# Add clients
executor = ScriptRunner(
script=self.train_script,
script_args=self.train_args,
launch_external_process=self.launch_external_process,
command=self.command,
framework=FrameworkType.PYTORCH,
server_expected_format=self.server_expected_format,
params_transfer_type=TransferType.DIFF,
memory_gc_rounds=self.client_memory_gc_rounds,
cuda_empty_cache=self.cuda_empty_cache,
)
job.to_clients(executor)
Recipe.__init__(self, job)