Source code for nvflare.fuel.hci.server.hci

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import socketserver
import ssl
import threading

from nvflare.fuel.hci.conn import Connection, receive_til_end
from nvflare.fuel.hci.proto import validate_proto
from nvflare.fuel.hci.security import get_certificate_common_name

from .reg import ServerCommandRegister

MAX_ADMIN_CONNECTIONS = 16


class _MsgHandler(socketserver.BaseRequestHandler):
    """Message handler.

    Used by the AdminServer to receive admin commands, validate, then process and do command through the
    ServerCommandRegister.
    """

    connections = 0
    lock = threading.Lock()

    def __init__(self, request, client_address, server):
        # handle() is called in the constructor so logger must be initialized first
        self.logger = logging.getLogger(self.__class__.__name__)
        super().__init__(request, client_address, server)

    def handle(self):
        try:
            with _MsgHandler.lock:
                _MsgHandler.connections += 1

            self.logger.debug(f"Concurrent admin connections: {_MsgHandler.connections}")
            if _MsgHandler.connections > MAX_ADMIN_CONNECTIONS:
                raise ConnectionRefusedError(f"Admin connection limit ({MAX_ADMIN_CONNECTIONS}) reached")

            conn = Connection(self.request, self.server)

            if self.server.use_ssl:
                cn = get_certificate_common_name(self.request.getpeercert())
                conn.set_prop("_client_cn", cn)
                valid = self.server.validate_client_cn(cn)
            else:
                valid = True

            if not valid:
                conn.append_error("authentication error")
            else:
                req = receive_til_end(self.request).strip()
                command = None
                req_json = validate_proto(req)
                conn.request = req_json
                if req_json is not None:
                    data = req_json["data"]
                    for item in data:
                        it = item["type"]
                        if it == "command":
                            command = item["data"]
                            break

                    if command is None:
                        conn.append_error("protocol violation")
                    else:
                        self.server.cmd_reg.process_command(conn, command)
                else:
                    # not json encoded
                    conn.append_error("protocol violation")

            if not conn.ended:
                conn.close()
        except BaseException as exc:
            self.logger.error(f"Admin connection terminated due to exception: {str(exc)}")
            if self.logger.getEffectiveLevel() <= logging.DEBUG:
                self.logger.exception("Admin connection error")
        finally:
            with _MsgHandler.lock:
                _MsgHandler.connections -= 1


[docs]def initialize_hci(): socketserver.TCPServer.allow_reuse_address = True
[docs]class AdminServer(socketserver.ThreadingTCPServer): # faster re-binding allow_reuse_address = True # make this bigger than five request_queue_size = 10 # kick connections when we exit daemon_threads = True def __init__( self, cmd_reg: ServerCommandRegister, host, port, ca_cert=None, server_cert=None, server_key=None, accepted_client_cns=None, ): """Base class of FedAdminServer to create a server that can receive commands. Args: cmd_reg: CommandRegister host: the IP address of the admin server port: port number of admin server ca_cert: the root CA's cert file name server_cert: server's cert, signed by the CA server_key: server's private key file accepted_client_cns: list of accepted Common Names from client, if specified """ socketserver.TCPServer.__init__(self, (host, port), _MsgHandler, False) self.use_ssl = False if ca_cert and server_cert: if accepted_client_cns: assert isinstance(accepted_client_cns, list), "accepted_client_cns must be list but got {}.".format( accepted_client_cns ) ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) ctx.verify_mode = ssl.CERT_REQUIRED ctx.load_verify_locations(ca_cert) ctx.load_cert_chain(certfile=server_cert, keyfile=server_key) # replace the socket with an SSL version of itself self.socket = ctx.wrap_socket(self.socket, server_side=True) self.use_ssl = True # bind the socket and start the server self.server_bind() self.server_activate() self._thread = None self.host = host self.port = port self.accepted_client_cns = accepted_client_cns self.cmd_reg = cmd_reg cmd_reg.finalize() self.logger = logging.getLogger(self.__class__.__name__)
[docs] def validate_client_cn(self, cn): if self.accepted_client_cns: return cn in self.accepted_client_cns else: return True
[docs] def stop(self): self.shutdown() self.cmd_reg.close() if self._thread.is_alive(): self._thread.join() self.logger.info(f"Admin Server {self.host} on Port {self.port} shutdown!")
[docs] def set_command_registry(self, cmd_reg: ServerCommandRegister): if cmd_reg: cmd_reg.finalize() if self.cmd_reg: self.cmd_reg.close() self.cmd_reg = cmd_reg
[docs] def start(self): if self._thread is None: self._thread = threading.Thread(target=self._run, args=()) if not self._thread.is_alive(): self._thread.start()
def _run(self): self.logger.info(f"Starting Admin Server {self.host} on Port {self.port}") self.serve_forever()