# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import threading
import time
from abc import abstractmethod
from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.task_controller import Task, TaskController
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.learnable import Learnable
from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor
from nvflare.app_common.abstract.shareable_generator import ShareableGenerator
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.app_common.ccwf.common import Constant, ResultType, StatusReport, make_task_name, topic_for_end_workflow
from nvflare.fuel.utils.validation_utils import check_non_empty_str, check_number_range, check_positive_number
from nvflare.security.logging import secure_format_traceback
class _LearnTask:
def __init__(self, task_name: str, task_data: Shareable, fl_ctx: FLContext):
self.task_name = task_name
self.task_data = task_data
self.fl_ctx = fl_ctx
self.abort_signal = Signal()
[docs]class ClientSideController(Executor, TaskController):
def __init__(
self,
task_name_prefix: str,
learn_task_name=AppConstants.TASK_TRAIN,
persistor_id=AppConstants.DEFAULT_PERSISTOR_ID,
shareable_generator_id=AppConstants.DEFAULT_SHAREABLE_GENERATOR_ID,
learn_task_check_interval=Constant.LEARN_TASK_CHECK_INTERVAL,
learn_task_ack_timeout=Constant.LEARN_TASK_ACK_TIMEOUT,
learn_task_abort_timeout=Constant.LEARN_TASK_ABORT_TIMEOUT,
final_result_ack_timeout=Constant.FINAL_RESULT_ACK_TIMEOUT,
allow_busy_task: bool = False,
):
"""
Constructor of a ClientSideController object.
Args:
task_name_prefix: prefix of task names. All CCWF task names are prefixed with this.
learn_task_name: name for the Learning Task (LT)
persistor_id: ID of the persistor component
shareable_generator_id: ID of the shareable generator component
learn_task_check_interval: interval for checking incoming Learning Task (LT)
learn_task_ack_timeout: timeout for sending the LT to other client(s)
final_result_ack_timeout: timeout for sending final result to participating clients
learn_task_abort_timeout: time to wait for the LT to become stopped after aborting it
allow_busy_task: whether a new learn task is allowed when working on current learn task
"""
check_non_empty_str("task_name_prefix", task_name_prefix)
check_positive_number("learn_task_check_interval", learn_task_check_interval)
check_number_range("learn_task_ack_timeout", learn_task_ack_timeout, min_value=1.0)
check_positive_number("learn_task_abort_timeout", learn_task_abort_timeout)
check_number_range("final_result_ack_timeout", final_result_ack_timeout, min_value=1.0)
Executor.__init__(self)
TaskController.__init__(self)
self.task_name_prefix = task_name_prefix
self.start_task_name = make_task_name(task_name_prefix, Constant.BASENAME_START)
self.configure_task_name = make_task_name(task_name_prefix, Constant.BASENAME_CONFIG)
self.do_learn_task_name = make_task_name(task_name_prefix, Constant.BASENAME_LEARN)
self.report_final_result_task_name = make_task_name(task_name_prefix, Constant.BASENAME_REPORT_FINAL_RESULT)
self.learn_task_name = learn_task_name
self.learn_task_abort_timeout = learn_task_abort_timeout
self.learn_task_check_interval = learn_task_check_interval
self.learn_task_ack_timeout = learn_task_ack_timeout
self.final_result_ack_timeout = final_result_ack_timeout
self.allow_busy_task = allow_busy_task
self.persistor_id = persistor_id
self.shareable_generator_id = shareable_generator_id
self.persistor = None
self.shareable_generator = None
self.current_status = StatusReport()
self.last_status_report_time = time.time() # time of last status report to server
self.config = None
self.workflow_id = None
self.finalize_lock = threading.Lock()
self.learn_thread = threading.Thread(target=self._do_learn)
self.learn_thread.daemon = True
self.learn_task = None
self.current_task = None
self.learn_executor = None
self.learn_task_lock = threading.Lock()
self.asked_to_stop = False
self.status_lock = threading.Lock()
self.engine = None
self.me = None
self.is_starting_client = False
self.last_result = None
self.last_round = None
self.best_result = None
self.best_metric = None
self.best_round = 0
self.workflow_done = False
[docs] def get_config_prop(self, name: str, default=None):
"""
Get a specified config property.
Args:
name: name of the property
default: default value to return if the property is not defined.
Returns:
"""
if not self.config:
return default
return self.config.get(name, default)
[docs] def start_run(self, fl_ctx: FLContext):
self.start_controller(fl_ctx)
self.engine = fl_ctx.get_engine()
if not self.engine:
self.system_panic("no engine", fl_ctx)
return
runner = fl_ctx.get_prop(FLContextKey.RUNNER)
if not runner:
self.system_panic("no client runner", fl_ctx)
return
self.me = fl_ctx.get_identity_name()
if self.learn_task_name:
self.learn_executor = runner.find_executor(self.learn_task_name)
if not self.learn_executor:
self.system_panic(f"no executor for task {self.learn_task_name}", fl_ctx)
return
self.persistor = self.engine.get_component(self.persistor_id)
if not isinstance(self.persistor, LearnablePersistor):
self.system_panic(
f"Persistor {self.persistor_id} must be a Persistor instance, but got {type(self.persistor)}",
fl_ctx,
)
return
if self.shareable_generator_id:
self.shareable_generator = self.engine.get_component(self.shareable_generator_id)
if not isinstance(self.shareable_generator, ShareableGenerator):
self.system_panic(
f"Shareable generator {self.shareable_generator_id} must be a Shareable Generator instance, "
f"but got {type(self.shareable_generator)}",
fl_ctx,
)
return
self.initialize(fl_ctx)
if self.learn_task_name:
self.log_info(fl_ctx, "Started learn thread")
self.learn_thread.start()
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
self.start_run(fl_ctx)
elif event_type == EventType.BEFORE_PULL_TASK:
# add my status to fl_ctx
if not self.workflow_id:
return
reports = fl_ctx.get_prop(Constant.STATUS_REPORTS)
if reports:
reports.pop(self.workflow_id, None)
if self.workflow_done:
return
report = self._get_status_report()
if not report:
self.log_debug(fl_ctx, "nothing to report this time")
return
self._add_status_report(report, fl_ctx)
self.last_status_report_time = report.timestamp
elif event_type in [EventType.ABORT_TASK, EventType.END_RUN]:
if not self.asked_to_stop and not self.workflow_done:
self.asked_to_stop = True
self._abort_current_task(fl_ctx)
self.finalize(fl_ctx)
def _add_status_report(self, report: StatusReport, fl_ctx: FLContext):
reports = fl_ctx.get_prop(Constant.STATUS_REPORTS)
if not reports:
reports = {}
# set the prop as public, so it will be sent to the peer in peer_context
fl_ctx.set_prop(Constant.STATUS_REPORTS, reports, sticky=False, private=False)
reports[self.workflow_id] = report.to_dict()
[docs] def initialize(self, fl_ctx: FLContext):
"""Called to initialize the executor.
Args:
fl_ctx: The FL Context
Returns: None
"""
fl_ctx.set_prop(Constant.EXECUTOR, self, private=True, sticky=False)
self.fire_event(Constant.EXECUTOR_INITIALIZED, fl_ctx)
[docs] def finalize(self, fl_ctx: FLContext):
"""Called to finalize the executor.
Args:
fl_ctx: the FL Context
Returns: None
"""
with self.finalize_lock:
if self.workflow_done:
return
fl_ctx.set_prop(Constant.EXECUTOR, self, private=True, sticky=False)
fl_ctx.set_prop(FLContextKey.WORKFLOW, self.workflow_id, private=True, sticky=False)
self.fire_event(Constant.EXECUTOR_FINALIZED, fl_ctx)
self.workflow_done = True
[docs] def process_config(self, fl_ctx: FLContext):
"""This is called to allow the subclass to process config props.
Returns: None
"""
pass
[docs] def topic_for_my_workflow(self, base_topic: str):
return f"{base_topic}.{self.workflow_id}"
[docs] def broadcast_final_result(
self, fl_ctx: FLContext, result_type: str, result: Learnable, metric=None, round_num=None
):
error = None
targets = self.get_config_prop(Constant.RESULT_CLIENTS)
if not targets:
self.log_info(fl_ctx, f"no clients configured to receive final {result_type} result")
else:
try:
num_errors = self._try_broadcast_final_result(fl_ctx, result_type, result, metric, round_num)
if num_errors > 0:
error = ReturnCode.EXECUTION_EXCEPTION
except:
self.log_error(fl_ctx, f"exception broadcast final {result_type} result {secure_format_traceback()}")
error = ReturnCode.EXECUTION_EXCEPTION
if result_type == ResultType.BEST:
action = "finished_broadcast_best_result"
all_done = False
else:
action = "finished_broadcast_last_result"
all_done = True
self.update_status(action=action, error=error, all_done=all_done)
def _try_broadcast_final_result(
self, fl_ctx: FLContext, result_type: str, result: Learnable, metric=None, round_num=None
):
targets = self.get_config_prop(Constant.RESULT_CLIENTS)
assert isinstance(targets, list)
if self.me in targets:
targets.remove(self.me)
if len(targets) == 0:
# no targets to receive the result!
self.log_info(fl_ctx, f"no targets to receive {result_type} result")
return 0
shareable = Shareable()
shareable.set_header(Constant.RESULT_TYPE, result_type)
if metric is not None:
shareable.set_header(Constant.METRIC, metric)
if round_num is not None:
shareable.set_header(Constant.ROUND, round_num)
shareable[Constant.RESULT] = result
self.log_info(
fl_ctx, f"broadcasting {result_type} result with metric {metric} at round {round_num} to clients {targets}"
)
self.update_status(action=f"broadcast_{result_type}_result")
task = Task(
name=self.report_final_result_task_name,
data=shareable,
timeout=int(self.final_result_ack_timeout),
secure=self.is_task_secure(fl_ctx),
)
resp = self.broadcast_and_wait(
task=task,
targets=targets,
min_responses=len(targets),
fl_ctx=fl_ctx,
)
assert isinstance(resp, dict)
num_errors = 0
for t in targets:
reply = resp.get(t)
if not isinstance(reply, Shareable):
self.log_error(
fl_ctx,
f"bad response for {result_type} result from client {t}: "
f"reply must be Shareable but got {type(reply)}",
)
num_errors += 1
continue
rc = reply.get_return_code(ReturnCode.OK)
if rc != ReturnCode.OK:
self.log_error(fl_ctx, f"bad response for {result_type} result from client {t}: {rc}")
num_errors += 1
if num_errors == 0:
self.log_info(fl_ctx, f"successfully broadcast {result_type} result to {targets}")
return num_errors
[docs] def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
if task_name == self.configure_task_name:
self.config = shareable[Constant.CONFIG]
my_wf_id = self.get_config_prop(FLContextKey.WORKFLOW)
if not my_wf_id:
self.log_error(fl_ctx, "missing workflow id in configuration!")
return make_reply(ReturnCode.BAD_REQUEST_DATA)
self.log_info(fl_ctx, f"got my workflow id {my_wf_id}")
self.workflow_id = my_wf_id
reply = self.process_config(fl_ctx)
self.engine.register_aux_message_handler(
topic=topic_for_end_workflow(my_wf_id),
message_handle_func=self._process_end_workflow,
)
learnable = self.persistor.load(fl_ctx)
fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, learnable, private=True, sticky=True)
if not reply:
reply = make_reply(ReturnCode.OK)
return reply
elif task_name == self.start_task_name:
self.is_starting_client = True
learnable = fl_ctx.get_prop(AppConstants.GLOBAL_MODEL)
initial_model = self.shareable_generator.learnable_to_shareable(learnable, fl_ctx)
return self.start_workflow(initial_model, fl_ctx, abort_signal)
elif task_name == self.do_learn_task_name:
return self._process_learn_request(shareable, fl_ctx)
elif task_name == self.report_final_result_task_name:
return self._process_final_result(shareable, fl_ctx)
else:
self.log_error(fl_ctx, f"Could not handle task: {task_name}")
return make_reply(ReturnCode.TASK_UNKNOWN)
[docs] @abstractmethod
def start_workflow(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
"""
This is called for the subclass to start the workflow.
This only happens on the starting_client.
Args:
shareable: the initial task data (e.g. initial model weights)
fl_ctx: FL context
abort_signal: abort signal for task execution
Returns:
"""
pass
def _get_status_report(self):
with self.status_lock:
status = self.current_status
must_report = False
if status.error:
must_report = True
elif status.timestamp:
must_report = True
if not must_report:
return None
# do status report
report = copy.copy(status)
return report
def _abort_current_task(self, fl_ctx: FLContext):
current_task = self.learn_task
if not current_task:
return
current_task.abort_signal.trigger(True)
fl_ctx.set_prop(FLContextKey.TASK_NAME, current_task.task_name)
self.fire_event(EventType.ABORT_TASK, fl_ctx)
[docs] def set_learn_task(self, task_data: Shareable, fl_ctx: FLContext) -> bool:
with self.learn_task_lock:
task_data.set_header(AppConstants.NUM_ROUNDS, self.get_config_prop(AppConstants.NUM_ROUNDS))
task = _LearnTask(self.learn_task_name, task_data, fl_ctx)
current_task = self.learn_task
if not current_task:
self.learn_task = task
return True
if not self.allow_busy_task:
return False
# already has a task!
self.log_warning(fl_ctx, "already running a task: aborting it")
self._abort_current_task(fl_ctx)
# monitor until the task is done
start = time.time()
while self.learn_task:
if time.time() - start > self.learn_task_abort_timeout:
self.log_error(
fl_ctx, f"failed to stop the running task after {self.learn_task_abort_timeout} seconds"
)
return False
time.sleep(0.1)
self.learn_task = task
return True
def _do_learn(self):
while not self.asked_to_stop:
if self.learn_task:
t = self.learn_task
assert isinstance(t, _LearnTask)
self.logger.info(f"Got a Learn task {t.task_name}")
try:
self.do_learn_task(t.task_name, t.task_data, t.fl_ctx, t.abort_signal)
except:
self.logger.log(f"exception from do_learn_task: {secure_format_traceback()}")
self.learn_task = None
time.sleep(self.learn_task_check_interval)
[docs] def update_status(self, last_round=None, action=None, error=None, all_done=False):
with self.status_lock:
status = self.current_status
status.timestamp = time.time()
if all_done:
# once marked all_done, always all_done!
status.all_done = True
if error:
status.error = error
if action:
status.action = action
if status.last_round is None:
status.last_round = last_round
elif last_round is not None and last_round > status.last_round:
status.last_round = last_round
status_dict = status.to_dict()
self.logger.info(f"updated my last status: {status_dict}")
[docs] @abstractmethod
def do_learn_task(self, name: str, data: Shareable, fl_ctx: FLContext, abort_signal: Signal):
"""This is called to do a Learn Task.
Subclass must implement this method.
Args:
name: task name
data: task data
fl_ctx: FL context of the task
abort_signal: abort signal for the task execution
Returns:
"""
pass
def _process_final_result(self, request: Shareable, fl_ctx: FLContext) -> Shareable:
peer_ctx = fl_ctx.get_peer_context()
assert isinstance(peer_ctx, FLContext)
client_name = peer_ctx.get_identity_name()
result = request.get(Constant.RESULT)
metric = request.get_header(Constant.METRIC)
round_num = request.get_header(Constant.ROUND)
result_type = request.get_header(Constant.RESULT_TYPE)
if result_type not in [ResultType.BEST, ResultType.LAST]:
self.log_error(fl_ctx, f"Bad request from client {client_name}: invalid result type {result_type}")
return make_reply(ReturnCode.BAD_REQUEST_DATA)
if not result:
self.log_error(fl_ctx, f"Bad request from client {client_name}: no result")
return make_reply(ReturnCode.BAD_REQUEST_DATA)
if not isinstance(result, Learnable):
self.log_error(fl_ctx, f"Bad result from client {client_name}: expect Learnable but got {type(result)}")
return make_reply(ReturnCode.BAD_REQUEST_DATA)
self.log_info(fl_ctx, f"Got {result_type} from client {client_name} with metric {metric} at round {round_num}")
fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, result, private=True, sticky=True)
if result_type == ResultType.BEST:
fl_ctx.set_prop(Constant.ROUND, round_num, private=True, sticky=False)
fl_ctx.set_prop(Constant.CLIENT, client_name, private=True, sticky=False)
fl_ctx.set_prop(AppConstants.VALIDATION_RESULT, metric, private=True, sticky=False)
self.fire_event(AppEventType.GLOBAL_BEST_MODEL_AVAILABLE, fl_ctx)
else:
# last model
assert isinstance(self.persistor, LearnablePersistor)
self.persistor.save(result, fl_ctx)
return make_reply(ReturnCode.OK)
def _process_end_workflow(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
self.log_info(fl_ctx, f"ending workflow {self.get_config_prop(FLContextKey.WORKFLOW)}")
self.asked_to_stop = True
self._abort_current_task(fl_ctx)
self.finalize(fl_ctx)
return make_reply(ReturnCode.OK)
def _process_learn_request(self, request: Shareable, fl_ctx: FLContext) -> Shareable:
try:
return self._try_process_learn_request(request, fl_ctx)
except Exception as ex:
self.log_exception(fl_ctx, f"exception: {ex}")
self.update_status(action="process_learn_request", error=ReturnCode.EXECUTION_EXCEPTION)
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
def _try_process_learn_request(self, request: Shareable, fl_ctx: FLContext) -> Shareable:
peer_ctx = fl_ctx.get_peer_context()
assert isinstance(peer_ctx, FLContext)
sender = peer_ctx.get_identity_name()
# process request from prev client
self.log_info(fl_ctx, f"Got Learn request from {sender}")
if self.learn_task and not self.allow_busy_task:
# should never happen!
self.log_error(fl_ctx, f"got Learn request from {sender} while I'm still busy!")
self.update_status(action="process_learn_request", error=ReturnCode.EXECUTION_EXCEPTION)
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
self.log_info(fl_ctx, f"accepted learn request from {sender}")
self.set_learn_task(task_data=request, fl_ctx=fl_ctx)
return make_reply(ReturnCode.OK)
[docs] def send_learn_task(self, targets: list, request: Shareable, fl_ctx: FLContext) -> bool:
self.log_info(fl_ctx, f"sending learn task to clients {targets}")
request.set_header(AppConstants.NUM_ROUNDS, self.get_config_prop(AppConstants.NUM_ROUNDS))
task = Task(
name=self.do_learn_task_name,
data=request,
timeout=int(self.learn_task_ack_timeout),
secure=self.is_task_secure(fl_ctx),
)
resp = self.broadcast_and_wait(
task=task,
targets=targets,
min_responses=len(targets),
fl_ctx=fl_ctx,
)
assert isinstance(resp, dict)
for t in targets:
reply = resp.get(t)
if not isinstance(reply, Shareable):
self.log_error(fl_ctx, f"failed to send learn request to client {t}")
self.log_error(fl_ctx, f"reply must be Shareable but got {type(reply)}")
self.update_status(action="send_learn_task", error=ReturnCode.EXECUTION_EXCEPTION)
return False
rc = reply.get_return_code(ReturnCode.OK)
if rc != ReturnCode.OK:
self.log_error(fl_ctx, f"bad response for learn request from client {t}: {rc}")
self.update_status(action="send_learn_task", error=rc)
return False
return True
[docs] def execute_learn_task(self, data: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
current_round = data.get_header(AppConstants.CURRENT_ROUND)
self.log_info(fl_ctx, f"started training round {current_round}")
try:
result = self.learn_executor.execute(self.learn_task_name, data, fl_ctx, abort_signal)
except:
self.log_exception(fl_ctx, f"trainer exception: {secure_format_traceback()}")
result = make_reply(ReturnCode.EXECUTION_EXCEPTION)
self.log_info(fl_ctx, f"finished training round {current_round}")
# make sure to include cookies in result
cookie_jar = data.get_cookie_jar()
result.set_cookie_jar(cookie_jar)
result.set_header(AppConstants.CURRENT_ROUND, current_round)
result.add_cookie(AppConstants.CONTRIBUTION_ROUND, current_round) # to make model selector happy
return result
[docs] def record_last_result(
self,
fl_ctx: FLContext,
round_num: int,
result: Learnable,
):
if not isinstance(result, Learnable):
self.log_error(fl_ctx, f"result must be Learnable but got {type(result)}")
return
self.last_result = result
self.last_round = round_num
fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, result, private=True, sticky=True)
if self.persistor:
self.log_info(fl_ctx, f"Saving result of round {round_num}")
self.persistor.save(result, fl_ctx)
[docs] def is_task_secure(self, fl_ctx: FLContext) -> bool:
"""
Determine whether the task should be secure. A secure task requires encrypted communication between the peers.
The task is secure only when the training is in secure mode AND private_p2p is set to True.
"""
private_p2p = self.get_config_prop(Constant.PRIVATE_P2P)
secure_train = fl_ctx.get_prop(FLContextKey.SECURE_MODE, False)
return private_p2p and secure_train