Source code for nvflare.edge.web.service.utils

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Optional

from nvflare.edge.constants import EdgeApiStatus
from nvflare.edge.web.models.job_request import JobRequest
from nvflare.edge.web.models.job_response import JobResponse
from nvflare.edge.web.models.result_report import ResultReport
from nvflare.edge.web.models.result_response import ResultResponse
from nvflare.edge.web.models.selection_request import SelectionRequest
from nvflare.edge.web.models.selection_response import SelectionResponse
from nvflare.edge.web.models.task_request import TaskRequest
from nvflare.edge.web.models.task_response import TaskResponse

from .constants import NONE_DATA, QueryType
from .edge_api_pb2 import Reply, Request


[docs] def to_bytes(data: Optional[dict]) -> bytes: if not data: return NONE_DATA str_data = json.dumps(data) return str_data.encode("utf-8")
[docs] def make_reply(status: str, payload: Optional[dict] = None): return Reply( status=status, payload=to_bytes(payload), )
def _request_to_grpc(query_type: str, method: str, req) -> Request: payload = {} payload.update(req) return Request(type=query_type, method=method, header=NONE_DATA, payload=to_bytes(payload)) def _grpc_reply_to_response(reply: Reply, clazz): if reply.status != EdgeApiStatus.OK: return clazz(status=reply.status) if reply.payload != NONE_DATA: d = json.loads(reply.payload) resp = clazz(EdgeApiStatus.OK) resp.update(d) else: resp = None return resp
[docs] def job_request_to_grpc_request(request: JobRequest) -> Request: return _request_to_grpc(QueryType.JOB_REQUEST, "POST", request)
[docs] def grpc_reply_to_job_response(reply: Reply) -> JobResponse: return _grpc_reply_to_response(reply, JobResponse)
[docs] def task_request_to_grpc_request(request: TaskRequest) -> Request: return _request_to_grpc(QueryType.TASK_REQUEST, "GET", request)
[docs] def grpc_reply_to_task_response(reply: Reply) -> TaskResponse: return _grpc_reply_to_response(reply, TaskResponse)
[docs] def selection_request_to_grpc_request(request: SelectionRequest) -> Request: return _request_to_grpc(QueryType.SELECTION_REQUEST, "GET", request)
[docs] def grpc_reply_to_selection_response(reply: Reply) -> SelectionResponse: return _grpc_reply_to_response(reply, SelectionResponse)
[docs] def result_report_to_grpc_request(request: ResultReport) -> Request: return _request_to_grpc(QueryType.RESULT_REPORT, "POST", request)
[docs] def grpc_reply_to_result_response(reply: Reply) -> ResultResponse: return _grpc_reply_to_response(reply, ResultResponse)