Source code for nvflare.edge.web.service.query

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os.path
from typing import Tuple, Union

from nvflare.apis.fl_constant import ConnectionSecurity
from nvflare.edge.constants import EdgeApiStatus
from nvflare.edge.web.models.job_request import JobRequest
from nvflare.edge.web.models.job_response import JobResponse
from nvflare.edge.web.models.result_report import ResultReport
from nvflare.edge.web.models.result_response import ResultResponse
from nvflare.edge.web.models.selection_request import SelectionRequest
from nvflare.edge.web.models.selection_response import SelectionResponse
from nvflare.edge.web.models.task_request import TaskRequest
from nvflare.edge.web.models.task_response import TaskResponse
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.drivers.grpc.utils import get_grpc_client_credentials
from nvflare.fuel.utils.hash_utils import UniformHash
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.security.logging import secure_format_exception

from .client import EdgeApiClient
from .utils import (
    grpc_reply_to_job_response,
    grpc_reply_to_result_response,
    grpc_reply_to_selection_response,
    grpc_reply_to_task_response,
    job_request_to_grpc_request,
    result_report_to_grpc_request,
    selection_request_to_grpc_request,
    task_request_to_grpc_request,
)


[docs] class Query: def __init__(self, lcp_mapping_file: str = None, ca_cert_file: str = None): ssl_credentials = None if ca_cert_file: if not os.path.isfile(ca_cert_file): raise ValueError(f"specified ca_cert_file {ca_cert_file} does not exist or is not a file") params = { DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS, DriverParams.CA_CERT.value: ca_cert_file, } ssl_credentials = get_grpc_client_credentials(params) self.lcp_list = [] # TODO: add ssl support self.client = EdgeApiClient(ssl_credentials=ssl_credentials) # self.client = EdgeApiClient() self.logger = get_obj_logger(self) if lcp_mapping_file: self.load_lcp_map(lcp_mapping_file) def _add_lcp(self, name: str, addr: str): self.lcp_list.append((name, addr)) def _map(self, device_id: str) -> Tuple[str, str]: uniform_hash = UniformHash(len(self.lcp_list)) index = uniform_hash.hash(device_id) return self.lcp_list[index]
[docs] def load_lcp_map(self, mapping_file: str): with open(mapping_file, "r") as f: mapping = json.load(f) for name, config in mapping.items(): host = config["host"] port = config["port"] addr = f"{host}:{port}" self._add_lcp(name, addr)
def _query( self, request: Union[TaskRequest, JobRequest, SelectionRequest, ResultReport], to_grpc_f, from_grpc_f, default_response, ): if not self.lcp_list: self.logger.error("No LCP configured") return default_response grpc_req = to_grpc_f(request) device_id = request.get_device_id() name, addr = self._map(device_id) self.logger.debug(f"sending request {type(request)} to {name} at {addr}") try: grpc_reply = self.client.query(addr, grpc_req) resp = from_grpc_f(grpc_reply) if not resp: resp = default_response return resp except Exception as ex: self.logger.error(f"exception querying grpc service: {secure_format_exception(ex)}") return default_response def __call__(self, request: Union[TaskRequest, JobRequest, SelectionRequest, ResultReport]): if isinstance(request, JobRequest): return self._query( request, job_request_to_grpc_request, grpc_reply_to_job_response, JobResponse(EdgeApiStatus.RETRY) ) elif isinstance(request, TaskRequest): return self._query( request, task_request_to_grpc_request, grpc_reply_to_task_response, TaskResponse(EdgeApiStatus.RETRY) ) elif isinstance(request, SelectionRequest): return self._query( request, selection_request_to_grpc_request, grpc_reply_to_selection_response, SelectionResponse(EdgeApiStatus.RETRY), ) elif isinstance(request, ResultReport): return self._query( request, result_report_to_grpc_request, grpc_reply_to_result_response, ResultResponse(EdgeApiStatus.RETRY), ) else: self.logger.error(f"received invalid request type: {type(request)}") return None