Source code for nvflare.app_opt.flower.applet

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import threading
import time

from nvflare.apis.fl_context import FLContext
from nvflare.apis.workspace import Workspace
from nvflare.app_common.tie.applet import Applet
from nvflare.app_common.tie.cli_applet import CLIApplet
from nvflare.app_common.tie.defs import Constant as TieConstant
from nvflare.app_common.tie.process_mgr import CommandDescriptor, ProcessManager, StopMethod, run_command, start_process
from nvflare.app_opt.flower.defs import Constant
from nvflare.fuel.utils.grpc_utils import create_channel
from nvflare.security.logging import secure_format_exception


[docs] def get_partition_id(fl_ctx: FLContext): """Get the partition id for the current client based on the sorted list of all client names.""" engine = fl_ctx.get_engine() all_client_names = sorted([client.name for client in engine.get_clients()]) for id, client_name in enumerate(all_client_names): if client_name == fl_ctx.get_identity_name(): return id return -1
[docs] def get_num_partitions(fl_ctx: FLContext): """Get the number of partitions based on the number of clients.""" engine = fl_ctx.get_engine() return len(engine.get_clients())
[docs] class FlowerClientApplet(CLIApplet): def __init__(self, extra_env: dict = None): """Constructor of FlowerClientApplet, which extends CLIApplet.""" CLIApplet.__init__(self, stop_method="term") self.extra_env = extra_env
[docs] def get_command(self, ctx: dict) -> CommandDescriptor: """Implementation of the get_command method required by the super class CLIApplet. It returns the CLI command for starting Flower's client app, as well as the full path of the log file for the client app. Args: ctx: the applet run context Returns: CLI command for starting client app and name of log file. """ superlink_addr = ctx.get(Constant.APP_CTX_SUPERLINK_ADDR) clientapp_api_addr = ctx.get(Constant.APP_CTX_CLIENTAPP_API_ADDR) fl_ctx = ctx.get(Constant.APP_CTX_FL_CONTEXT) if not isinstance(fl_ctx, FLContext): self.logger.error(f"expect APP_CTX_FL_CONTEXT to be FLContext but got {type(fl_ctx)}") raise RuntimeError("invalid FLContext") engine = fl_ctx.get_engine() ws = engine.get_workspace() if not isinstance(ws, Workspace): self.logger.error(f"expect workspace to be Workspace but got {type(ws)}") raise RuntimeError("invalid workspace") job_id = fl_ctx.get_job_id() app_dir = ws.get_app_dir(job_id) """ Example: flower-supernode --insecure --grpc-adapter --superlink 127.0.0.1:9092 --clientappio-api-address 127.0.0.1:9094 --node-config ... """ cmd = ( f"flower-supernode --insecure --grpc-adapter " f"--superlink {superlink_addr} " f"--clientappio-api-address {clientapp_api_addr}" ) # add node config node_config_str = self._get_node_config(fl_ctx) if node_config_str: cmd += node_config_str # use app_dir as the cwd for flower's client app. # this is necessary for client_api to be used with the flower client app for metrics logging # client_api expects config info from the "config" folder in the cwd! self.logger.info(f"starting flower client app: {cmd}") return CommandDescriptor( cmd=cmd, cwd=app_dir, env=self.extra_env, log_file_name="client_app_log.txt", stdout_msg_prefix="FLWR-CA", stop_method=StopMethod.TERMINATE, )
def _get_node_config(self, fl_ctx: FLContext): """Get the node config for the flower client app.""" try: cmd = f' client-name="{fl_ctx.get_identity_name()}"' partition_id = get_partition_id(fl_ctx) if partition_id != -1: cmd += f" partition-id={partition_id}" cmd += f" num-partitions={get_num_partitions(fl_ctx)}" return f" --node-config '{cmd}'" except Exception as ex: self.log_error(fl_ctx, f"Exception getting node configuration from fl_ctx: {secure_format_exception(ex)}") return None
[docs] class FlowerServerApplet(Applet): def __init__( self, database: str, superlink_ready_timeout: float, superlink_grace_period=1.0, superlink_min_query_interval=10.0, ): """Constructor of FlowerServerApplet. Args: database: database spec to be used by the server app superlink_ready_timeout: how long to wait for the superlink process to become ready superlink_grace_period: how long to wait for superlink to gracefully shutdown superlink_min_query_interval: minimal interval for querying superlink for status """ Applet.__init__(self) self._superlink_process_mgr = None self.database = database self.superlink_ready_timeout = superlink_ready_timeout self.superlink_grace_period = superlink_grace_period self.superlink_min_query_interval = superlink_min_query_interval self.run_id = None self.last_query_time = None self.last_check_status = None self.last_check_stopped = False self.flower_app_dir = None self.exec_api_addr = None self.flower_run_finished = False self.flwr_stop_called = False # have we called 'flwr stop'? self.flower_run_rc = None self._start_error = False self.stop_lock = threading.Lock() def _start_process(self, name: str, cmd_desc: CommandDescriptor, fl_ctx: FLContext) -> ProcessManager: self.logger.info(f"starting {name}: {cmd_desc.cmd}") try: return start_process(cmd_desc, fl_ctx, stop_method="term") except Exception as ex: self.logger.error(f"exception starting applet: {secure_format_exception(ex)}") self._start_error = True
[docs] def start(self, app_ctx: dict): """Start the applet. Flower requires two processes for server application: superlink: this process is responsible for client communication server_app: this process performs server side of training. We start the superlink first, and wait for it to become ready, then start the server app. Each process will have its own log file in the job's run dir. The superlink's log file is named "superlink_log.txt". The server app's log file is named "server_app_log.txt". Args: app_ctx: the run context of the applet. Returns: """ # try to start superlink first serverapp_api_addr = app_ctx.get(Constant.APP_CTX_SERVERAPP_API_ADDR) fleet_api_addr = app_ctx.get(Constant.APP_CTX_FLEET_API_ADDR) exec_api_addr = app_ctx.get(Constant.APP_CTX_EXEC_API_ADDR) fl_ctx = app_ctx.get(Constant.APP_CTX_FL_CONTEXT) if not isinstance(fl_ctx, FLContext): self.logger.error(f"expect APP_CTX_FL_CONTEXT to be FLContext but got {type(fl_ctx)}") raise RuntimeError("invalid FLContext") engine = fl_ctx.get_engine() ws = engine.get_workspace() if not isinstance(ws, Workspace): self.logger.error(f"expect workspace to be Workspace but got {type(ws)}") raise RuntimeError("invalid workspace") custom_dir = ws.get_app_custom_dir(fl_ctx.get_job_id()) self.flower_app_dir = custom_dir self.exec_api_addr = exec_api_addr db_arg = "" if self.database: db_arg = f"--database {self.database}" """ Example: flower-superlink --insecure --fleet-api-type grpc-adapter --serverappio-api-address 127.0.0.1:9091 --fleet-api-address 127.0.0.1:9092 --exec-api-address 127.0.0.1:9093 """ superlink_cmd = ( f"flower-superlink --insecure --fleet-api-type grpc-adapter {db_arg} " f"--serverappio-api-address {serverapp_api_addr} " f"--fleet-api-address {fleet_api_addr} " f"--exec-api-address {exec_api_addr}" ) cmd_desc = CommandDescriptor( cmd=superlink_cmd, log_file_name="superlink_log.txt", stdout_msg_prefix="FLWR-SL", stop_method=StopMethod.TERMINATE, ) self._superlink_process_mgr = self._start_process(name="superlink", cmd_desc=cmd_desc, fl_ctx=fl_ctx) if not self._superlink_process_mgr: raise RuntimeError("cannot start superlink process") # wait until superlink's fleet_api_addr is ready before starting server app # fleet_api_addr is superlink's gRPC server address for Flare to connect to. start_time = time.time() create_channel( server_addr=fleet_api_addr, grpc_options=None, ready_timeout=self.superlink_ready_timeout, test_only=True, ) self.logger.info(f"superlink is ready for server app in {time.time() - start_time} seconds") # submitting the server app using "flwr run" command flwr_run_cmd = self._flower_command("run") run_info = self._run_flower_command(flwr_run_cmd) run_id = run_info.get("run-id") if not run_id: raise RuntimeError(f"invalid result from command '{flwr_run_cmd}': missing run-id") self.logger.info(f"submitted Flower App and got run id {run_id}") self.run_id = run_id
def _flower_command(self, cmd_name: str, cmd_args=""): return ( f"flwr {cmd_name} --format json --federation-config 'address=\"{self.exec_api_addr}\"' " f"{cmd_args} {self.flower_app_dir}" ) def _run_flower_command(self, command: str): self.logger.debug(f"running flower command: {command}") cmd_desc = CommandDescriptor(cmd=command) reply = run_command(cmd_desc) if not isinstance(reply, str): raise RuntimeError(f"failed to run command '{command}': expect reply to be str but got {type(reply)}") self.logger.debug(f"flower command {command}: {reply=}") # the reply must be a json str try: result = json.loads(reply) except Exception as ex: err = f"invalid result from command '{command}': {secure_format_exception(ex)}" self.logger.error(err) raise RuntimeError(err) if not isinstance(result, dict): err = f"invalid result from command '{command}': expect dict but got {type(result)}" self.logger.error(err) raise RuntimeError(err) success = result.get("success", False) if not success: err = f"failed command '{command}': {success=} {result=}" self.logger.error(err) raise RuntimeError(err) self.logger.debug(f"result of {command}: {result}") return result @staticmethod def _stop_process(p: ProcessManager) -> int: if not p: # nothing to stop return 0 else: return p.stop()
[docs] def stop(self, timeout=0.0) -> int: """Stop the server applet's superlink. Args: timeout: how long to wait before forcefully stopping (kill) the process. Note: we always stop the process immediately - do not wait for the process to stop itself. Returns: """ with self.stop_lock: if self.run_id and not self.flwr_stop_called and not self.flower_run_finished: # stop the server app # we may not be able to issue 'flwr stop' more than once! self.flwr_stop_called = True flwr_stop_cmd = self._flower_command("stop", self.run_id) self.logger.info(f"Issued command to stop Flower App: {flwr_stop_cmd}") try: self._run_flower_command(flwr_stop_cmd) except Exception as ex: # ignore exception self.logger.error(f"exception running '{flwr_stop_cmd}': {secure_format_exception(ex)}") # wait a while to let superlink and supernodes gracefully stop the app time.sleep(self.superlink_grace_period) # stop the superlink self._stop_process(self._superlink_process_mgr) self._superlink_process_mgr = None return 0
@staticmethod def _is_process_stopped(p: ProcessManager): if p: return_code = p.poll() if return_code is None: return False, 0 else: return True, return_code else: return True, 0 def _check_flower_run_status(self): if not self.last_query_time or time.time() - self.last_query_time > self.superlink_min_query_interval: # time to query self.last_query_time = time.time() self.last_check_stopped, self.last_check_status = self._query_for_run_status() return self.last_check_stopped, self.last_check_status def _query_for_run_status(self): # check whether the app is finished flwr_ls_cmd = self._flower_command("ls") try: run_info = self._run_flower_command(flwr_ls_cmd) except Exception as ex: self.logger.error(f"exception running '{flwr_ls_cmd}': {secure_format_exception(ex)}") return True, TieConstant.EXIT_CODE_FATAL_ERROR runs = run_info.get("runs") if not runs: # the app is no longer there self.logger.error(f"invalid result from command '{flwr_ls_cmd}': missing run info") return True, TieConstant.EXIT_CODE_FATAL_ERROR if not isinstance(runs, list): self.logger.error(f"invalid result from command '{flwr_ls_cmd}': expect run list but got {type(runs)}") return True, TieConstant.EXIT_CODE_FATAL_ERROR run = runs[0] if not isinstance(run, dict): self.logger.error(f"invalid result from command '{flwr_ls_cmd}': expect run to be dict but got {type(run)}") return True, TieConstant.EXIT_CODE_FATAL_ERROR status = run.get("status") if not status: self.logger.error(f"invalid result from command '{flwr_ls_cmd}': missing status from {run}") return True, TieConstant.EXIT_CODE_FATAL_ERROR if not isinstance(status, str): self.logger.error(f"invalid result from command '{flwr_ls_cmd}': bad status value '{status}'") return True, TieConstant.EXIT_CODE_FATAL_ERROR if status.startswith("finished"): self.logger.info(f"Flower Run {self.run_id} finished: {status=}") self.flower_run_finished = True if status.endswith("completed"): rc = 0 else: rc = TieConstant.EXIT_CODE_FAILED self.flower_run_rc = rc return True, rc else: return False, 0
[docs] def is_stopped(self) -> (bool, int): """Check whether the server applet is already stopped Returns: a tuple of: whether the applet is stopped, exit code if stopped. Note: if either superlink or server app is stopped, we treat the applet as stopped. """ if self._start_error: return True, TieConstant.EXIT_CODE_CANT_START superlink_stopped, superlink_rc = self._is_process_stopped(self._superlink_process_mgr) if superlink_stopped: self._superlink_process_mgr = None return True, superlink_rc if self.flower_run_finished: return True, self.flower_run_rc if self.last_check_stopped: return self.last_check_stopped, self.last_check_status with self.stop_lock: self.last_check_stopped, self.last_check_status = self._check_flower_run_status() return self.last_check_stopped, self.last_check_status