Source code for nvflare.app_opt.pt.recipes.fedavg_he

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Optional, Union

from pydantic import BaseModel

from nvflare.apis.dxo import DataKind
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather
from nvflare.app_opt.he.intime_accumulate_model_aggregator import HEInTimeAccumulateWeightedAggregator
from nvflare.app_opt.he.model_decryptor import HEModelDecryptor
from nvflare.app_opt.he.model_encryptor import HEModelEncryptor
from nvflare.app_opt.he.model_serialize_filter import HEModelSerializeFilter
from nvflare.app_opt.he.model_shareable_generator import HEModelShareableGenerator
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob
from nvflare.client.config import ExchangeFormat, TransferType
from nvflare.fuel.utils.constants import FrameworkType
from nvflare.job_config.defs import FilterType
from nvflare.job_config.script_runner import ScriptRunner
from nvflare.recipe.spec import ExecEnv, Recipe

HE_CONTEXT_PROVISIONING_DOC_LINK = "https://nvflare.readthedocs.io/en/2.7/programming_guide/provisioning_system.html"
HE_SIM_ENV_NOT_SUPPORTED_ERROR = (
    "FedAvgRecipeWithHE does not support SimEnv. "
    "Use provisioned startup kits with nvflare.lighter.impl.he.HEBuilder and run with ProdEnv or PocEnv. "
    f"See: {HE_CONTEXT_PROVISIONING_DOC_LINK}"
)


# Internal — not part of the public API
class _FedAvgRecipeWithHEValidator(BaseModel):
    model_config = {"arbitrary_types_allowed": True}

    name: str
    model: Any
    initial_ckpt: Optional[str] = None
    min_clients: int
    num_rounds: int
    train_script: str
    train_args: str
    aggregator: Optional[Aggregator]
    aggregator_data_kind: Optional[DataKind]
    launch_external_process: bool = False
    command: str = "python3 -u"
    server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY
    params_transfer_type: TransferType = TransferType.FULL
    encrypt_layers: Optional[Union[List[str], str]] = None
    server_memory_gc_rounds: int = 1
    client_memory_gc_rounds: int = 0
    cuda_empty_cache: bool = False


[docs] class FedAvgRecipeWithHE(Recipe): """A recipe for implementing Federated Averaging (FedAvg) with Homomorphic Encryption (HE) in NVFlare. FedAvg is a fundamental federated learning algorithm that aggregates model updates from multiple clients by computing a weighted average based on the amount of local training data. This recipe adds homomorphic encryption to preserve privacy during federated learning by allowing computations on encrypted data. The recipe configures: - A federated job with initial model (optional) - Scatter-and-gather controller for coordinating training rounds - HE-enabled weighted aggregator for combining encrypted client model updates - HE shareable generator for converting between Shareable and Learnable objects - HE model encryptor/decryptor filters on the client side - HE model serialization filter on the server side - Script runners for client-side training execution Important: TenSEAL context files must be generated before running this recipe: - `server_context.tenseal` for the server startup folder - `client_context.tenseal` for each client startup folder Use NVFlare provisioning with `nvflare.lighter.impl.he.HEBuilder` so these context files are generated automatically into startup kits. Example project config: `examples/advanced/cifar10/pt/cifar10-real-world/workspaces/secure_project.yml` SimEnv is not supported for this HE recipe. Use ProdEnv or PocEnv with provisioned startup kits. For provisioning details, see: https://nvflare.readthedocs.io/en/2.7/programming_guide/provisioning_system.html Args: name: Name of the federated learning job. Defaults to "fedavg". model: Initial model to start federated training with. Can be: - nn.Module instance - Dict config: {"class_path": "module.ClassName", "args": {"param": value}} - None: no initial model initial_ckpt: Path to a pre-trained checkpoint file. Can be: - Relative path: file will be bundled into the job's custom/ directory. - Absolute path: treated as a server-side path, used as-is at runtime. Note: PyTorch requires model when using initial_ckpt (for architecture). min_clients: Minimum number of clients required to start a training round. num_rounds: Number of federated training rounds to execute. Defaults to 2. train_script: Path to the training script that will be executed on each client. train_args: Command line arguments to pass to the training script. aggregator: Aggregator for combining client updates. If None, uses HEInTimeAccumulateWeightedAggregator with aggregator_data_kind. aggregator_data_kind: Data kind to use for the aggregator. Defaults to DataKind.WEIGHTS. launch_external_process (bool): Whether to launch the script in external process. Defaults to False. command (str): If launch_external_process=True, command to run script (prepended to script). Defaults to "python3". server_expected_format (str): What format to exchange the parameters between server and client. params_transfer_type (str): How to transfer the parameters. FULL means the whole model parameters are sent. DIFF means that only the difference is sent. Defaults to TransferType.FULL. encrypt_layers: if not specified (None), all layers are being encrypted; if list of variable/layer names, only specified variables are encrypted; if string containing regular expression (e.g. "conv"), only matched variables are being encrypted. server_memory_gc_rounds: Run memory cleanup (gc.collect + malloc_trim) every N rounds on server. Set to 0 to disable. Defaults to 1 (every round). Example: ```python recipe = FedAvgRecipeWithHE( name="my_fedavg_he_job", model=pretrained_model, min_clients=2, num_rounds=10, train_script="client.py", train_args="--epochs 5 --batch_size 32" ) ``` Note: This recipe implements FedAvg with homomorphic encryption (HE) using TenSEAL library. HE allows computations to be performed on encrypted data, preserving client privacy. The following HE components are configured: - Server side: HEModelShareableGenerator, HEInTimeAccumulateWeightedAggregator, HEModelSerializeFilter - Client side: HEModelDecryptor (for incoming data), HEModelEncryptor (for outgoing results) Model updates are aggregated using weighted averaging based on the number of training samples provided by each client, with encryption/decryption handled transparently. If you want to use a custom aggregator, you can pass it in the aggregator parameter. The custom aggregator should support HE operations or be a subclass of HEInTimeAccumulateWeightedAggregator. """ def __init__( self, *, name: str = "fedavg_he", model: Union[Any, dict[str, Any], None] = None, initial_ckpt: Optional[str] = None, min_clients: int, num_rounds: int = 2, train_script: str, train_args: str = "", aggregator: Optional[Aggregator] = None, aggregator_data_kind: Optional[DataKind] = DataKind.WEIGHTS, launch_external_process: bool = False, command: str = "python3 -u", server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY, params_transfer_type: TransferType = TransferType.FULL, encrypt_layers: Optional[Union[List[str], str]] = None, server_memory_gc_rounds: int = 1, client_memory_gc_rounds: int = 0, cuda_empty_cache: bool = False, ): # Validate inputs internally v = _FedAvgRecipeWithHEValidator( name=name, model=model, initial_ckpt=initial_ckpt, min_clients=min_clients, num_rounds=num_rounds, train_script=train_script, train_args=train_args, aggregator=aggregator, aggregator_data_kind=aggregator_data_kind, launch_external_process=launch_external_process, command=command, server_expected_format=server_expected_format, params_transfer_type=params_transfer_type, encrypt_layers=encrypt_layers, server_memory_gc_rounds=server_memory_gc_rounds, client_memory_gc_rounds=client_memory_gc_rounds, cuda_empty_cache=cuda_empty_cache, ) self.name = v.name self.model = v.model self.initial_ckpt = v.initial_ckpt # Validate inputs using shared utilities from nvflare.recipe.utils import recipe_model_to_job_model, validate_ckpt validate_ckpt(self.initial_ckpt) if isinstance(self.model, dict): self.model = recipe_model_to_job_model(self.model) self.min_clients = v.min_clients self.num_rounds = v.num_rounds self.train_script = v.train_script self.train_args = v.train_args self.aggregator = v.aggregator self.aggregator_data_kind = v.aggregator_data_kind self.launch_external_process = v.launch_external_process self.command = v.command self.server_expected_format: ExchangeFormat = v.server_expected_format self.params_transfer_type: TransferType = v.params_transfer_type self.encrypt_layers: Optional[Union[List[str], str]] = v.encrypt_layers self.server_memory_gc_rounds = v.server_memory_gc_rounds self.client_memory_gc_rounds = v.client_memory_gc_rounds self.cuda_empty_cache = v.cuda_empty_cache # Create BaseFedJob without model first (model setup done manually below for HE) job = BaseFedJob( name=self.name, min_clients=self.min_clients, ) # Create a persistor with HE serialization filter if initial model or checkpoint is provided if self.model is not None or self.initial_ckpt is not None: from nvflare.app_opt.pt.job_config.model import PTModel from nvflare.recipe.utils import prepare_initial_ckpt ckpt_path = prepare_initial_ckpt(self.initial_ckpt, job) model_persistor = PTFileModelPersistor( model=self.model, source_ckpt_file_full_name=ckpt_path, filter_id="model_serialize_filter", ) pt_model = PTModel(model=self.model, persistor=model_persistor) job.comp_ids.update(job.to_server(pt_model)) # Add HE model serialization filter (must be added before persistor uses it) if self.model is not None or self.initial_ckpt is not None: model_serialize_filter = HEModelSerializeFilter() job.to_server(model_serialize_filter, id="model_serialize_filter") # Define the HE-specific components for the server if self.aggregator is None: self.aggregator = HEInTimeAccumulateWeightedAggregator( expected_data_kind=self.aggregator_data_kind, weigh_by_local_iter=False, # HE: weighting happens client-side in HEModelEncryptor (train task) ) else: if not isinstance(self.aggregator, Aggregator): raise ValueError(f"Invalid aggregator type: {type(self.aggregator)}. Expected type: {Aggregator}") # Use HE-specific shareable generator shareable_generator = HEModelShareableGenerator() shareable_generator_id = job.to_server(shareable_generator, id="shareable_generator") aggregator_id = job.to_server(self.aggregator, id="aggregator") controller = ScatterAndGather( min_clients=self.min_clients, num_rounds=self.num_rounds, wait_time_after_min_received=0, aggregator_id=aggregator_id, persistor_id=( job.comp_ids.get("persistor_id", "") if (self.model is not None or self.initial_ckpt is not None) else "" ), shareable_generator_id=shareable_generator_id, memory_gc_rounds=self.server_memory_gc_rounds, ) # Send the controller to the server job.to_server(controller) # Add clients with HE filters executor = ScriptRunner( script=self.train_script, script_args=self.train_args, launch_external_process=self.launch_external_process, command=self.command, framework=FrameworkType.PYTORCH, server_expected_format=self.server_expected_format, params_transfer_type=self.params_transfer_type, memory_gc_rounds=self.client_memory_gc_rounds, cuda_empty_cache=self.cuda_empty_cache, ) job.to_clients(executor) # Add HE model decryptor as task data filter when training or validating (decrypt incoming data from server) job.to_clients(HEModelDecryptor(), tasks=["train", "validate"], filter_type=FilterType.TASK_DATA) # Add HE model encryptor as task result filter after training (encrypt outgoing results to server) job.to_clients( HEModelEncryptor( encrypt_layers=encrypt_layers, weigh_by_local_iter=True, # Client-side weighting for HE (aggregator has weigh_by_local_iter=False) ), tasks=["train"], filter_type=FilterType.TASK_RESULT, ) # Add HE model encryptor as task result filter when submitting model (encrypt outgoing results to server) job.to_clients( HEModelEncryptor( encrypt_layers=encrypt_layers, weigh_by_local_iter=False ), # We don't need to weight by local iter when submitting model for evaluation tasks=["submit_model"], filter_type=FilterType.TASK_RESULT, ) Recipe.__init__(self, job)
[docs] def process_env(self, env: ExecEnv): from nvflare.recipe.sim_env import SimEnv if isinstance(env, SimEnv): raise ValueError(HE_SIM_ENV_NOT_SUPPORTED_ERROR) super().process_env(env)