Source code for nvflare.private.fed.server.admin

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
from typing import List, Optional

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import ServerCommandKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx
from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.net_agent import NetAgent
from nvflare.fuel.f3.cellnet.net_manager import NetManager
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.hci.conn import Connection
from nvflare.fuel.hci.reg import CommandModule
from nvflare.fuel.hci.server.audit import CommandAudit
from nvflare.fuel.hci.server.authz import AuthzFilter
from nvflare.fuel.hci.server.builtin import new_command_register_with_builtin_module
from nvflare.fuel.hci.server.constants import ConnProps
from nvflare.fuel.hci.server.hci import AdminServer
from nvflare.fuel.hci.server.login import LoginModule, SessionManager, SimpleAuthenticator
from nvflare.fuel.sec.audit import Auditor, AuditService
from nvflare.private.admin_defs import Message
from nvflare.private.defs import ERROR_MSG_PREFIX, RequestHeader
from nvflare.private.fed.server.message_send import ClientReply, send_requests


[docs]def new_message(conn: Connection, topic, body, require_authz: bool) -> Message: msg = Message(topic=topic, body=body) cmd_entry = conn.get_prop(ConnProps.CMD_ENTRY) if cmd_entry: msg.set_header(RequestHeader.ADMIN_COMMAND, cmd_entry.name) msg.set_header(RequestHeader.REQUIRE_AUTHZ, str(require_authz).lower()) props_to_copy = [ ConnProps.EVENT_ID, ConnProps.USER_NAME, ConnProps.USER_ROLE, ConnProps.USER_ORG, ConnProps.SUBMITTER_NAME, ConnProps.SUBMITTER_ORG, ConnProps.SUBMITTER_ROLE, ] for p in props_to_copy: prop = conn.get_prop(p, default=None) if prop: msg.set_header(p, prop) return msg
class _Client(object): def __init__(self, token, name): self.token = token self.name = name self.last_heard_time = None class _ClientReq(object): def __init__(self, client, req: Message): self.client = client self.req = req
[docs]def check_client_replies(replies: List[ClientReply], client_sites: List[str], command: str): display_sites = ", ".join(client_sites) if not replies: raise RuntimeError(f"Failed to {command} to the clients {display_sites}: no replies.") if len(replies) != len(client_sites): raise RuntimeError(f"Failed to {command} to the clients {display_sites}: not enough replies.") error_msg = "" for r, client_name in zip(replies, client_sites): if r.reply and ERROR_MSG_PREFIX in r.reply.body: error_msg += f"\t{client_name}: {r.reply.body}\n" if error_msg != "": raise RuntimeError(f"Failed to {command} to the following clients: \n{error_msg}")
[docs]class FedAdminServer(AdminServer): def __init__( self, cell: Cell, fed_admin_interface, users, cmd_modules, file_upload_dir, file_download_dir, host, port, ca_cert_file_name, server_cert_file_name, server_key_file_name, accepted_client_cns=None, download_job_url="", ): """The FedAdminServer is the framework for developing admin commands. Args: fed_admin_interface: the server's federated admin interface users: a dict of {username: pwd hash} cmd_modules: a list of CommandModules file_upload_dir: the directory for uploaded files file_download_dir: the directory for files to be downloaded host: the IP address of the admin server port: port number of admin server ca_cert_file_name: the root CA's cert file name server_cert_file_name: server's cert, signed by the CA server_key_file_name: server's private key file accepted_client_cns: list of accepted Common Names from client, if specified download_job_url: download job url """ cmd_reg = new_command_register_with_builtin_module(app_ctx=fed_admin_interface) self.sai = fed_admin_interface self.cell = cell self.client_lock = threading.Lock() authenticator = SimpleAuthenticator(users) sess_mgr = SessionManager() login_module = LoginModule(authenticator, sess_mgr) cmd_reg.register_module(login_module) # register filters - order is important! # login_module is also a filter that determines if user is authenticated cmd_reg.add_filter(login_module) # next is the authorization filter and command module authz_filter = AuthzFilter() cmd_reg.add_filter(authz_filter) # audit filter records commands to audit trail auditor = AuditService.get_auditor() # TODO:: clean this up if not isinstance(auditor, Auditor): raise TypeError("auditor must be Auditor but got {}".format(type(auditor))) audit_filter = CommandAudit(auditor) cmd_reg.add_filter(audit_filter) self.file_upload_dir = file_upload_dir self.file_download_dir = file_download_dir cmd_reg.register_module(sess_mgr) # mpm.add_cleanup_cb(sess_mgr.shutdown) agent = NetAgent(self.cell) net_mgr = NetManager(agent) cmd_reg.register_module(net_mgr) mpm.add_cleanup_cb(net_mgr.close) mpm.add_cleanup_cb(agent.close) if cmd_modules: if not isinstance(cmd_modules, list): raise TypeError("cmd_modules must be list but got {}".format(type(cmd_modules))) for m in cmd_modules: if not isinstance(m, CommandModule): raise TypeError("cmd_modules must contain CommandModule but got element of type {}".format(type(m))) cmd_reg.register_module(m) AdminServer.__init__( self, cmd_reg=cmd_reg, host=host, port=port, ca_cert=ca_cert_file_name, server_cert=server_cert_file_name, server_key=server_key_file_name, accepted_client_cns=accepted_client_cns, extra_conn_props={ ConnProps.DOWNLOAD_DIR: file_download_dir, ConnProps.UPLOAD_DIR: file_upload_dir, ConnProps.DOWNLOAD_JOB_URL: download_job_url, }, ) self.clients = {} # token => _Client self.timeout = 10.0
[docs] def client_heartbeat(self, token, name: str): """Receive client heartbeat. Args: token: the session token of the client name: client name Returns: Client. """ with self.client_lock: client = self.clients.get(token) if not client: client = _Client(token, name) self.clients[token] = client client.last_heard_time = time.time() return client
[docs] def client_dead(self, token): """Remove dead client. Args: token: the session token of the client """ with self.client_lock: self.clients.pop(token, None)
[docs] def get_client_tokens(self) -> []: """Get tokens of existing clients.""" result = [] with self.client_lock: for token in self.clients.keys(): result.append(token) return result
[docs] def send_request_to_client(self, req: Message, client_token: str, timeout_secs=2.0) -> Optional[ClientReply]: if not isinstance(req, Message): raise TypeError("request must be Message but got {}".format(type(req))) reqs = {client_token: req} with self.sai.new_context() as fl_ctx: replies = self.send_requests(reqs, fl_ctx, timeout_secs=timeout_secs) if replies is None or len(replies) <= 0: return None else: return replies[0]
[docs] def send_requests_and_get_reply_dict(self, requests: dict, timeout_secs=2.0) -> dict: """Send requests to clients Args: requests: A dict of requests: {client token: Message} timeout_secs: how long to wait for reply before timeout Returns: A dict of {client token: reply}, where reply is a Message or None (no reply received) """ result = {} if requests: for token, _ in requests.items(): result[token] = None with self.sai.new_context() as fl_ctx: replies = self.send_requests(requests, fl_ctx, timeout_secs=timeout_secs) for r in replies: result[r.client_token] = r.reply return result
[docs] def send_requests(self, requests: dict, fl_ctx: FLContext, timeout_secs=2.0, optional=False) -> [ClientReply]: """Send requests to clients. NOTE:: This method is to be used by a Command Handler to send requests to Clients. Hence, it is run in the Command Handler's handling thread. This is a blocking call - returned only after all responses are received or timeout. Args: requests: A dict of requests: {client token: request or list of requests} fl_ctx: FLContext timeout_secs: how long to wait for reply before timeout optional: whether the requests are optional Returns: A list of ClientReply """ for _, request in requests.items(): # with self.sai.new_context() as fl_ctx: self.sai.fire_event(EventType.BEFORE_SEND_ADMIN_COMMAND, fl_ctx) shared_fl_ctx = gen_new_peer_ctx(fl_ctx) request.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) return send_requests( cell=self.cell, command="admin", requests=requests, clients=self.clients, timeout_secs=timeout_secs, optional=optional, )
[docs] def stop(self): super().stop() self.sai.close()