Source code for nvflare.edge.tools.edge_job
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os.path
from typing import Optional
from nvflare.edge.assessor import Assessor
from nvflare.edge.controllers.sage import ScatterAndGatherForEdge
from nvflare.edge.executors.edge_model_executor import EdgeModelExecutor
from nvflare.edge.simulation.device_task_processor import DeviceTaskProcessor
from nvflare.edge.updaters.emd import AggregatorFactory
from nvflare.edge.widgets.etr import EdgeTaskReceiver
from nvflare.edge.widgets.tp_runner import TPRunner
from nvflare.edge.widgets.tpo_runner import TPORunner
from nvflare.fuel.utils.validation_utils import check_object_type, check_positive_int, check_positive_number, check_str
from nvflare.job_config.api import FedJob
from nvflare.job_config.file_source import FileSource
[docs]
class EdgeJob(FedJob):
def __init__(
self,
name: str,
edge_method: str,
min_clients: int = 1,
):
"""Constructor of EdgeJob
Args:
name: name of the job.
edge_method: method for matching job request. Goes to the job's meta.
min_clients: min number of clients required for the job.
"""
check_str("edge_method", edge_method)
FedJob.__init__(self, name=name, min_clients=min_clients, meta_props={"edge_method": edge_method})
self.server_config_added = False
self.client_config_added = False
self.simulation_set = False
[docs]
def configure_server(
self,
assessor: Assessor,
num_rounds: int = 1,
task_name: str = "train",
assess_interval: float = 0.5,
update_interval: float = 1.0,
):
"""Set up server config.
Args:
assessor: The Assessor object for assessing workflow progress.
num_rounds: number of rounds.
task_name: name of the task.
assess_interval: how often to perform assessment.
update_interval: how often the clients should send updates.
Returns: None
"""
if self.server_config_added:
raise RuntimeError("server config is already added")
check_object_type("assessor", assessor, Assessor)
check_positive_int("num_rounds", num_rounds)
check_str("task_name", task_name)
assessor_id = self.to_server(assessor, id="wf_assessor")
controller = ScatterAndGatherForEdge(
assessor_id=assessor_id,
num_rounds=num_rounds,
task_name=task_name,
task_check_period=0.5,
assess_interval=assess_interval,
update_interval=update_interval,
)
self.to_server(controller, id="sage")
self.server_config_added = True
[docs]
def configure_client(
self,
aggregator_factory: AggregatorFactory,
max_model_versions: Optional[int] = None,
update_timeout=5.0,
executor_task_name="train",
simulation_config_file: str = None,
):
"""Set up client config.
Args:
aggregator_factory: an AggregatorFactory object to create aggregators when needed.
max_model_versions: max number of model versions to keep.
update_timeout: timeout for status update messages.
executor_task_name: task name for executor.
simulation_config_file: config file for local simulation (optional).
Returns: None
"""
if self.client_config_added:
raise RuntimeError("client config is already added")
# check the validity of max_model_versions if not None
if max_model_versions:
check_positive_int("max_model_versions", max_model_versions)
check_object_type("aggregator_factory", aggregator_factory, AggregatorFactory)
check_positive_number("update_timeout", update_timeout)
check_str("executor_task_name", executor_task_name)
if simulation_config_file:
check_str("simulation_config_file", simulation_config_file)
self.to_clients(EdgeTaskReceiver(), id="edge_task_receiver")
aggr_factory_id = self.to_clients(aggregator_factory, id="aggr_factory")
executor = self._configure_executor(
aggr_factory_id=aggr_factory_id, max_model_versions=max_model_versions, update_timeout=update_timeout
)
self.to_clients(executor, id="executor", tasks=[executor_task_name])
if simulation_config_file:
self.configure_simulation_with_file(simulation_config_file)
self.client_config_added = True
def _configure_executor(self, aggr_factory_id, max_model_versions, update_timeout):
return EdgeModelExecutor(
aggr_factory_id=aggr_factory_id, max_model_versions=max_model_versions, update_timeout=update_timeout
)
[docs]
def configure_simulation_with_file(self, simulation_config_file: str):
"""Configure simulation with a config file.
Args:
simulation_config_file: the simulation config file.
Returns:
"""
if self.simulation_set:
raise RuntimeError("simulation is already configured")
if not os.path.isfile(simulation_config_file):
raise ValueError(f"file {simulation_config_file} does not exist or is not a valid file")
try:
with open(simulation_config_file, "r") as f:
json.load(f)
except Exception as ex:
raise ValueError(f"file {simulation_config_file} is not a valid JSON file: {ex}")
self.to_clients(FileSource(simulation_config_file, app_folder_type="config"))
base_name = os.path.basename(simulation_config_file)
conf_file = "{JOB_CONFIG_DIR}/" + f"{base_name}"
self.to_clients(TPRunner(conf_file), id="tp_runner")
self.simulation_set = True
[docs]
def configure_simulation(
self,
task_processor: DeviceTaskProcessor,
job_timeout: float = 60.0,
num_devices: int = 1000,
num_workers: int = 10,
):
"""Configure simulation with a DeviceTaskProcessor.
Args:
task_processor: the DeviceTaskProcessor object to be used for processing tasks.
job_timeout: timeout for trying to get job.
num_devices: number of devices to simulate.
num_workers: number of workers for executing tasks.
Returns: None
"""
if self.simulation_set:
raise RuntimeError("simulation is already configured")
tp_id = self.to_clients(task_processor, "task_processor")
runner = TPORunner(tp_id, job_timeout, num_devices, num_workers)
self.to_clients(runner, "tpo_runner")
self.simulation_set = True