Source code for nvflare.edge.web.views.feg_views

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union

from flask import Blueprint, request

from nvflare.edge.constants import EdgeApiStatus, HttpHeaderKey
from nvflare.edge.web.models.api_error import ApiError
from nvflare.edge.web.models.base_model import EdgeProtoKey
from nvflare.edge.web.models.device_info import DeviceInfo
from nvflare.edge.web.models.job_request import JobRequest
from nvflare.edge.web.models.result_report import ResultReport
from nvflare.edge.web.models.selection_request import SelectionRequest
from nvflare.edge.web.models.task_request import TaskRequest
from nvflare.edge.web.models.user_info import UserInfo
from nvflare.edge.web.service.query import Query


[docs] class APIQuery: def __init__(self): self.lcp_mapping_file = None self.ca_cert_file = None self.query = None
[docs] def set_lcp_mapping(self, file_name: str): self.lcp_mapping_file = file_name
[docs] def set_ca_cert(self, file_name: str): self.ca_cert_file = file_name
[docs] def start(self): self.query = Query(lcp_mapping_file=self.lcp_mapping_file, ca_cert_file=self.ca_cert_file)
def __call__(self, req: Union[TaskRequest, JobRequest, SelectionRequest, ResultReport]): return self.query(req)
feg_bp = Blueprint("feg", __name__) api_query = APIQuery() def _process_headers() -> dict: headers = request.headers device_id = headers.get(HttpHeaderKey.DEVICE_ID, None) if not device_id: raise ApiError(400, EdgeApiStatus.INVALID_REQUEST, "Device ID missing") d = {} device_info = DeviceInfo(device_id) device_info_header = headers.get(HttpHeaderKey.DEVICE_INFO, None) if device_info_header: device_info.from_query_string(device_info_header) d[EdgeProtoKey.DEVICE_INFO] = device_info user_info_header = headers.get(HttpHeaderKey.USER_INFO, None) if user_info_header: user_info = UserInfo() user_info.from_query_string(user_info_header) else: user_info = None if user_info: d[EdgeProtoKey.USER_INFO] = user_info return d def _update_body(d: dict): data = request.get_json() d.update(data) def _update_args(d: dict, keys: dict): for k, default_value in keys.items(): arg_value = request.args.get(k) if arg_value is None: arg_value = default_value if arg_value is not None: d[k] = arg_value def _do_query(req): resp = api_query(req) if not resp: raise ApiError(400, EdgeApiStatus.INVALID_REQUEST, "unknown request type") return resp
[docs] @feg_bp.route("/job", methods=["POST"]) def job_view(): d = _process_headers() _update_body(d) error, req = JobRequest.from_dict(d) if error: raise ApiError(400, EdgeApiStatus.INVALID_REQUEST, error) return _do_query(req)
[docs] @feg_bp.route("/task", methods=["POST"]) def task_view(): d = _process_headers() _update_args(d, {EdgeProtoKey.JOB_ID: None}) _update_body(d) error, req = TaskRequest.from_dict(d) if error: raise ApiError(400, EdgeApiStatus.INVALID_REQUEST, error) return _do_query(req)
[docs] @feg_bp.route("/result", methods=["POST"]) def result_view(): d = _process_headers() _update_args( d, { EdgeProtoKey.JOB_ID: None, EdgeProtoKey.TASK_ID: None, EdgeProtoKey.TASK_NAME: None, EdgeProtoKey.STATUS: EdgeApiStatus.OK, }, ) _update_body(d) error, req = ResultReport.from_dict(d) if error: raise ApiError(400, EdgeApiStatus.INVALID_REQUEST, error) return _do_query(req)
[docs] @feg_bp.route("/selection", methods=["POST"]) def selection_view(): d = _process_headers() _update_args(d, {EdgeProtoKey.JOB_ID: None}) _update_body(d) error, req = SelectionRequest.from_dict(d) if error: raise ApiError(400, EdgeApiStatus.INVALID_REQUEST, error) return _do_query(req)