Source code for nvflare.app_common.tie.controller

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
from abc import ABC, abstractmethod

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.apis.utils.reliable_message import ReliableMessage
from nvflare.app_common.tie.connector import Connector
from nvflare.fuel.utils.validation_utils import check_number_range, check_positive_number
from nvflare.security.logging import secure_format_exception

from .applet import Applet
from .defs import Constant


class _ClientStatus:
    """
    Objects of this class keep processing status of each FL client during job execution.
    """

    def __init__(self):
        # Set when the client's config reply is received and the reply return code is OK.
        # If the client failed to reply or the return code is not OK, this value is not set.
        self.configured_time = None

        # Set when the client's start reply is received and the reply return code is OK.
        # If the client failed to reply or the return code is not OK, this value is not set.
        self.started_time = None

        # operation of the last request from this client
        self.last_op = None

        # time of the last op request from this client
        self.last_op_time = time.time()

        # whether the app process is finished on this client
        self.app_done = False


[docs] class TieController(Controller, ABC): def __init__( self, configure_task_name=Constant.CONFIG_TASK_NAME, configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT, start_task_name=Constant.START_TASK_NAME, start_task_timeout=Constant.START_TASK_TIMEOUT, job_status_check_interval: float = Constant.JOB_STATUS_CHECK_INTERVAL, max_client_op_interval: float = Constant.MAX_CLIENT_OP_INTERVAL, progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, ): """ Constructor Args: configure_task_name - name of the config task configure_task_timeout - time to wait for clients’ responses to the config task before timeout. start_task_name - name of the start task start_task_timeout - time to wait for clients’ responses to the start task before timeout. job_status_check_interval - how often to check client statuses of the job max_client_op_interval - max amount of time allowed between app ops from a client progress_timeout- the maximum amount of time allowed for the workflow to not make any progress. In other words, at least one participating client must have made progress during this time. Otherwise, the workflow will be considered to be in trouble and the job will be aborted. """ Controller.__init__(self) self.configure_task_name = configure_task_name self.start_task_name = start_task_name self.start_task_timeout = start_task_timeout self.configure_task_timeout = configure_task_timeout self.max_client_op_interval = max_client_op_interval self.progress_timeout = progress_timeout self.job_status_check_interval = job_status_check_interval self.connector = None self.participating_clients = None self.status_lock = threading.Lock() self.client_statuses = {} # client name => ClientStatus self.abort_signal = None check_number_range("configure_task_timeout", configure_task_timeout, min_value=1) check_number_range("start_task_timeout", start_task_timeout, min_value=1) check_positive_number("job_status_check_interval", job_status_check_interval) check_number_range("max_client_op_interval", max_client_op_interval, min_value=10.0) check_number_range("progress_timeout", progress_timeout, min_value=5.0)
[docs] @abstractmethod def get_client_config_params(self, fl_ctx: FLContext) -> dict: """Called by the TieController to get config parameters to be sent to FL clients. Subclass of TieController must implement this method. Args: fl_ctx: FL context Returns: a dict of config params """ pass
[docs] @abstractmethod def get_connector_config_params(self, fl_ctx: FLContext) -> dict: """Called by the TieController to get config parameters for configuring the connector. Subclass of TieController must implement this method. Args: fl_ctx: FL context Returns: a dict of config params """ pass
[docs] @abstractmethod def get_connector(self, fl_ctx: FLContext) -> Connector: """Called by the TieController to get the Connector to be used with the controller. Subclass of TieController must implement this method. Args: fl_ctx: FL context Returns: a Connector object """ pass
[docs] @abstractmethod def get_applet(self, fl_ctx: FLContext) -> Applet: """Called by the TieController to get the Applet to be used with the controller. Subclass of TieController must implement this method. Args: fl_ctx: FL context Returns: an Applet object """ pass
[docs] def start_controller(self, fl_ctx: FLContext): """Start the controller. It first tries to get the connector and applet to be used. It then initializes the applet, set the applet to the connector, and initializes the connector. It finally registers message handlers for APP_REQUEST and CLIENT_DONE. If error occurs in any step, the job is stopped. Note: if a subclass overwrites this method, it must call super().start_controller()! Args: fl_ctx: the FL context Returns: None """ all_clients = self._engine.get_clients() self.participating_clients = [t.name for t in all_clients] for c in self.participating_clients: self.client_statuses[c] = _ClientStatus() connector = self.get_connector(fl_ctx) if not connector: self.system_panic("cannot get connector", fl_ctx) return None if not isinstance(connector, Connector): self.system_panic( f"invalid connector: expect Connector but got {type(connector)}", fl_ctx, ) return None applet = self.get_applet(fl_ctx) if not applet: self.system_panic("cannot get applet", fl_ctx) return if not isinstance(applet, Applet): self.system_panic( f"invalid applet: expect Applet but got {type(applet)}", fl_ctx, ) return applet.initialize(fl_ctx) connector.set_applet(applet) connector.initialize(fl_ctx) self.connector = connector engine = fl_ctx.get_engine() engine.register_aux_message_handler( topic=Constant.TOPIC_CLIENT_DONE, message_handle_func=self._process_client_done, ) ReliableMessage.register_request_handler( topic=Constant.TOPIC_APP_REQUEST, handler_f=self._handle_app_request, fl_ctx=fl_ctx, )
def _trigger_stop(self, fl_ctx: FLContext, error=None): # first trigger the abort_signal to tell all components (mainly the controller's control_flow and connector) # that check this signal to abort. if self.abort_signal: self.abort_signal.trigger(value=True) # if there is error, call system_panic to terminate the job with proper status. # if no error, the job will end normally. if error: self.system_panic(reason=error, fl_ctx=fl_ctx) def _is_stopped(self): # check whether the abort signal is triggered return self.abort_signal and self.abort_signal.triggered def _update_client_status(self, fl_ctx: FLContext, op=None, client_done=False): """Update the status of the requesting client. Args: fl_ctx: FL context op: the app operation requested client_done: whether the client is done Returns: None """ with self.status_lock: peer_ctx = fl_ctx.get_peer_context() if not peer_ctx: self.log_error(fl_ctx, "missing peer_ctx from fl_ctx") return if not isinstance(peer_ctx, FLContext): self.log_error(fl_ctx, f"expect peer_ctx to be FLContext but got {type(peer_ctx)}") return client_name = peer_ctx.get_identity_name() if not client_name: self.log_error(fl_ctx, "missing identity from peer_ctx") return status = self.client_statuses.get(client_name) if not status: self.log_error(fl_ctx, f"no status record for client {client_name}") assert isinstance(status, _ClientStatus) if op: status.last_op = op if client_done: status.app_done = client_done status.last_op_time = time.time() def _process_client_done(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: """Process the ClientDone report for a client Args: topic: topic of the message request: request to be processed fl_ctx: the FL context Returns: reply to the client """ self.log_debug(fl_ctx, f"_process_client_done {topic}") exit_code = request.get(Constant.MSG_KEY_EXIT_CODE) if exit_code == 0: self.log_info(fl_ctx, f"app client is done with exit code {exit_code}") elif exit_code == Constant.EXIT_CODE_CANT_START: self.log_error(fl_ctx, f"app client failed to start (exit code {exit_code})") self.system_panic("app client failed to start", fl_ctx) else: # Should we stop here? # Problem is that even if the exit_code is not 0, we can't say the job failed. self.log_warning(fl_ctx, f"app client is done with exit code {exit_code}") self._update_client_status(fl_ctx, client_done=True) return make_reply(ReturnCode.OK) def _handle_app_request(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: """Handle app request from applets on other sites It calls the connector to process the app request. If the connector fails to process the request, the job will be stopped. Args: topic: message topic request: the request data fl_ctx: FL context Returns: processing result as a Shareable object """ self.log_debug(fl_ctx, f"_handle_app_request {topic}") op = request.get_header(Constant.MSG_KEY_OP) if self._is_stopped(): self.log_warning(fl_ctx, f"dropped app request ({op=}) since server is already stopped") return make_reply(ReturnCode.SERVICE_UNAVAILABLE) # we assume app protocol to be very strict, we'll stop the control flow when any error occurs process_error = "app request process error" self._update_client_status(fl_ctx, op=op) try: reply = self.connector.process_app_request(op, request, fl_ctx, self.abort_signal) except Exception as ex: self.log_exception(fl_ctx, f"exception processing app request {op=}: {secure_format_exception(ex)}") self._trigger_stop(fl_ctx, process_error) return make_reply(ReturnCode.EXECUTION_EXCEPTION) self.log_info(fl_ctx, f"received reply for app request '{op=}'") reply.set_header(Constant.MSG_KEY_OP, op) return reply def _configure_clients(self, abort_signal: Signal, fl_ctx: FLContext): self.log_info(fl_ctx, f"Configuring clients {self.participating_clients}") try: config = self.get_client_config_params(fl_ctx) except Exception as ex: self.system_panic(f"exception get_client_config_params: {secure_format_exception(ex)}", fl_ctx) return False if config is None: self.system_panic("no config data is returned", fl_ctx) return False shareable = Shareable() shareable[Constant.MSG_KEY_CONFIG] = config task = Task( name=self.configure_task_name, data=shareable, timeout=self.configure_task_timeout, result_received_cb=self._process_configure_reply, ) self.log_info(fl_ctx, f"sending task {self.configure_task_name} to clients {self.participating_clients}") start_time = time.time() self.broadcast_and_wait( task=task, targets=self.participating_clients, min_responses=len(self.participating_clients), fl_ctx=fl_ctx, abort_signal=abort_signal, ) time_taken = time.time() - start_time self.log_info(fl_ctx, f"client configuration took {time_taken} seconds") failed_clients = [] for c, cs in self.client_statuses.items(): assert isinstance(cs, _ClientStatus) if not cs.configured_time: failed_clients.append(c) # if any client failed to configure, terminate the job if failed_clients: self.system_panic(f"failed to configure clients {failed_clients}", fl_ctx) return False self.log_info(fl_ctx, f"successfully configured clients {self.participating_clients}") return True def _start_clients(self, abort_signal: Signal, fl_ctx: FLContext): self.log_info(fl_ctx, f"Starting clients {self.participating_clients}") task = Task( name=self.start_task_name, data=Shareable(), timeout=self.start_task_timeout, result_received_cb=self._process_start_reply, ) self.log_info(fl_ctx, f"sending task {self.start_task_name} to clients {self.participating_clients}") start_time = time.time() self.broadcast_and_wait( task=task, targets=self.participating_clients, min_responses=len(self.participating_clients), fl_ctx=fl_ctx, abort_signal=abort_signal, ) time_taken = time.time() - start_time self.log_info(fl_ctx, f"client starting took {time_taken} seconds") failed_clients = [] for c, cs in self.client_statuses.items(): assert isinstance(cs, _ClientStatus) if not cs.started_time: failed_clients.append(c) # if any client failed to start, terminate the job if failed_clients: self.system_panic(f"failed to start clients {failed_clients}", fl_ctx) return False self.log_info(fl_ctx, f"successfully started clients {self.participating_clients}") return True
[docs] def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): """ To ensure smooth app execution: - ensure that all clients are online and ready to go before starting server - ensure that server is started and ready to take requests before asking clients to start operation - monitor the health of the clients - if anything goes wrong, terminate the job Args: abort_signal: abort signal that is used to notify components to abort fl_ctx: FL context Returns: None """ self.abort_signal = abort_signal # the connector uses the same abort signal! self.connector.set_abort_signal(abort_signal) # wait for every client to become online and properly configured self.log_info(fl_ctx, f"Waiting for clients to be ready: {self.participating_clients}") # configure all clients if not self._configure_clients(abort_signal, fl_ctx): self.system_panic("failed to configure all clients", fl_ctx) return # configure and start the connector try: config = self.get_connector_config_params(fl_ctx) self.connector.configure(config, fl_ctx) self.log_info(fl_ctx, "starting connector ...") self.connector.start(fl_ctx) except Exception as ex: error = f"failed to start connector: {secure_format_exception(ex)}" self.log_error(fl_ctx, error) self.system_panic(error, fl_ctx) return self.connector.monitor(fl_ctx, self._app_stopped) # start all clients if not self._start_clients(abort_signal, fl_ctx): self.system_panic("failed to start all clients", fl_ctx) return # monitor client health # we periodically check job status until all clients are done or the system is stopped self.log_info(fl_ctx, "Waiting for clients to finish ...") while not self._is_stopped(): done = self._check_job_status(fl_ctx) if done: break time.sleep(self.job_status_check_interval)
def _app_stopped(self, rc, fl_ctx: FLContext): # This CB is called when app server is stopped error = None if rc != 0: self.log_error(fl_ctx, f"App Server stopped abnormally with code {rc}") error = "App server abnormal stop" # the app server could stop at any moment, we trigger the abort_signal in case it is checked by any # other components self._trigger_stop(fl_ctx, error) def _process_configure_reply(self, client_task: ClientTask, fl_ctx: FLContext): result = client_task.result client_name = client_task.client.name rc = result.get_return_code() if rc == ReturnCode.OK: self.log_info(fl_ctx, f"successfully configured client {client_name}") cs = self.client_statuses.get(client_name) if cs: assert isinstance(cs, _ClientStatus) cs.configured_time = time.time() else: self.log_error(fl_ctx, f"client {client_task.client.name} failed to configure: {rc}") def _process_start_reply(self, client_task: ClientTask, fl_ctx: FLContext): result = client_task.result client_name = client_task.client.name rc = result.get_return_code() if rc == ReturnCode.OK: self.log_info(fl_ctx, f"successfully started client {client_name}") cs = self.client_statuses.get(client_name) if cs: assert isinstance(cs, _ClientStatus) cs.started_time = time.time() else: self.log_error(fl_ctx, f"client {client_name} failed to start") def _check_job_status(self, fl_ctx: FLContext) -> bool: """Check job status and determine whether the job is done. Args: fl_ctx: FL context Returns: whether the job is considered done. """ now = time.time() # overall_last_progress_time is the latest time that any client made progress. overall_last_progress_time = 0.0 clients_done = 0 for client_name, cs in self.client_statuses.items(): assert isinstance(cs, _ClientStatus) if cs.app_done: self.log_info(fl_ctx, f"client {client_name} is Done") clients_done += 1 elif now - cs.last_op_time > self.max_client_op_interval: self.system_panic( f"client {client_name} didn't have any activity for {self.max_client_op_interval} seconds", fl_ctx, ) return True if overall_last_progress_time < cs.last_op_time: overall_last_progress_time = cs.last_op_time if clients_done == len(self.client_statuses): # all clients are done - the job is considered done return True elif time.time() - overall_last_progress_time > self.progress_timeout: # there has been no progress from any client for too long. # this could be because the clients got stuck. # consider the job done and abort the job. self.system_panic(f"the job has no progress for {self.progress_timeout} seconds", fl_ctx) return True return False
[docs] def process_result_of_unknown_task( self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext ): self.log_warning(fl_ctx, f"ignored unknown task {task_name} from client {client.name}")
[docs] def stop_controller(self, fl_ctx: FLContext): """This is called by base controller to stop. If a subclass overwrites this method, it must call super().stop_controller(fl_ctx). Args: fl_ctx: Returns: """ if self.connector: self.log_info(fl_ctx, "Stopping server connector ...") self.connector.stop(fl_ctx) self.log_info(fl_ctx, "Server connector stopped")