Source code for nvflare.edge.web.routing_proxy

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
from typing import Any

from flask import Flask, jsonify
from flask.json.provider import DefaultJSONProvider

from nvflare.edge.web.models.api_error import ApiError
from nvflare.edge.web.views.feg_views import api_query, feg_bp

log = logging.getLogger(__name__)
app = Flask(__name__)


[docs] def clean_dict(value: Any): if isinstance(value, dict): return {k: clean_dict(v) for k, v in value.items() if v is not None} return value
[docs] class FilteredJSONProvider(DefaultJSONProvider): sort_keys = False
[docs] def dumps(self, obj: Any, **kwargs: Any) -> str: return super().dumps(clean_dict(obj))
[docs] @app.errorhandler(ApiError) def handle_api_error(error: ApiError): response = jsonify(error.to_dict()) response.status_code = error.status_code return response
[docs] def parse_args(): parser = argparse.ArgumentParser( description="Run proxy server with specified port, mapping file, and CA cert file." ) # Required positional arguments parser.add_argument("port", type=int, help="Port number to run the proxy server on.") parser.add_argument("lcp_mapping_file", type=str, help="Path to the mapping file.") parser.add_argument("ca_cert_file", type=str, help="Path to the CA certificate file.") # Optional SSL cert/key parser.add_argument( "--ssl-cert", type=str, default=None, help="Path to SSL certificate file (optional, self-signed or CA-signed)." ) parser.add_argument("--ssl-key", type=str, default=None, help="Path to SSL private key file (optional).") args = parser.parse_args() # If one SSL argument is provided, require both if (args.ssl_cert and not args.ssl_key) or (args.ssl_key and not args.ssl_cert): parser.error("Both --ssl-cert and --ssl-key must be provided together") return args
if __name__ == "__main__": logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[logging.StreamHandler()], ) args = parse_args() proxy_port = args.port lcp_mapping_file = args.lcp_mapping_file ca_cert_file = args.ca_cert_file ssl_context = None if args.ssl_cert and args.ssl_key: print(f"Using SSL cert: {args.ssl_cert}") print(f"Using SSL key: {args.ssl_key}") ssl_context = (args.ssl_cert, args.ssl_key) else: print("No SSL cert/key provided, running without SSL") api_query.set_lcp_mapping(lcp_mapping_file) api_query.set_ca_cert(ca_cert_file) api_query.start() app.json = FilteredJSONProvider(app) app.register_blueprint(feg_bp) app.run(host="0.0.0.0", port=proxy_port, debug=False, ssl_context=ssl_context)