Source code for nvflare.edge.device.sdk_spec

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import time
from typing import List

from nvflare.apis.dxo import DXO, from_dict
from nvflare.apis.signal import Signal

from .config import process_train_config
from .defs import Context, ContextKey, DataSource, EventType, Executor, Filter


[docs] class FlareRunner: def __init__( self, job_name: str, data_source: DataSource, device_info: dict, user_info: dict, job_timeout: float, in_filters: List[Filter] = None, out_filters: List[Filter] = None, resolver_registry: dict = None, ): """Constructor of FlareRunner Args: job_name: name of the job. Used for matching Flare job on host. data_source: data source for the training device_info: device info user_info: info of the device user job_timeout: timeout for getting a job from Flare host in_filters: app provided filters for input model out_filters: app provided filters for output model resolver_registry: app provided resolvers Note: app provided filters apply to all jobs and are invoked before configured job filters! """ self.job_name = job_name self.resolver_registry = {} self.data_source = data_source self.device_info = device_info self.user_info = user_info self.job_timeout = job_timeout self.app_in_filters = in_filters self.app_out_filters = out_filters self.abort_signal = Signal() self.job_id = None self.cookie = None # add built-in creators self.add_builtin_resolvers() # add app-provided resolvers, which can override builtin resolvers! if resolver_registry: if not isinstance(resolver_registry, dict): raise ValueError(f"resolver_registry must be dict but got {type(resolver_registry)}") self.resolver_registry.update(resolver_registry)
[docs] def add_builtin_resolvers(self): """Add resolvers for Flare's builtin components Returns: """ pass
[docs] def run(self): while True: sess_done = self._do_one_job() if sess_done: return
[docs] def stop(self): if self.abort_signal: self.abort_signal.trigger(True)
def _get_job(self, ctx: Context, abort_signal: Signal) -> dict: """Repeatedly try to get job from host Returns: a job or None if the host says DONE or timed out. """ pass def _get_task(self, ctx: Context, abort_signal: Signal) -> (dict, bool): """Repeatedly try to get a task from the host Returns: a task or None if the host says DONE """ pass def _report_result(self, result: dict, ctx: Context, abort_signal: Signal) -> bool: pass def _do_filtering(self, data: DXO, filters, ctx) -> DXO: if filters: for f in filters: assert isinstance(f, Filter) data = f.filter(data, ctx, self.abort_signal) if self.abort_signal.triggered: break return data def _do_one_job(self) -> bool: """Work with the host to do one job Returns: whether whole session is done. """ ctx = Context() ctx[ContextKey.RUNNER] = self ctx[ContextKey.DATA_SOURCE] = self.data_source # try to get job job = self._get_job(ctx, self.abort_signal) if not job: # No job for me. return True self.job_name = job.get("job_name") self.job_id = job.get("job_id") job_data = job.get("job_data") # the job_data in the job contains trainer config! train_config = process_train_config(job_data, self.resolver_registry) ctx[ContextKey.COMPONENTS] = train_config.objects ctx[ContextKey.EVENT_HANDLERS] = train_config.event_handlers in_filters = [] if self.app_in_filters: in_filters.extend(self.app_in_filters) if train_config.in_filters: in_filters.extend(train_config.in_filters) out_filters = [] if self.app_out_filters: out_filters.extend(self.app_out_filters) if train_config.out_filters: out_filters.extend(train_config.out_filters) while True: task, sess_done = self._get_task(ctx, self.abort_signal) if self.abort_signal.triggered: return True if not task: # no more work for this job return sess_done # create a new context for each task! task_ctx = copy.copy(ctx) # task is a dict assert isinstance(task, dict) self.cookie = task.get("cookie") task_name = task.get("task_name") # task data is DXO format task_data = task.get("task_data") task_dxo = from_dict(task_data) # find the right executor executor = train_config.find_executor(task_name) if not executor: raise RuntimeError(f"cannot find executor for task {task_name}") if not isinstance(executor, Executor): raise RuntimeError(f"bad executor for task {task_name}: expect Executor but got {type(executor)}") task_ctx[ContextKey.TASK_ID] = task.get("task_id") task_ctx[ContextKey.TASK_NAME] = task_name task_ctx[ContextKey.TASK_DATA] = task_data task_ctx[ContextKey.EXECUTOR] = executor # filter the input task_dxo = self._do_filtering(task_dxo, in_filters, task_ctx) if not isinstance(task_dxo, DXO): raise RuntimeError(f"task data after filtering is not valid DXO: {type(task_dxo)}") if self.abort_signal.triggered: return True task_ctx.fire_event(EventType.BEFORE_TRAIN, time.time(), self.abort_signal) output = executor.execute(task_dxo, task_ctx, self.abort_signal) # output must follow DXO format if not isinstance(output, DXO): raise RuntimeError(f"output from {type(executor)} is not a valid DXO: {type(output)}") task_ctx.fire_event(EventType.AFTER_TRAIN, (time.time(), output), self.abort_signal) if self.abort_signal.triggered: return True # filter the output output = self._do_filtering(output, out_filters, task_ctx) if not isinstance(output, DXO): raise RuntimeError(f"output after filtering for task {task_name} is not a valid DXO: {type(output)}") if self.abort_signal.triggered: return True sess_done = self._report_result(output.to_dict(), task_ctx, self.abort_signal) if sess_done: return sess_done if self.abort_signal.triggered: return True