Source code for nvflare.edge.web.service.server

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import concurrent.futures.thread
import threading

import grpc

from nvflare.edge.constants import EdgeApiStatus
from nvflare.fuel.f3.drivers.aio_context import AioContext
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.security.logging import secure_format_exception

from .edge_api_pb2 import Reply, Request
from .edge_api_pb2_grpc import EdgeApiServicer, add_EdgeApiServicer_to_server
from .query_handler import QueryHandler
from .utils import make_reply


[docs] class Servicer(EdgeApiServicer): def __init__(self, handler: QueryHandler, aio_ctx: AioContext, max_workers=100): self.logger = get_obj_logger(self) self.handler = handler self.aio_ctx = aio_ctx self.worker_pool = concurrent.futures.thread.ThreadPoolExecutor(max_workers=max_workers)
[docs] async def Query(self, request: Request, context) -> Reply: try: loop = self.aio_ctx.get_event_loop() reply = await loop.run_in_executor(self.worker_pool, self.handler.handle_query, request) if not reply: raise RuntimeError("no result from QueryHandler.") return reply except Exception as ex: self.logger.error(f"error processing request: {secure_format_exception(ex)}") return make_reply(EdgeApiStatus.ERROR)
[docs] class EdgeApiServer: def __init__( self, handler: QueryHandler, address: str, grpc_options=None, max_workers=100, ssl_credentials=None, ): self.aio_ctx = AioContext.get_global_context() self.logger = get_obj_logger(self) self.handler = handler self.address = address self.grpc_options = grpc_options self.max_workers = max_workers self.grpc_server = None self.grpc_server_stop_grace = 0.5 self.waiter = threading.Event() self.root_cert = None self.cert_chain = None self.private_key = None self.ssl_credentials = ssl_credentials async def _start(self): # Note: the AIO grpc server must be created in this coro, because it has to be created in the thread # that runs the event loop! self.logger.info("starting Edge API Server ...") self.grpc_server = grpc.aio.server(options=self.grpc_options) servicer = Servicer(self.handler, self.aio_ctx, self.max_workers) add_EdgeApiServicer_to_server(servicer, self.grpc_server) if self.ssl_credentials: # one-way SSL self.logger.info(f"adding secure port at {self.address} for 1-way ssl") self.grpc_server.add_secure_port(self.address, server_credentials=self.ssl_credentials) self.logger.info(f"added secure port at {self.address}") else: self.grpc_server.add_insecure_port(self.address) self.logger.info(f"added insecure port at {self.address}") self.logger.info("starting server engine") await self.grpc_server.start() self.logger.info("started server and wait for termination") await self.grpc_server.wait_for_termination() async def _shutdown(self): try: await self.grpc_server.stop(grace=self.grpc_server_stop_grace) # Note that self.grpc_server.stop returns immediately. Since we gave 0.5 grace time for RPCs to end, # we wait here until RPCs are done or aborted. # Without this, we may run into "excepthook" error at the end of the program since the GRPC server isn't # properly shutdown. await asyncio.sleep(self.grpc_server_stop_grace) self.grpc_server = None self.logger.debug("Server is stopped!") except Exception as ex: self.logger.debug(f"exception shutdown server: {secure_format_exception(ex)}")
[docs] def start(self): self.aio_ctx.run_coro(self._start()) self.logger.info("waiting for server to finish") self.waiter.wait() self.logger.info("server is done")
[docs] def shutdown(self): self.aio_ctx.run_coro(self._shutdown()) self.waiter.set() self.logger.info("Shutting Down Server")