Source code for nvflare.app_common.ccwf.server_ctl

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from datetime import datetime
from typing import List

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import ReturnCode, Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.ccwf.common import (
    Constant,
    StatusReport,
    make_task_name,
    status_report_from_dict,
    topic_for_end_workflow,
)
from nvflare.fuel.utils.validation_utils import (
    DefaultValuePolicy,
    check_number_range,
    check_object_type,
    check_positive_int,
    check_positive_number,
    check_str,
    normalize_config_arg,
    validate_candidate,
    validate_candidates,
)
from nvflare.security.logging import secure_format_traceback


[docs]class ClientStatus: def __init__(self): self.ready_time = None self.last_report_time = time.time() self.last_progress_time = time.time() self.num_reports = 0 self.status = StatusReport()
[docs]class ServerSideController(Controller): def __init__( self, num_rounds: int = 1, start_round: int = 0, task_name_prefix: str = "wf", configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT, end_workflow_timeout=Constant.END_WORKFLOW_TIMEOUT, start_task_timeout=Constant.START_TASK_TIMEOUT, task_check_period: float = Constant.TASK_CHECK_INTERVAL, job_status_check_interval: float = Constant.JOB_STATUS_CHECK_INTERVAL, starting_client: str = "", starting_client_policy: str = DefaultValuePolicy.ANY, participating_clients=None, result_clients: List[str] = [], result_clients_policy: str = DefaultValuePolicy.ALL, max_status_report_interval: float = Constant.PER_CLIENT_STATUS_REPORT_TIMEOUT, progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, private_p2p: bool = True, ): """ Constructor Args: num_rounds - the number of rounds to be performed. This is a workflow config parameter. Defaults to 1. start_round - the starting round number. This is a workflow config parameter. task_name_prefix - the prefix for task names of this workflow. The workflow requires multiple tasks (e.g. config and start) between the server controller and the client. The full names of these tasks are <prefix>_config and <prefix>_start. Subclasses may send additional tasks. Naming these tasks with a common prefix can make it easier to configure task executors for FL clients. participating_clients - the names of the clients that will participate in the job. None means all clients. result_clients - names of the clients that will receive final learning results. result_clients_policy - how to determine result_clients if their names are not explicitly specified. Possible values are: ALL - all participating clients ANY - any one of the participating clients EMPTY - no result_clients DISALLOW - does not allow implicit - result_clients must be explicitly specified configure_task_timeout - time to wait for clients’ responses to the config task before timeout. starting_client - name of the starting client. starting_client_policy - how to determine the starting client if the name is not explicitly specified. Possible values are: ANY - any one of the participating clients (the first client) RANDOM - a random client EMPTY - no starting client DISALLOW - does not allow implicit - starting_client must be explicitly specified start_task_timeout - how long to wait for the starting client to finish the “start” task. If timed out, the job will be aborted. If the starting_client is not specified, then no start task will be sent. max_status_report_interval - the maximum amount of time allowed for a client to miss a status report. In other words, if a client fails to report its status for this much time, the client will be considered in trouble and the job will be aborted. progress_timeout- the maximum amount of time allowed for the workflow to not make any progress. In other words, at least one participating client must have made progress during this time. Otherwise, the workflow will be considered to be in trouble and the job will be aborted. end_workflow_timeout - timeout for ending workflow message. private_p2p - whether to make peer-to-peer communications private. When set to True, P2P communications will be encrypted. Private P2P communication is an additional level of protection on basic communication security (such as SSL). Each pair of peers have their own encryption keys to ensure that only they themselves can understand their messages, even if the messages may be relayed through other sites (e.g. server). Different pairs of peers have different keys. Currently, private P2P is enabled only when the system is in secure mode. This is because key exchange between peers requires both sides to have PKI certificates and keys, which requires the project to be provisioned in secure mode. """ Controller.__init__(self, task_check_period) participating_clients = normalize_config_arg(participating_clients) if participating_clients is None: raise ValueError("participating_clients must not be empty") self.task_name_prefix = task_name_prefix self.configure_task_name = make_task_name(task_name_prefix, Constant.BASENAME_CONFIG) self.configure_task_timeout = configure_task_timeout self.start_task_name = make_task_name(task_name_prefix, Constant.BASENAME_START) self.start_task_timeout = start_task_timeout self.end_workflow_timeout = end_workflow_timeout self.num_rounds = num_rounds self.start_round = start_round self.max_status_report_interval = max_status_report_interval self.progress_timeout = progress_timeout self.job_status_check_interval = job_status_check_interval self.starting_client = starting_client self.starting_client_policy = starting_client_policy self.participating_clients = participating_clients self.result_clients = result_clients self.result_clients_policy = result_clients_policy # make private_p2p bool check_object_type("private_p2p", private_p2p, bool) self.private_p2p = private_p2p self.client_statuses = {} # client name => ClientStatus self.cw_started = False self.asked_to_stop = False self.workflow_id = None check_positive_int("num_rounds", num_rounds) check_number_range("configure_task_timeout", configure_task_timeout, min_value=1) check_number_range("end_workflow_timeout", end_workflow_timeout, min_value=1) check_positive_number("job_status_check_interval", job_status_check_interval) check_number_range("max_status_report_interval", max_status_report_interval, min_value=10.0) check_number_range("progress_timeout", progress_timeout, min_value=5.0) check_str("starting_client_policy", starting_client_policy) if participating_clients and len(participating_clients) < 2: raise ValueError(f"Not enough participating_clients: must > 1, but got {participating_clients}")
[docs] def start_controller(self, fl_ctx: FLContext): wf_id = fl_ctx.get_prop(FLContextKey.WORKFLOW) self.log_debug(fl_ctx, f"starting controller for workflow {wf_id}") if not wf_id: raise RuntimeError("workflow ID is missing from FL context") self.workflow_id = wf_id all_clients = self._engine.get_clients() if len(all_clients) < 2: raise RuntimeError(f"this workflow requires at least 2 clients, but only got {all_clients}") all_client_names = [t.name for t in all_clients] self.participating_clients = validate_candidates( var_name="participating_clients", candidates=self.participating_clients, base=all_client_names, default_policy=DefaultValuePolicy.ALL, allow_none=False, ) self.log_info(fl_ctx, f"Using participating clients: {self.participating_clients}") self.starting_client = validate_candidate( var_name="starting_client", candidate=self.starting_client, base=self.participating_clients, default_policy=self.starting_client_policy, allow_none=True, ) self.log_info(fl_ctx, f"Starting client: {self.starting_client}") self.result_clients = validate_candidates( var_name="result_clients", candidates=self.result_clients, base=self.participating_clients, default_policy=self.result_clients_policy, allow_none=True, ) for c in self.participating_clients: self.client_statuses[c] = ClientStatus()
[docs] def prepare_config(self) -> dict: return {}
[docs] def sub_flow(self, abort_signal: Signal, fl_ctx: FLContext): pass
[docs] def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): # wait for every client to become ready self.log_info(fl_ctx, f"Waiting for clients to be ready: {self.participating_clients}") # GET STARTED self.log_info(fl_ctx, f"Configuring clients {self.participating_clients} for workflow {self.workflow_id}") learn_config = { Constant.PRIVATE_P2P: self.private_p2p, Constant.TASK_NAME_PREFIX: self.task_name_prefix, Constant.CLIENTS: self.participating_clients, Constant.START_CLIENT: self.starting_client, Constant.RESULT_CLIENTS: self.result_clients, AppConstants.NUM_ROUNDS: self.num_rounds, Constant.START_ROUND: self.start_round, FLContextKey.WORKFLOW: self.workflow_id, } extra_config = self.prepare_config() if extra_config: learn_config.update(extra_config) self.log_info(fl_ctx, f"Workflow Config: {learn_config}") # configure all clients shareable = Shareable() shareable[Constant.CONFIG] = learn_config task = Task( name=self.configure_task_name, data=shareable, timeout=self.configure_task_timeout, result_received_cb=self._process_configure_reply, ) self.log_info(fl_ctx, f"sending task {self.configure_task_name} to clients {self.participating_clients}") start_time = time.time() self.broadcast_and_wait( task=task, targets=self.participating_clients, min_responses=len(self.participating_clients), fl_ctx=fl_ctx, abort_signal=abort_signal, ) time_taken = time.time() - start_time self.log_info(fl_ctx, f"client configuration took {time_taken} seconds") failed_clients = [] for c, cs in self.client_statuses.items(): assert isinstance(cs, ClientStatus) if not cs.ready_time: failed_clients.append(c) if failed_clients: self.system_panic( f"failed to configure clients {failed_clients}", fl_ctx, ) return self.log_info(fl_ctx, f"successfully configured clients {self.participating_clients}") # starting the starting_client if self.starting_client: shareable = Shareable() task = Task( name=self.start_task_name, data=shareable, timeout=self.start_task_timeout, result_received_cb=self._process_start_reply, ) self.log_info(fl_ctx, f"sending task {self.start_task_name} to client {self.starting_client}") self.send_and_wait( task=task, targets=[self.starting_client], fl_ctx=fl_ctx, abort_signal=abort_signal, ) if not self.cw_started: self.system_panic( f"failed to start workflow {self.workflow_id} on client {self.starting_client}", fl_ctx, ) return self.log_info(fl_ctx, f"started workflow {self.workflow_id} on client {self.starting_client}") # a subclass could provide additional control flow self.sub_flow(abort_signal, fl_ctx) self.log_info(fl_ctx, f"Waiting for clients to finish workflow {self.workflow_id} ...") while not abort_signal.triggered and not self.asked_to_stop: time.sleep(self.job_status_check_interval) done = self._check_job_status(fl_ctx) if done: break self.log_info(fl_ctx, f"Workflow {self.workflow_id} finished on all clients") # ask all clients to end the workflow self.log_info(fl_ctx, f"asking all clients to end workflow {self.workflow_id}") engine = fl_ctx.get_engine() end_wf_request = Shareable() resp = engine.send_aux_request( targets=self.participating_clients, topic=topic_for_end_workflow(self.workflow_id), request=end_wf_request, timeout=self.end_workflow_timeout, fl_ctx=fl_ctx, secure=False, ) assert isinstance(resp, dict) num_errors = 0 for c in self.participating_clients: reply = resp.get(c) if not reply: self.log_error(fl_ctx, f"not reply from client {c} for ending workflow {self.workflow_id}") num_errors += 1 continue assert isinstance(reply, Shareable) rc = reply.get_return_code(ReturnCode.OK) if rc != ReturnCode.OK: self.log_error(fl_ctx, f"client {c} failed to end workflow {self.workflow_id}: {rc}") num_errors += 1 if num_errors > 0: self.system_panic(f"failed to end workflow {self.workflow_id} on all clients", fl_ctx) self.log_info(fl_ctx, f"Workflow {self.workflow_id} done!")
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.BEFORE_PROCESS_TASK_REQUEST: self._update_client_status(fl_ctx)
[docs] def process_config_reply(self, client_name: str, reply: Shareable, fl_ctx: FLContext) -> bool: return True
def _process_configure_reply(self, client_task: ClientTask, fl_ctx: FLContext): result = client_task.result client_name = client_task.client.name rc = result.get_return_code() if rc == ReturnCode.OK: self.log_info(fl_ctx, f"successfully configured client {client_name}") try: ok = self.process_config_reply(client_name, result, fl_ctx) if not ok: return except: self.log_error( fl_ctx, f"exception processing config reply from client {client_name}: {secure_format_traceback()}" ) return cs = self.client_statuses.get(client_name) if cs: assert isinstance(cs, ClientStatus) cs.ready_time = time.time() else: error = result.get(Constant.ERROR, "?") self.log_error(fl_ctx, f"client {client_task.client.name} failed to configure: {rc}: {error}")
[docs] def client_started(self, client_task: ClientTask, fl_ctx: FLContext): return True
def _process_start_reply(self, client_task: ClientTask, fl_ctx: FLContext): result = client_task.result client_name = client_task.client.name rc = result.get_return_code() if rc == ReturnCode.OK: try: ok = self.client_started(client_task, fl_ctx) if not ok: return except: self.log_info(fl_ctx, f"exception in client_started: {secure_format_traceback()}") return self.cw_started = True else: error = result.get(Constant.ERROR, "?") self.log_error( fl_ctx, f"client {client_task.client.name} couldn't start workflow {self.workflow_id}: {rc}: {error}" )
[docs] def is_sub_flow_done(self, fl_ctx: FLContext) -> bool: return False
def _check_job_status(self, fl_ctx: FLContext): # see whether the server side thinks it's done if self.is_sub_flow_done(fl_ctx): return True now = time.time() overall_last_progress_time = 0.0 for client_name, cs in self.client_statuses.items(): assert isinstance(cs, ClientStatus) assert isinstance(cs.status, StatusReport) if cs.status.all_done: self.log_info(fl_ctx, f"Got ALL_DONE from client {client_name}") return True if now - cs.last_report_time > self.max_status_report_interval: self.system_panic( f"client {client_name} didn't report status for {self.max_status_report_interval} seconds", fl_ctx, ) return True if overall_last_progress_time < cs.last_progress_time: overall_last_progress_time = cs.last_progress_time if time.time() - overall_last_progress_time > self.progress_timeout: self.system_panic( f"the workflow {self.workflow_id} has no progress for {self.progress_timeout} seconds", fl_ctx, ) return True return False def _update_client_status(self, fl_ctx: FLContext): peer_ctx = fl_ctx.get_peer_context() assert isinstance(peer_ctx, FLContext) client_name = peer_ctx.get_identity_name() # see whether status is available reports = peer_ctx.get_prop(Constant.STATUS_REPORTS) if not reports: self.log_debug(fl_ctx, f"no status report from client {client_name}") return my_report = reports.get(self.workflow_id) if not my_report: return if client_name not in self.client_statuses: self.log_error(fl_ctx, f"received result from unknown client {client_name}!") return report = status_report_from_dict(my_report) cs = self.client_statuses[client_name] assert isinstance(cs, ClientStatus) now = time.time() cs.last_report_time = now cs.num_reports += 1 if report.error: self.asked_to_stop = True self.system_panic(f"received failure report from client {client_name}: {report.error}", fl_ctx) return if cs.status != report: # updated cs.status = report cs.last_progress_time = now timestamp = datetime.fromtimestamp(report.timestamp) if report.timestamp else False self.log_info( fl_ctx, f"updated status of client {client_name} on round {report.last_round}: " f"timestamp={timestamp}, action={report.action}, all_done={report.all_done}", ) else: self.log_debug( fl_ctx, f"ignored status report from client {client_name} at round {report.last_round}: no change" )
[docs] def process_result_of_unknown_task( self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext ): self.log_warning(fl_ctx, f"ignored unknown task {task_name} from client {client.name}")
[docs] def stop_controller(self, fl_ctx: FLContext): pass