Source code for nvflare.fuel.hci.server.hci

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import socketserver
import ssl
import threading

from nvflare.fuel.hci.binary_proto import CT_BINARY, receive_all
from nvflare.fuel.hci.conn import Connection
from nvflare.fuel.hci.proto import MetaKey, MetaStatusValue, ProtoKey, make_meta, validate_proto
from nvflare.fuel.hci.security import IdentityKey, get_identity_info
from nvflare.security.logging import secure_log_traceback

from .constants import ConnProps
from .reg import ServerCommandRegister

logger = logging.getLogger(__name__)


class _MsgHandler(socketserver.BaseRequestHandler):
    """Message handler.

    Used by the AdminServer to receive admin commands, validate, then process and do command through the
    ServerCommandRegister.
    """

    def handle(self):
        conn = None
        try:
            conn = Connection(self.request, self.server)
            conn.set_prop(ConnProps.CA_CERT, self.server.ca_cert)

            if self.server.extra_conn_props:
                conn.set_props(self.server.extra_conn_props)

            if self.server.cmd_reg.conn_props:
                conn.set_props(self.server.cmd_reg.conn_props)

            if self.server.use_ssl:
                identity = get_identity_info(self.request.getpeercert())
                conn.set_prop(ConnProps.CLIENT_IDENTITY, identity)
                valid = self.server.validate_client_cn(identity[IdentityKey.NAME])
            else:
                valid = True

            if not valid:
                conn.append_error(
                    "authentication error", meta=make_meta(MetaStatusValue.NOT_AUTHENTICATED, info="invalid credential")
                )
            else:
                ct, req, extra = receive_all(self.request)
                if ct == CT_BINARY and not extra:
                    conn.append_error(
                        "no data received from client",
                        meta=make_meta(MetaStatusValue.INTERNAL_ERROR, info="no data received"),
                    )
                else:
                    req = req.strip()
                    command = None
                    req_json = validate_proto(req)
                    conn.request = req_json
                    conn.content_type = ct
                    conn.extra = extra

                    if req_json is not None:
                        meta = req_json.get(ProtoKey.META, None)
                        if meta and isinstance(meta, dict):
                            cmd_timeout = meta.get(MetaKey.CMD_TIMEOUT)
                            if cmd_timeout:
                                conn.set_prop(ConnProps.CMD_TIMEOUT, cmd_timeout)

                            custom_props = meta.get(MetaKey.CUSTOM_PROPS)
                            if custom_props:
                                conn.set_prop(ConnProps.CUSTOM_PROPS, custom_props)

                        data = req_json[ProtoKey.DATA]
                        for item in data:
                            it = item[ProtoKey.TYPE]
                            if it == ProtoKey.COMMAND:
                                command = item[ProtoKey.DATA]
                                break

                        if command is None:
                            conn.append_error(
                                "protocol violation",
                                meta=make_meta(MetaStatusValue.INTERNAL_ERROR, "protocol violation"),
                            )
                        else:
                            self.server.cmd_reg.process_command(conn, command)
                    else:
                        # not json encoded
                        conn.append_error(
                            "protocol violation", meta=make_meta(MetaStatusValue.INTERNAL_ERROR, "protocol violation")
                        )
        except:
            secure_log_traceback()
        if conn and not conn.ended:
            conn.close()


[docs]def initialize_hci(): socketserver.TCPServer.allow_reuse_address = True
[docs]class AdminServer(socketserver.ThreadingTCPServer): # faster re-binding allow_reuse_address = True # make this bigger than five request_queue_size = 10 # kick connections when we exit daemon_threads = True def __init__( self, cmd_reg: ServerCommandRegister, host, port, ca_cert=None, server_cert=None, server_key=None, accepted_client_cns=None, extra_conn_props=None, ): """Base class of FedAdminServer to create a server that can receive commands. Args: cmd_reg: CommandRegister host: the IP address of the admin server port: port number of admin server ca_cert: the root CA's cert file name server_cert: server's cert, signed by the CA server_key: server's private key file accepted_client_cns: list of accepted Common Names from client, if specified extra_conn_props: a dict of extra conn props, if specified """ if extra_conn_props is not None: assert isinstance(extra_conn_props, dict), "extra_conn_props must be dict but got {}".format( extra_conn_props ) socketserver.TCPServer.__init__(self, ("0.0.0.0", port), _MsgHandler, False) self.use_ssl = False if ca_cert and server_cert: if accepted_client_cns: assert isinstance(accepted_client_cns, list), "accepted_client_cns must be list but got {}.".format( accepted_client_cns ) ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) # This feature is only supported on 3.7+ ctx.minimum_version = ssl.TLSVersion.TLSv1_2 ctx.verify_mode = ssl.CERT_REQUIRED ctx.load_verify_locations(ca_cert) ctx.load_cert_chain(certfile=server_cert, keyfile=server_key) # replace the socket with an ssl version of itself self.socket = ctx.wrap_socket(self.socket, server_side=True) self.use_ssl = True # bind the socket and start the server self.server_bind() self.server_activate() self._thread = None self.ca_cert = ca_cert self.host = host self.port = port self.accepted_client_cns = accepted_client_cns self.extra_conn_props = extra_conn_props self.cmd_reg = cmd_reg cmd_reg.finalize()
[docs] def validate_client_cn(self, cn): if self.accepted_client_cns: return cn in self.accepted_client_cns else: return True
[docs] def stop(self): self.shutdown() self.cmd_reg.close() logger.info(f"Admin Server {self.host} on Port {self.port} shutdown!")
[docs] def set_command_registry(self, cmd_reg: ServerCommandRegister): if cmd_reg: cmd_reg.finalize() if self.cmd_reg: self.cmd_reg.close() self.cmd_reg = cmd_reg
[docs] def start(self): if self._thread is None: self._thread = threading.Thread(target=self._run, args=()) self._thread.daemon = True if not self._thread.is_alive(): self._thread.start()
def _run(self): logger.info(f"Starting Admin Server {self.host} on Port {self.port}") self.serve_forever()