Source code for nvflare.edge.tools.et_fed_buff_recipe

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.util
import json
from typing import Dict, Optional

from nvflare.edge.models.model import DeviceModel
from nvflare.edge.tools.edge_fed_buff_recipe import (
    DeviceManagerConfig,
    EdgeFedBuffRecipe,
    EvaluatorConfig,
    ModelManagerConfig,
    SimulationConfig,
)
from nvflare.edge.tools.et_job import ETJob
from nvflare.job_config.file_source import FileSource

_TRAINER_NAME = "trainer"
_DEVICE_CONFIG_FILE_NAME = "device_config.json"


[docs] class ETFedBuffRecipe(EdgeFedBuffRecipe): """Edge Training FedBuff Recipe for embedded/edge device training. This recipe extends EdgeFedBuffRecipe for edge devices with DeviceModel wrapper. Args: job_name: Name of the federated learning job. device_model: DeviceModel wrapping the PyTorch model for edge devices. input_shape: Input shape for the model. output_shape: Output shape for the model. model_manager_config: Configuration for the model manager. device_manager_config: Configuration for the device manager. initial_ckpt: Absolute path to a pre-trained checkpoint file (.pt, .pth). The file may not exist locally (server-side path). evaluator_config: Configuration for the global evaluator (optional). simulation_config: Configuration for simulated devices settings (optional). device_training_params: Training parameters for device (optional). custom_source_root: Path to custom source code (optional). device_wait_timeout: Timeout in seconds for waiting for sufficient devices to join before stopping the job. None means wait indefinitely. WARNING: when device_reuse=False with a finite device pool, leaving this as None can cause the job to hang indefinitely once the pool is exhausted. In that case, set an explicit timeout (e.g., 300.0 seconds). Default: None """ def __init__( self, job_name: str, device_model: DeviceModel, input_shape, output_shape, model_manager_config: ModelManagerConfig, device_manager_config: DeviceManagerConfig, initial_ckpt: Optional[str] = None, evaluator_config: EvaluatorConfig = None, simulation_config: SimulationConfig = None, device_training_params: Dict = None, custom_source_root: str = None, device_wait_timeout: Optional[float] = None, ): if importlib.util.find_spec("executorch.extension.training") is None: raise ImportError( "ETFedBuffRecipe requires executorch. " "See installation instructions: " "https://pytorch.org/executorch/stable/getting-started-setup.html" ) self.device_model = device_model self.input_shape = input_shape self.output_shape = output_shape self.device_training_params = device_training_params pt_model = device_model.net EdgeFedBuffRecipe.__init__( self, job_name=job_name, model=pt_model, model_manager_config=model_manager_config, device_manager_config=device_manager_config, initial_ckpt=initial_ckpt, evaluator_config=evaluator_config, simulation_config=simulation_config, custom_source_root=custom_source_root, device_wait_timeout=device_wait_timeout, )
[docs] def create_job(self): return ETJob( name=self.job_name, edge_method=self.method_name, device_model=self.device_model, input_shape=self.input_shape, output_shape=self.output_shape, )
def _configure_job(self, job): super()._configure_job(job) # add device training config file if specified if self.device_training_params: trainer_config = {"type": "Trainer.DLTrainer", "name": _TRAINER_NAME, "args": self.device_training_params} device_config = {"components": [trainer_config], "executors": {"train": f"@{_TRAINER_NAME}"}} with open(_DEVICE_CONFIG_FILE_NAME, "w") as f: json.dump(device_config, f, indent=2) job.to_server(FileSource(_DEVICE_CONFIG_FILE_NAME, app_folder_type="config"))