Source code for nvflare.private.fed.client.admin

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The FedAdmin to communicate with the Admin server."""

import threading
import time
import traceback

from nvflare.fuel.hci.server.constants import ConnProps
from nvflare.fuel.sec.audit import Auditor, AuditService
from nvflare.private.admin_defs import Message, error_reply, ok_reply


[docs]class Sender(object): """The Sender object integrate the agent with the underline messaging system. Make sure its methods are exception-proof! """
[docs] def send_reply(self, reply: Message): """Send the reply to the requester. Args: reply: reply message """ pass
[docs] def retrieve_requests(self) -> [Message]: """Send the message to retrieve pending requests from the Server. Returns: list of messages. """ pass
[docs] def send_result(self, message: Message): """Send the processor results to server. Args: message: message """ pass
[docs] def close(self): """Call to close the sender. Returns: """
[docs]class RequestProcessor(object): """The RequestProcessor is responsible for processing a request."""
[docs] def get_topics(self) -> [str]: """Get topics that this processor will handle. Returns: list of topics """ pass
[docs] def process(self, req: Message, app_ctx) -> Message: """Called to process the specified request. Args: req: request message app_ctx: application context Returns: reply message """ pass
[docs]class FedAdminAgent(object): """FedAdminAgent communicate with the FedAdminServer.""" def __init__(self, client_name, sender: Sender, app_ctx, req_poll_interval=0.5, process_poll_interval=0.1): """Init the FedAdminAgent. Args: client_name: client name sender: Sender object app_ctx: application context req_poll_interval: request polling interval process_poll_interval: process polling interval """ if not isinstance(sender, Sender): raise TypeError("sender must be an instance of Sender, but got {}".format(type(sender))) auditor = AuditService.get_auditor() if not isinstance(auditor, Auditor): raise TypeError("auditor must be an instance of Auditor, but got {}".format(type(auditor))) self.name = client_name self.sender = sender self.auditor = auditor self.app_ctx = app_ctx self.req_poll_interval = req_poll_interval self.process_poll_interval = process_poll_interval self.processors = {} self.reqs = [] self.req_lock = threading.Lock() self.retrieve_reqs_thread = None self.process_req_thread = None self.asked_to_stop = False
[docs] def register_processor(self, processor: RequestProcessor): """To register the RequestProcessor. Args: processor: RequestProcessor """ if not isinstance(processor, RequestProcessor): raise TypeError("processor must be an instance of RequestProcessor, but got {}".format(type(processor))) topics = processor.get_topics() for topic in topics: assert topic not in self.processors, "duplicate processors for topic {}".format(topic) self.processors[topic] = processor
[docs] def start(self): """To start the FedAdminAgent.""" if self.retrieve_reqs_thread is None: self.retrieve_reqs_thread = threading.Thread(target=_start_retriever, args=(self,)) # called from the main thread if not self.retrieve_reqs_thread.is_alive(): self.retrieve_reqs_thread.start() if self.process_req_thread is None: self.process_req_thread = threading.Thread(target=_start_processor, args=(self,)) # called from the main thread if not self.process_req_thread.is_alive(): self.process_req_thread.start()
def _run_retriever(self): while True: if self.asked_to_stop: break reqs = self.sender.retrieve_requests() if reqs is not None and isinstance(reqs, list): with self.req_lock: self.reqs.extend(reqs) time.sleep(self.req_poll_interval) def _run_processor(self): while True: if self.asked_to_stop: break with self.req_lock: if len(self.reqs) > 0: req = self.reqs.pop(0) else: req = None if req: assert isinstance(req, Message), "request must be Message but got {}".format(type(req)) topic = req.topic # create audit record if self.auditor: user_name = req.get_header(ConnProps.USER_NAME, "") ref_event_id = req.get_header(ConnProps.EVENT_ID, "") self.auditor.add_event(user=user_name, action=topic, ref=ref_event_id) processor = self.processors.get(topic) if processor: try: reply = processor.process(req, self.app_ctx) if reply is None: # simply ack reply = ok_reply() else: assert isinstance( reply, Message ), "processor for topic {} failed to produce valid reply".format(topic) except BaseException as e: traceback.print_exc() reply = error_reply("exception_occurred: {}".format(e)) else: reply = error_reply("invalid_request") reply.set_ref_id(req.id) self.sender.send_reply(reply) time.sleep(self.process_poll_interval)
[docs] def shutdown(self): """To be called by the Client Engine to gracefully shutdown the agent.""" self.asked_to_stop = True if self.retrieve_reqs_thread and self.retrieve_reqs_thread.is_alive(): self.retrieve_reqs_thread.join() if self.process_req_thread and self.process_req_thread.is_alive(): self.process_req_thread.join() self.sender.close()
def _start_retriever(agent: FedAdminAgent): agent._run_retriever() def _start_processor(agent: FedAdminAgent): agent._run_processor()