Source code for nvflare.edge.widgets.api_service

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import ConnectionSecurity, FLContextKey, SecureTrainConst
from nvflare.apis.fl_context import FLContext
from nvflare.edge.constants import EdgeContextKey, EdgeEventType
from nvflare.edge.web.models.job_request import JobRequest
from nvflare.edge.web.models.job_response import JobResponse
from nvflare.edge.web.models.result_report import ResultReport
from nvflare.edge.web.models.result_response import ResultResponse
from nvflare.edge.web.models.selection_request import SelectionRequest
from nvflare.edge.web.models.selection_response import SelectionResponse
from nvflare.edge.web.models.task_request import TaskRequest
from nvflare.edge.web.models.task_response import TaskResponse
from nvflare.edge.web.service.query_handler import QueryHandler
from nvflare.edge.web.service.server import EdgeApiServer
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.drivers.grpc.utils import get_grpc_server_credentials
from nvflare.fuel.f3.drivers.net_utils import enhance_credential_info
from nvflare.widgets.widget import Widget


[docs] class ApiService(Widget, QueryHandler): def __init__(self, host: str, port: int, max_workers=100): Widget.__init__(self) QueryHandler.__init__(self) self.max_workers = max_workers self.address = f"{host}:{port}" self.engine = None self.server = None self.register_event_handler(EventType.SYSTEM_START, self._startup) self.register_event_handler(EventType.SYSTEM_END, self._shutdown) def _handle_all_request(self, request, event_type: str): with self.engine.new_context() as fl_ctx: assert isinstance(fl_ctx, FLContext) fl_ctx.set_prop(EdgeContextKey.REQUEST_FROM_EDGE, request, sticky=False, private=True) self.fire_event(event_type, fl_ctx) result = fl_ctx.get_prop(EdgeContextKey.REPLY_TO_EDGE) if not result: self.logger.warning(f"no result from ETD for event {event_type}") return result
[docs] def handle_job_request(self, request: JobRequest) -> JobResponse: return self._handle_all_request(request, EdgeEventType.EDGE_JOB_REQUEST_RECEIVED)
[docs] def handle_task_request(self, request: TaskRequest) -> TaskResponse: return self._handle_all_request(request, EdgeEventType.EDGE_TASK_REQUEST_RECEIVED)
[docs] def handle_selection_request(self, request: SelectionRequest) -> SelectionResponse: return self._handle_all_request(request, EdgeEventType.EDGE_SELECTION_REQUEST_RECEIVED)
[docs] def handle_result_report(self, request: ResultReport) -> ResultResponse: return self._handle_all_request(request, EdgeEventType.EDGE_RESULT_REPORT_RECEIVED)
def _startup(self, _event_type: str, fl_ctx: FLContext): client_config = fl_ctx.get_prop(FLContextKey.CLIENT_CONFIG) root_cert_path = client_config.get(SecureTrainConst.SSL_ROOT_CERT) params = { DriverParams.CA_CERT.value: root_cert_path, DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS, } enhance_credential_info(params) ssl_credentials = None ca_cert_file = root_cert_path server_cert_file = params.get(DriverParams.SERVER_CERT.value) server_key_file = params.get(DriverParams.SERVER_KEY.value) if ca_cert_file and server_cert_file and server_key_file: ssl_credentials = get_grpc_server_credentials(params) self.engine = fl_ctx.get_engine() # TODO: add ssl support self.server = EdgeApiServer( handler=self, address=self.address, max_workers=self.max_workers, ssl_credentials=ssl_credentials, ) t = threading.Thread(target=self.server.start, daemon=True) t.start() self.log_info(fl_ctx, f"Edge API GRPC Service is started on address {self.address}") def _shutdown(self, _event_type: str, fl_ctx: FLContext): self.log_info(fl_ctx, f"Edge API GRPC Service on address {self.address} is shutting down") if self.server: self.server.shutdown()