# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Union
from pydantic import BaseModel, field_validator
from nvflare.app_common.workflows.model_controller import ModelController
from nvflare.client.config import ExchangeFormat
from nvflare.job_config.base_fed_job import BaseFedJob
from nvflare.job_config.script_runner import FrameworkType, ScriptRunner
from nvflare.recipe.spec import Recipe
from nvflare.recipe.utils import validate_ckpt
# Internal validator
class _FedEvalValidator(BaseModel):
eval_ckpt: str
@field_validator("eval_ckpt")
@classmethod
def validate_eval_ckpt(cls, v):
# eval_ckpt is required for evaluation, validate it
validate_ckpt(v)
return v
model_config = {"arbitrary_types_allowed": True}
[docs]
class EvalController(ModelController):
def __init__(self, persistor_id: str, timeout: int):
super().__init__(persistor_id=persistor_id)
self.timeout = timeout
[docs]
def run(self):
model = self.load_model()
self.info("Sending model for evaluation")
results = self.send_model_and_wait(targets=None, data=model, task_name="validate", timeout=self.timeout)
self.info(f"Got {len(results)} results")
for r in results:
self.info(f"Metrics: {r.metrics}")
[docs]
class FedEvalRecipe(Recipe):
"""A recipe for federated evaluation of a PyTorch model across multiple sites.
This recipe sets up a federated evaluation workflow where a global model
from the server is sent to multiple clients for evaluation. Each client evaluates
the model on their local data and reports metrics back to the server.
The recipe configures:
- A federated job with an initial model to evaluate
- EvalController for coordinating federated evaluation across clients
- Script runners for client-side evaluation execution
Args:
name: Name of the federated evaluation job. Defaults to "eval".
model: Model structure to evaluate. Can be:
- An instantiated nn.Module (e.g., Net())
- A dict config: {"class_path": "module.ClassName", "args": {...}}
eval_ckpt: Absolute path to pre-trained checkpoint file (.pt, .pth, etc.).
Required for evaluation - specifies which weights to evaluate.
The file may not exist locally (server-side path).
min_clients: Minimum number of clients required to start evaluation.
eval_script: Path to the evaluation script that will be executed on each client.
eval_args: Command line arguments to pass to the evaluation script. Defaults to "".
launch_external_process: Whether to launch the script in external process. Defaults to False.
command: If launch_external_process=True, command to run script (prepended to script).
Defaults to "python3 -u".
server_expected_format: What format to exchange the parameters between server and client.
Defaults to ExchangeFormat.NUMPY.
validation_timeout: Timeout for evaluation task in seconds. Defaults to 6000.
per_site_config: Per-site configuration for the evaluation job. Dictionary mapping
site names to configuration dicts. Each config dict can contain optional overrides:
eval_script, eval_args, launch_external_process, command, server_expected_format.
If not provided, the same configuration will be used for all clients. Defaults to None.
Example:
Basic usage with model instance:
```python
from nvflare.app_opt.pt.recipes.fedeval import FedEvalRecipe
from model import Net
recipe = FedEvalRecipe(
name="eval_job",
model=Net(),
eval_ckpt="/path/to/pretrained_model.pt",
min_clients=2,
eval_script="client.py",
eval_args="--batch_size 32",
)
```
Using dict config:
```python
recipe = FedEvalRecipe(
name="eval_job",
model={"class_path": "my_module.Net", "args": {"num_classes": 10}},
eval_ckpt="/path/to/pretrained_model.pt",
min_clients=2,
eval_script="client.py",
)
```
"""
def __init__(
self,
*,
name: str = "eval",
model: Union[Any, Dict[str, Any]],
eval_ckpt: str,
min_clients: int,
eval_script: str,
eval_args: str = "",
launch_external_process: bool = False,
command: str = "python3 -u",
server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY,
validation_timeout: int = 6000,
per_site_config: Optional[Dict[str, Dict]] = None,
client_memory_gc_rounds: int = 0,
cuda_empty_cache: bool = False,
):
# Validate eval_ckpt
_FedEvalValidator(eval_ckpt=eval_ckpt)
self.name = name
self.model = model
self.eval_ckpt = eval_ckpt
self.min_clients = min_clients
self.eval_script = eval_script
self.eval_args = eval_args
self.launch_external_process = launch_external_process
self.command = command
self.server_expected_format = server_expected_format
self.validation_timeout = validation_timeout
self.per_site_config = per_site_config
self.client_memory_gc_rounds = client_memory_gc_rounds
self.cuda_empty_cache = cuda_empty_cache
# Create BaseFedJob
job = BaseFedJob(
name=self.name,
min_clients=self.min_clients,
)
# Setup model and persistor using PTModel (handles both nn.Module and dict config)
import torch.nn as nn
from nvflare.app_opt.pt.job_config.model import PTModel
# Validate model type and normalize dict (class_path -> path for job API)
if not isinstance(self.model, (nn.Module, dict)):
raise ValueError(f"model must be nn.Module or dict config, got {type(self.model)}")
if isinstance(self.model, dict):
from nvflare.recipe.utils import recipe_model_to_job_model
self.model = recipe_model_to_job_model(self.model)
# PTModel handles both nn.Module and dict config uniformly
from nvflare.recipe.utils import prepare_initial_ckpt
ckpt_path = prepare_initial_ckpt(self.eval_ckpt, job)
pt_model = PTModel(model=self.model, initial_ckpt=ckpt_path)
result = job.to_server(pt_model)
job.comp_ids.update(result)
persistor_id = job.comp_ids.get("persistor_id")
if not persistor_id:
raise ValueError("Failed to obtain persistor_id from PTModel configuration")
# Simple controller
controller = EvalController(persistor_id=persistor_id, timeout=self.validation_timeout)
job.to_server(controller)
# Add client executors
if self.per_site_config is not None:
for site_name, site_config in self.per_site_config.items():
script = site_config.get("eval_script", self.eval_script)
script_args = site_config.get("eval_args", self.eval_args)
launch_external = site_config.get("launch_external_process", self.launch_external_process)
cmd = site_config.get("command", self.command)
expected_format = site_config.get("server_expected_format", self.server_expected_format)
executor = ScriptRunner(
script=script,
script_args=script_args,
launch_external_process=launch_external,
command=cmd,
framework=FrameworkType.PYTORCH,
server_expected_format=expected_format,
memory_gc_rounds=self.client_memory_gc_rounds,
cuda_empty_cache=self.cuda_empty_cache,
)
job.to(executor, site_name)
else:
executor = ScriptRunner(
script=self.eval_script,
script_args=self.eval_args,
launch_external_process=self.launch_external_process,
command=self.command,
framework=FrameworkType.PYTORCH,
server_expected_format=self.server_expected_format,
memory_gc_rounds=self.client_memory_gc_rounds,
cuda_empty_cache=self.cuda_empty_cache,
)
job.to_clients(executor)
Recipe.__init__(self, job)