# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import json
from typing import Dict, Optional
from nvflare.edge.models.model import DeviceModel
from nvflare.edge.tools.edge_fed_buff_recipe import (
DeviceManagerConfig,
EdgeFedBuffRecipe,
EvaluatorConfig,
ModelManagerConfig,
SimulationConfig,
)
from nvflare.edge.tools.et_job import ETJob
from nvflare.job_config.file_source import FileSource
_TRAINER_NAME = "trainer"
_DEVICE_CONFIG_FILE_NAME = "device_config.json"
[docs]
class ETFedBuffRecipe(EdgeFedBuffRecipe):
"""Edge Training FedBuff Recipe for embedded/edge device training.
This recipe extends EdgeFedBuffRecipe for edge devices with DeviceModel wrapper.
Args:
job_name: Name of the federated learning job.
device_model: DeviceModel wrapping the PyTorch model for edge devices.
input_shape: Input shape for the model.
output_shape: Output shape for the model.
model_manager_config: Configuration for the model manager.
device_manager_config: Configuration for the device manager.
initial_ckpt: Absolute path to a pre-trained checkpoint file (.pt, .pth).
The file may not exist locally (server-side path).
evaluator_config: Configuration for the global evaluator (optional).
simulation_config: Configuration for simulated devices settings (optional).
device_training_params: Training parameters for device (optional).
custom_source_root: Path to custom source code (optional).
device_wait_timeout: Timeout in seconds for waiting for sufficient devices
to join before stopping the job. None means wait indefinitely.
WARNING: when device_reuse=False with a finite device pool, leaving this
as None can cause the job to hang indefinitely once the pool is exhausted.
In that case, set an explicit timeout (e.g., 300.0 seconds).
Default: None
"""
def __init__(
self,
job_name: str,
device_model: DeviceModel,
input_shape,
output_shape,
model_manager_config: ModelManagerConfig,
device_manager_config: DeviceManagerConfig,
initial_ckpt: Optional[str] = None,
evaluator_config: EvaluatorConfig = None,
simulation_config: SimulationConfig = None,
device_training_params: Dict = None,
custom_source_root: str = None,
device_wait_timeout: Optional[float] = None,
):
if importlib.util.find_spec("executorch.extension.training") is None:
raise ImportError(
"ETFedBuffRecipe requires executorch. "
"See installation instructions: "
"https://pytorch.org/executorch/stable/getting-started-setup.html"
)
self.device_model = device_model
self.input_shape = input_shape
self.output_shape = output_shape
self.device_training_params = device_training_params
pt_model = device_model.net
EdgeFedBuffRecipe.__init__(
self,
job_name=job_name,
model=pt_model,
model_manager_config=model_manager_config,
device_manager_config=device_manager_config,
initial_ckpt=initial_ckpt,
evaluator_config=evaluator_config,
simulation_config=simulation_config,
custom_source_root=custom_source_root,
device_wait_timeout=device_wait_timeout,
)
[docs]
def create_job(self):
return ETJob(
name=self.job_name,
edge_method=self.method_name,
device_model=self.device_model,
input_shape=self.input_shape,
output_shape=self.output_shape,
)
def _configure_job(self, job):
super()._configure_job(job)
# add device training config file if specified
if self.device_training_params:
trainer_config = {"type": "Trainer.DLTrainer", "name": _TRAINER_NAME, "args": self.device_training_params}
device_config = {"components": [trainer_config], "executors": {"train": f"@{_TRAINER_NAME}"}}
with open(_DEVICE_CONFIG_FILE_NAME, "w") as f:
json.dump(device_config, f, indent=2)
job.to_server(FileSource(_DEVICE_CONFIG_FILE_NAME, app_folder_type="config"))