# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union
from pydantic import BaseModel
from nvflare.app_common.workflows.scaffold import Scaffold
from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob
from nvflare.app_opt.pt.job_config.model import PTModel
from nvflare.client.config import ExchangeFormat, TransferType
from nvflare.fuel.utils.constants import FrameworkType
from nvflare.job_config.script_runner import ScriptRunner
from nvflare.recipe.spec import Recipe
# Internal — not part of the public API
class _ScaffoldValidator(BaseModel):
model_config = {"arbitrary_types_allowed": True}
name: str
model: Any
initial_ckpt: Optional[str] = None
min_clients: int
num_rounds: int
train_script: str
train_args: str
launch_external_process: bool = False
command: str = "python3 -u"
server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY
params_transfer_type: TransferType = TransferType.FULL
server_memory_gc_rounds: int = 0
client_memory_gc_rounds: int = 0
cuda_empty_cache: bool = False
[docs]
class ScaffoldRecipe(Recipe):
"""A recipe for implementing Scaffold in NVFlare.
Implements the training algorithm proposed in
Karimireddy et al. "SCAFFOLD: Stochastic Controlled Averaging for Federated Learning"
(https://arxiv.org/abs/1910.06378).
**Client script requirement**: Unlike FedAvgRecipe, the client script *must* use
`PTScaffoldHelper` (nvflare.app_opt.pt.scaffold): call init(model), model_update()
during training, terms_update() after training, and include
``meta[AlgorithmConstants.SCAFFOLD_CTRL_DIFF] = scaffold_helper.get_delta_controls()``
in the FLModel sent back to the server. A standard flare.receive/send loop without
PTScaffoldHelper will cause server-side aggregation to fail.
This recipe sets up a complete federated learning workflow with Scaffold controller.
Args:
name: Name of the federated learning job. Defaults to "scaffold".
model: Initial model to start federated training with. Can be:
- nn.Module instance
- Dict config: {"class_path": "module.ClassName", "args": {"param": value}}
- None: no initial model
initial_ckpt: Absolute path to a pre-trained checkpoint file. The file may not
exist locally as it could be on the server. Used to load initial weights.
Note: PyTorch requires model when using initial_ckpt (for architecture).
min_clients: Minimum number of clients required to start a training round. Defaults to 2.
num_rounds: Number of federated training rounds to execute. Defaults to 2.
train_script: Path to the training script that will be executed on each client. Defaults to "client.py".
train_args: Command line arguments to pass to the training script. Defaults to "".
server_memory_gc_rounds: Run memory cleanup (gc.collect + malloc_trim) every N rounds on server.
Set to 0 to disable. Defaults to 0.
Example:
```python
recipe = ScaffoldRecipe(
name="my_scaffold_job",
model=pretrained_model,
min_clients=2,
num_rounds=10,
train_script="client.py",
train_args="--epochs 5 --batch_size 32"
)
```
"""
def __init__(
self,
*,
name: str = "scaffold",
model: Union[Any, dict[str, Any], None] = None,
initial_ckpt: Optional[str] = None,
min_clients: int,
num_rounds: int = 2,
train_script: str,
train_args: str = "",
launch_external_process: bool = False,
command: str = "python3 -u",
server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY,
params_transfer_type: TransferType = TransferType.FULL,
server_memory_gc_rounds: int = 0,
client_memory_gc_rounds: int = 0,
cuda_empty_cache: bool = False,
):
# Validate inputs internally
v = _ScaffoldValidator(
name=name,
model=model,
initial_ckpt=initial_ckpt,
min_clients=min_clients,
num_rounds=num_rounds,
train_script=train_script,
train_args=train_args,
launch_external_process=launch_external_process,
command=command,
server_expected_format=server_expected_format,
params_transfer_type=params_transfer_type,
server_memory_gc_rounds=server_memory_gc_rounds,
client_memory_gc_rounds=client_memory_gc_rounds,
cuda_empty_cache=cuda_empty_cache,
)
self.name = v.name
self.model = v.model
self.initial_ckpt = v.initial_ckpt
# Validate inputs using shared utilities
from nvflare.recipe.utils import recipe_model_to_job_model, validate_ckpt
validate_ckpt(self.initial_ckpt)
if isinstance(self.model, dict):
self.model = recipe_model_to_job_model(self.model)
self.min_clients = v.min_clients
self.num_rounds = v.num_rounds
self.train_script = v.train_script
self.train_args = v.train_args
self.launch_external_process = v.launch_external_process
self.command = v.command
self.server_expected_format: ExchangeFormat = v.server_expected_format
self.params_transfer_type: TransferType = v.params_transfer_type
self.server_memory_gc_rounds = v.server_memory_gc_rounds
self.client_memory_gc_rounds = v.client_memory_gc_rounds
self.cuda_empty_cache = v.cuda_empty_cache
# Create BaseFedJob
job = BaseFedJob(
initial_model=None, # We'll setup model below
name=self.name,
min_clients=self.min_clients,
)
# Setup model persistor using PTModel
persistor_id = ""
if self.model is not None or self.initial_ckpt is not None:
from nvflare.recipe.utils import prepare_initial_ckpt
ckpt_path = prepare_initial_ckpt(self.initial_ckpt, job)
pt_model = PTModel(model=self.model, initial_ckpt=ckpt_path)
result = job.to_server(pt_model, id="persistor")
persistor_id = result["persistor_id"]
# Define the controller and send to server
controller = Scaffold(
num_clients=self.min_clients, # Scaffold controller requires the number of clients to be the same as the min_clients
num_rounds=self.num_rounds,
persistor_id=persistor_id,
memory_gc_rounds=self.server_memory_gc_rounds,
)
# Send the controller to the server
job.to_server(controller)
# Add clients
executor = ScriptRunner(
script=self.train_script,
script_args=self.train_args,
launch_external_process=self.launch_external_process,
command=self.command,
framework=FrameworkType.PYTORCH,
server_expected_format=self.server_expected_format,
params_transfer_type=self.params_transfer_type,
memory_gc_rounds=self.client_memory_gc_rounds,
cuda_empty_cache=self.cuda_empty_cache,
)
job.to_clients(executor)
Recipe.__init__(self, job)