Source code for nvflare.private.fed.server.admin

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
import time
from typing import List

from nvflare.fuel.hci.conn import Connection
from nvflare.fuel.hci.reg import CommandModule
from nvflare.fuel.hci.server.audit import CommandAudit
from nvflare.fuel.hci.server.authz import AuthorizationService, AuthzCommandModule, AuthzFilter
from nvflare.fuel.hci.server.builtin import new_command_register_with_builtin_module
from nvflare.fuel.hci.server.constants import ConnProps
from nvflare.fuel.hci.server.file_transfer import FileTransferModule
from nvflare.fuel.hci.server.hci import AdminServer
from nvflare.fuel.hci.server.login import LoginModule, SessionManager, SimpleAuthenticator
from nvflare.fuel.sec.audit import Auditor, AuditService
from nvflare.private.admin_defs import Message
from nvflare.private.defs import ERROR_MSG_PREFIX

from .app_authz import AppAuthzService


[docs]def new_message(conn: Connection, topic, body) -> Message: msg = Message(topic=topic, body=body) event_id = conn.get_prop(ConnProps.EVENT_ID, default=None) if event_id: msg.set_header(ConnProps.EVENT_ID, event_id) user_name = conn.get_prop(ConnProps.USER_NAME, default=None) if user_name: msg.set_header(ConnProps.USER_NAME, user_name) return msg
class _Waiter(object): def __init__(self, req: Message): self.req = req self.reply = None self.reply_time = None class _Client(object): def __init__(self, token): self.token = token self.last_heard_time = None self.outgoing_reqs = [] self.fnf_reqs = [] # fire-and-forget requests self.waiters = {} # ref => waiter self.req_lock = threading.Lock() self.waiter_lock = threading.Lock() def send(self, req: Message): waiter = _Waiter(req) with self.req_lock: self.outgoing_reqs.append(req) with self.waiter_lock: self.waiters[req.id] = waiter return waiter def fire_and_forget(self, reqs: [Message]): with self.req_lock: for r in reqs: self.fnf_reqs.append(r) def get_outgoing_requests(self, max_num_reqs=0) -> [Message]: result = [] with self.req_lock: # reqs in outgoing_reqs have higher priority q = self.outgoing_reqs if len(self.outgoing_reqs) <= 0: # no regular reqs - take fire-and-forget reqs q = self.fnf_reqs if max_num_reqs <= 0: num_reqs = len(q) else: num_reqs = min(max_num_reqs, len(q)) for i in range(num_reqs): result.append(q.pop(0)) return result def accept_reply(self, reply: Message): self.last_heard_time = time.time() ref_id = reply.get_ref_id() with self.waiter_lock: w = self.waiters.pop(ref_id, None) if w: w.reply = reply w.reply_time = time.time() def cancel_waiter(self, waiter: _Waiter): req = waiter.req with self.waiter_lock: waiter = self.waiters.pop(req.id, None) if waiter: with self.req_lock: if req in self.outgoing_reqs: self.outgoing_reqs.remove(req)
[docs]class ClientReply(object): def __init__(self, client_token: str, req: Message, reply: Message): """Client reply. Args: client_token (str): client token req (Message): request reply (Message): reply """ self.client_token = client_token self.request = req self.reply = reply
class _ClientReq(object): def __init__(self, client, req: Message): self.client = client self.req = req self.waiter = None
[docs]def check_client_replies(replies: List[ClientReply], client_sites: List[str], command: str): display_sites = ", ".join(client_sites) if not replies: raise RuntimeError(f"Failed to {command} to the clients {display_sites}: no replies.") if len(replies) != len(client_sites): raise RuntimeError(f"Failed to {command} to the clients {display_sites}: not enough replies.") error_msg = "" for r, client_name in zip(replies, client_sites): if r.reply and ERROR_MSG_PREFIX in r.reply.body: error_msg += f"\t{client_name}: {r.reply.body}\n" if error_msg != "": raise RuntimeError(f"Failed to {command} to the following clients: \n{error_msg}")
[docs]class FedAdminServer(AdminServer): def __init__( self, fed_admin_interface, users, cmd_modules, file_upload_dir, file_download_dir, allowed_shell_cmds, host, port, ca_cert_file_name, server_cert_file_name, server_key_file_name, accepted_client_cns=None, app_validator=None, download_job_url=None, ): """The FedAdminServer is the framework for developing admin commands. Args: fed_admin_interface: the server's federated admin interface users: a dict of {username: pwd hash} cmd_modules: a list of CommandModules file_upload_dir: the directory for uploaded files file_download_dir: the directory for files to be downloaded allowed_shell_cmds: list of shell commands allowed. If not specified, all allowed. host: the IP address of the admin server port: port number of admin server ca_cert_file_name: the root CA's cert file name server_cert_file_name: server's cert, signed by the CA server_key_file_name: server's private key file accepted_client_cns: list of accepted Common Names from client, if specified app_validator: Application folder validator. """ cmd_reg = new_command_register_with_builtin_module(app_ctx=fed_admin_interface) self.sai = fed_admin_interface self.allowed_shell_cmds = allowed_shell_cmds authenticator = SimpleAuthenticator(users) sess_mgr = SessionManager() login_module = LoginModule(authenticator, sess_mgr) cmd_reg.register_module(login_module) # register filters - order is important! # login_module is also a filter that determines if user is authenticated cmd_reg.add_filter(login_module) # next is the authorization filter and command module authorizer = AuthorizationService.get_authorizer() authz_filter = AuthzFilter(authorizer=authorizer) cmd_reg.add_filter(authz_filter) authz_cmd_module = AuthzCommandModule(authorizer=authorizer) cmd_reg.register_module(authz_cmd_module) # audit filter records commands to audit trail auditor = AuditService.get_auditor() # TODO:: clean this up if not isinstance(auditor, Auditor): raise TypeError("auditor must be Auditor but got {}".format(type(auditor))) audit_filter = CommandAudit(auditor) cmd_reg.add_filter(audit_filter) self.file_upload_dir = file_upload_dir self.file_download_dir = file_download_dir AppAuthzService.initialize(app_validator) cmd_reg.register_module( FileTransferModule( upload_dir=file_upload_dir, download_dir=file_download_dir, upload_folder_authz_func=AppAuthzService.authorize_upload, download_job_url=download_job_url, ) ) cmd_reg.register_module(sess_mgr) if cmd_modules: if not isinstance(cmd_modules, list): raise TypeError("cmd_modules must be list but got {}".format(type(cmd_modules))) for m in cmd_modules: if not isinstance(m, CommandModule): raise TypeError("cmd_modules must contain CommandModule but got element of type {}".format(type(m))) cmd_reg.register_module(m) AdminServer.__init__( self, cmd_reg=cmd_reg, host=host, port=port, ca_cert=ca_cert_file_name, server_cert=server_cert_file_name, server_key=server_key_file_name, accepted_client_cns=accepted_client_cns, ) self.clients = {} # token => _Client self.client_lock = threading.Lock() self.timeout = 10.0
[docs] def client_heartbeat(self, token): """Receive client heartbeat. Args: token: the session token of the client Returns: Client. """ with self.client_lock: client = self.clients.get(token) if not client: client = _Client(token) self.clients[token] = client client.last_heard_time = time.time() return client
[docs] def client_dead(self, token): """Remove dead client. Args: token: the session token of the client """ with self.client_lock: self.clients.pop(token, None)
[docs] def get_client_tokens(self) -> []: """Get tokens of existing clients.""" result = [] with self.client_lock: for token in self.clients.keys(): result.append(token) return result
[docs] def send_request_to_client(self, req: Message, client_token: str, timeout_secs=2.0) -> ClientReply: if not isinstance(req, Message): raise TypeError("request must be Message but got {}".format(type(req))) reqs = {client_token: req} replies = self.send_requests(reqs, timeout_secs) if replies is None or len(replies) <= 0: return None else: return replies[0]
[docs] def send_request_to_clients(self, req: Message, client_tokens: [str], timeout_secs=2.0) -> [ClientReply]: if not isinstance(req, Message): raise TypeError("request must be Message but got {}".format(type(req))) reqs = {} for token in client_tokens: reqs[token] = req return self.send_requests(reqs, timeout_secs)
[docs] def send_requests(self, requests: dict, timeout_secs=2.0) -> [ClientReply]: """Send requests to clients. NOTE:: This method is to be used by a Command Handler to send requests to Clients. Hence, it is run in the Command Handler's handling thread. This is a blocking call - returned only after all responses are received or timeout. Args: requests: A dict of requests: {client token: request or list of requests} timeout_secs: how long to wait for reply before timeout Returns: A list of ClientReply """ if not isinstance(requests, dict): raise TypeError("requests must be a dict but got {}".format(type(requests))) if len(requests) <= 0: return [] if timeout_secs <= 0.0: # this is fire-and-forget! for token, r in requests.items(): client = self.clients.get(token) if not client: continue if isinstance(r, list): reqs = r else: reqs = [r] client.fire_and_forget(reqs) # No replies return [] # Regular requests client_reqs = [] with self.client_lock: for token, r in requests.items(): client = self.clients.get(token) if not client: continue if isinstance(r, list): reqs = r else: reqs = [r] for req in reqs: if not isinstance(req, Message): raise TypeError("request must be a Message but got {}".format(type(req))) client_reqs.append(_ClientReq(client, req)) return self._send_client_reqs(client_reqs, timeout_secs)
def _send_client_reqs(self, client_reqs, timeout_secs) -> [ClientReply]: result = [] if len(client_reqs) <= 0: return result for cr in client_reqs: cr.waiter = cr.client.send(cr.req) start_time = time.time() while True: all_received = True for cr in client_reqs: if cr.waiter.reply_time is None: all_received = False break if all_received: break if time.time() - start_time > timeout_secs: # timeout break time.sleep(0.1) for cr in client_reqs: result.append(ClientReply(client_token=cr.client.token, req=cr.waiter.req, reply=cr.waiter.reply)) if cr.waiter.reply_time is None: # this client timed out cr.client.cancel_waiter(cr.waiter) return result
[docs] def accept_reply(self, client_token, reply: Message): """Accept client reply. NOTE:: This method is to be called by the FL Engine after a client's reply is received. Hence, it is called from the FL Engine's message processing thread. Args: client_token: session token of the client reply: the reply message """ client = self.client_heartbeat(client_token) ref_id = reply.get_ref_id() assert ref_id is not None, "protocol error: missing ref_id in reply from client {}".format(client_token) client.accept_reply(reply)
[docs] def get_outgoing_requests(self, client_token, max_reqs=0): """Get outgoing request from a client. NOTE:: This method is called by FL Engine to get outgoing messages to the client, so it can send them to the client. Args: client_token: session token of the client max_reqs: max number of requests. 0 means unlimited. Returns: outgoing requests. A list of Message. """ with self.client_lock: client = self.clients.get(client_token) if client: return client.get_outgoing_requests(max_reqs) else: return []
[docs] def stop(self): super().stop() self.sai.close()