Source code for nvflare.private.fed.client.client_runner

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey, ReservedKey, ReservedTopic, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.fl_exception import UnsafeJobError
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.apis.utils.fl_context_utils import add_job_audit_event
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.private.defs import SpecialTaskName, TaskConstant
from nvflare.private.fed.client.client_engine_executor_spec import ClientEngineExecutorSpec, TaskAssignment
from nvflare.private.privacy_manager import Scope
from nvflare.security.logging import secure_format_exception
from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector

_TASK_CHECK_RESULT_OK = 0
_TASK_CHECK_RESULT_TRY_AGAIN = 1
_TASK_CHECK_RESULT_TASK_GONE = 2


[docs]class ClientRunnerConfig(object): def __init__( self, task_table: dict, # task_name => Executor task_data_filters: dict, # task_name => list of filters task_result_filters: dict, # task_name => list of filters handlers=None, # list of event handlers components=None, # dict of extra python objects: id => object default_task_fetch_interval: float = 0.5, ): """To init ClientRunnerConfig. Args: task_table: task_name: Executor dict task_data_filters: task_name => list of data filters task_result_filters: task_name => list of result filters handlers: list of event handlers components: dict of extra python objects: id => object default_task_fetch_interval: default task fetch interval before getting the correct value from server. default is set to 0.5. """ self.task_table = task_table self.task_data_filters = task_data_filters self.task_result_filters = task_result_filters self.handlers = handlers self.components = components self.default_task_fetch_interval = default_task_fetch_interval if not components: self.components = {} if not handlers: self.handlers = []
[docs] def add_component(self, comp_id: str, component: object): if not isinstance(comp_id, str): raise TypeError(f"component id must be str but got {type(comp_id)}") if comp_id in self.components: raise ValueError(f"duplicate component id {comp_id}") self.components[comp_id] = component if isinstance(component, FLComponent): self.handlers.append(component)
[docs]class ClientRunner(FLComponent): def __init__( self, config: ClientRunnerConfig, job_id, engine: ClientEngineExecutorSpec, ): """Initializes the ClientRunner. Args: config: ClientRunnerConfig job_id: job id engine: ClientEngine object """ FLComponent.__init__(self) self.task_table = config.task_table self.task_data_filters = config.task_data_filters self.task_result_filters = config.task_result_filters self.default_task_fetch_interval = config.default_task_fetch_interval self.job_id = job_id self.engine = engine self.run_abort_signal = Signal() self.task_abort_signal = None self.current_executor = None self.current_task = None self.asked_to_stop = False self.task_lock = threading.Lock() self.end_run_fired = False self.end_run_lock = threading.Lock() self.task_check_timeout = 5.0 self.task_check_interval = 5.0 self._register_aux_message_handler(engine) def _register_aux_message_handler(self, engine): engine.register_aux_message_handler(topic=ReservedTopic.END_RUN, message_handle_func=self._handle_end_run) engine.register_aux_message_handler(topic=ReservedTopic.ABORT_ASK, message_handle_func=self._handle_abort_task) @staticmethod def _reply_and_audit(reply: Shareable, ref, msg, fl_ctx: FLContext) -> Shareable: audit_event_id = add_job_audit_event(fl_ctx=fl_ctx, ref=ref, msg=msg) reply.set_header(ReservedKey.AUDIT_EVENT_ID, audit_event_id) return reply def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable: if not isinstance(task.data, Shareable): self.log_error( fl_ctx, "got invalid task data in assignment: expect Shareable, but got {}".format(type(task.data)) ) return make_reply(ReturnCode.BAD_TASK_DATA) fl_ctx.set_prop(FLContextKey.TASK_DATA, value=task.data, private=True, sticky=False) fl_ctx.set_prop(FLContextKey.TASK_NAME, value=task.name, private=True, sticky=False) fl_ctx.set_prop(FLContextKey.TASK_ID, value=task.task_id, private=True, sticky=False) server_audit_event_id = task.data.get_header(ReservedKey.AUDIT_EVENT_ID, "") add_job_audit_event(fl_ctx=fl_ctx, ref=server_audit_event_id, msg="received task from server") peer_ctx = fl_ctx.get_peer_context() if not peer_ctx: self.log_error(fl_ctx, "missing peer context in Server task assignment") return self._reply_and_audit( reply=make_reply(ReturnCode.MISSING_PEER_CONTEXT), ref=server_audit_event_id, fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.MISSING_PEER_CONTEXT}", ) if not isinstance(peer_ctx, FLContext): self.log_error( fl_ctx, "bad peer context in Server task assignment: expects FLContext but got {}".format(type(peer_ctx)), ) return self._reply_and_audit( reply=make_reply(ReturnCode.BAD_PEER_CONTEXT), ref=server_audit_event_id, fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.BAD_PEER_CONTEXT}", ) task.data.set_peer_props(peer_ctx.get_all_public_props()) peer_job_id = peer_ctx.get_job_id() if peer_job_id != self.job_id: self.log_error(fl_ctx, "bad task assignment: not for the same job_id") return self._reply_and_audit( reply=make_reply(ReturnCode.RUN_MISMATCH), ref=server_audit_event_id, fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.RUN_MISMATCH}", ) executor = self.task_table.get(task.name) if not executor: self.log_error(fl_ctx, f"bad task assignment: no executor available for task {task.name}") return self._reply_and_audit( reply=make_reply(ReturnCode.TASK_UNKNOWN), ref=server_audit_event_id, fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.TASK_UNKNOWN}", ) executor_name = executor.__class__.__name__ self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_DATA_FILTER") self.fire_event(EventType.BEFORE_TASK_DATA_FILTER, fl_ctx) # first apply privacy-defined filters scope_object = fl_ctx.get_prop(FLContextKey.SCOPE_OBJECT) filter_list = [] if scope_object: assert isinstance(scope_object, Scope) if scope_object.task_data_filters: filter_list.extend(scope_object.task_data_filters) task_filter_list = self.task_data_filters.get(task.name) if task_filter_list: filter_list.extend(task_filter_list) if filter_list: task_data = task.data for f in filter_list: filter_name = f.__class__.__name__ try: task_data = f.process(task_data, fl_ctx) except UnsafeJobError: self.log_exception(fl_ctx, f"UnsafeJobError from Task Data Filter {filter_name}") executor.unsafe = True fl_ctx.set_job_is_unsafe() return self._reply_and_audit( reply=make_reply(ReturnCode.UNSAFE_JOB), ref=server_audit_event_id, fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.UNSAFE_JOB}", ) except Exception as e: self.log_exception( fl_ctx, f"Processing error from Task Data Filter {filter_name}: {secure_format_exception(e)}" ) return self._reply_and_audit( reply=make_reply(ReturnCode.TASK_DATA_FILTER_ERROR), ref=server_audit_event_id, fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.TASK_DATA_FILTER_ERROR}", ) if not isinstance(task_data, Shareable): self.log_error( fl_ctx, "task data was converted to wrong type: expect Shareable but got {}".format(type(task_data)) ) return self._reply_and_audit( reply=make_reply(ReturnCode.TASK_DATA_FILTER_ERROR), ref=server_audit_event_id, fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.TASK_DATA_FILTER_ERROR}", ) task.data = task_data self.log_debug(fl_ctx, "firing event EventType.AFTER_TASK_DATA_FILTER") fl_ctx.set_prop(FLContextKey.TASK_DATA, value=task.data, private=True, sticky=False) self.fire_event(EventType.AFTER_TASK_DATA_FILTER, fl_ctx) self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_EXECUTION") fl_ctx.set_prop(FLContextKey.TASK_DATA, value=task.data, private=True, sticky=False) self.fire_event(EventType.BEFORE_TASK_EXECUTION, fl_ctx) try: self.log_info(fl_ctx, f"invoking task executor {executor_name}") add_job_audit_event(fl_ctx=fl_ctx, msg=f"invoked executor {executor_name}") with self.task_lock: self.task_abort_signal = Signal() self.current_executor = executor self.current_task = task try: reply = executor.execute(task.name, task.data, fl_ctx, self.task_abort_signal) finally: with self.task_lock: if self.task_abort_signal is None: task_aborted = True else: task_aborted = False self.task_abort_signal = None self.current_task = None self.current_executor = None if task_aborted: return self._reply_and_audit( reply=make_reply(ReturnCode.TASK_ABORTED), ref=server_audit_event_id, fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.TASK_ABORTED}", ) if not isinstance(reply, Shareable): self.log_error( fl_ctx, f"bad result generated by executor {executor_name}: must be Shareable but got {type(reply)}" ) return self._reply_and_audit( reply=make_reply(ReturnCode.EXECUTION_RESULT_ERROR), ref=server_audit_event_id, fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.EXECUTION_RESULT_ERROR}", ) except RuntimeError as e: self.log_exception( fl_ctx, f"RuntimeError from executor {executor_name}: {secure_format_exception(e)}: Aborting the job!" ) self.asked_to_stop = True return self._reply_and_audit( reply=make_reply(ReturnCode.EXECUTION_RESULT_ERROR), ref=server_audit_event_id, fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.EXECUTION_RESULT_ERROR}", ) except UnsafeJobError: self.log_exception(fl_ctx, f"UnsafeJobError from executor {executor_name}") executor.unsafe = True fl_ctx.set_job_is_unsafe() return self._reply_and_audit( reply=make_reply(ReturnCode.UNSAFE_JOB), ref=server_audit_event_id, fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.UNSAFE_JOB}", ) except Exception as e: self.log_exception(fl_ctx, f"Processing error from executor {executor_name}: {secure_format_exception(e)}") return self._reply_and_audit( reply=make_reply(ReturnCode.EXECUTION_EXCEPTION), ref=server_audit_event_id, fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.EXECUTION_EXCEPTION}", ) fl_ctx.set_prop(FLContextKey.TASK_RESULT, value=reply, private=True, sticky=False) self.log_debug(fl_ctx, "firing event EventType.AFTER_TASK_EXECUTION") self.fire_event(EventType.AFTER_TASK_EXECUTION, fl_ctx) self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_RESULT_FILTER") self.fire_event(EventType.BEFORE_TASK_RESULT_FILTER, fl_ctx) filter_list = [] if scope_object and scope_object.task_result_filters: filter_list.extend(scope_object.task_result_filters) task_filter_list = self.task_result_filters.get(task.name) if task_filter_list: filter_list.extend(task_filter_list) if filter_list: for f in filter_list: filter_name = f.__class__.__name__ try: reply = f.process(reply, fl_ctx) except UnsafeJobError: self.log_exception(fl_ctx, f"UnsafeJobError from Task Result Filter {filter_name}") executor.unsafe = True fl_ctx.set_job_is_unsafe() return self._reply_and_audit( reply=make_reply(ReturnCode.UNSAFE_JOB), ref=server_audit_event_id, fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.UNSAFE_JOB}", ) except Exception as e: self.log_exception( fl_ctx, f"Processing error in Task Result Filter {filter_name}: {secure_format_exception(e)}" ) return self._reply_and_audit( reply=make_reply(ReturnCode.TASK_RESULT_FILTER_ERROR), ref=server_audit_event_id, fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.TASK_RESULT_FILTER_ERROR}", ) if not isinstance(reply, Shareable): self.log_error( fl_ctx, "task result was converted to wrong type: expect Shareable but got {}".format(type(reply)) ) return self._reply_and_audit( reply=make_reply(ReturnCode.TASK_RESULT_FILTER_ERROR), ref=server_audit_event_id, fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.TASK_RESULT_FILTER_ERROR}", ) fl_ctx.set_prop(FLContextKey.TASK_RESULT, value=reply, private=True, sticky=False) self.log_debug(fl_ctx, "firing event EventType.AFTER_TASK_RESULT_FILTER") self.fire_event(EventType.AFTER_TASK_RESULT_FILTER, fl_ctx) self.log_info(fl_ctx, "finished processing task") if not isinstance(reply, Shareable): self.log_error( fl_ctx, "task processing error: expects result to be Shareable, but got {}".format(type(reply)) ) return self._reply_and_audit( reply=make_reply(ReturnCode.EXECUTION_RESULT_ERROR), ref=server_audit_event_id, fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.EXECUTION_RESULT_ERROR}", ) return self._reply_and_audit(reply=reply, ref=server_audit_event_id, fl_ctx=fl_ctx, msg="submit result OK") def _check_stop_conditions(self, fl_ctx: FLContext) -> bool: if fl_ctx.is_job_unsafe(): self.log_info(fl_ctx, "stopped unsafe job!") return True if self.run_abort_signal.triggered: self.log_info(fl_ctx, "run abort signal received") return True return False def _try_run(self): heartbeat_thread = threading.Thread(target=self.send_job_heartbeat, args=[], daemon=True) heartbeat_thread.start() while not self.asked_to_stop: with self.engine.new_context() as fl_ctx: if self._check_stop_conditions(fl_ctx): break task_fetch_interval, _ = self.fetch_and_run_one_task(fl_ctx) if self._check_stop_conditions(fl_ctx): break time.sleep(task_fetch_interval)
[docs] def send_job_heartbeat(self, interval=30.0): wait_times = int(interval / 2) request = Shareable() while not self.asked_to_stop: with self.engine.new_context() as fl_ctx: self.engine.send_aux_request( targets=[FQCN.ROOT_SERVER], topic=ReservedTopic.JOB_HEART_BEAT, request=request, timeout=0, fl_ctx=fl_ctx, optional=True, ) for i in range(wait_times): time.sleep(2) if self.asked_to_stop: break
[docs] def fetch_and_run_one_task(self, fl_ctx) -> (float, bool): """Fetches and runs a task. Returns: A tuple of (task_fetch_interval, task_processed). """ default_task_fetch_interval = self.default_task_fetch_interval self.log_debug(fl_ctx, "fetching task from server ...") task = self.engine.get_task_assignment(fl_ctx) if not task: self.log_debug(fl_ctx, "no task received - will try in {} secs".format(default_task_fetch_interval)) return default_task_fetch_interval, False elif task.name == SpecialTaskName.END_RUN: self.log_info(fl_ctx, "server asked to end the run") self.asked_to_stop = True return default_task_fetch_interval, False elif task.name == SpecialTaskName.TRY_AGAIN: task_data = task.data task_fetch_interval = default_task_fetch_interval if task_data and isinstance(task_data, Shareable): task_fetch_interval = task_data.get_header(TaskConstant.WAIT_TIME, task_fetch_interval) self.log_debug(fl_ctx, "server asked to try again - will try in {} secs".format(task_fetch_interval)) return task_fetch_interval, False if task.name not in [SpecialTaskName.END_RUN, SpecialTaskName.TRY_AGAIN]: self.log_info(fl_ctx, "got task assignment: name={}, id={}".format(task.name, task.task_id)) task_data = task.data if not isinstance(task_data, Shareable): raise TypeError("task_data must be Shareable, but got {}".format(type(task_data))) task_fetch_interval = task_data.get_header(TaskConstant.WAIT_TIME, default_task_fetch_interval) # create a new task abort signal task_reply = self._process_task(task, fl_ctx) if not isinstance(task_reply, Shareable): raise TypeError("task_reply must be Shareable, but got {}".format(type(task_reply))) self.log_debug(fl_ctx, "firing event EventType.BEFORE_SEND_TASK_RESULT") self.fire_event(EventType.BEFORE_SEND_TASK_RESULT, fl_ctx) # set the cookie in the reply! cookie_jar = task_data.get_cookie_jar() if cookie_jar: task_reply.set_cookie_jar(cookie_jar) self._send_task_result(task_reply, task.task_id, fl_ctx) self.log_debug(fl_ctx, "firing event EventType.AFTER_SEND_TASK_RESULT") self.fire_event(EventType.AFTER_SEND_TASK_RESULT, fl_ctx) return task_fetch_interval, True
def _send_task_result(self, result: Shareable, task_id: str, fl_ctx: FLContext): try_count = 1 while True: self.log_info(fl_ctx, f"try #{try_count}: sending task result to server") if self.asked_to_stop: self.log_info(fl_ctx, "job aborted: stopped trying to send result") return False try_count += 1 rc = self._try_send_result_once(result, task_id, fl_ctx) if rc == _TASK_CHECK_RESULT_OK: return True elif rc == _TASK_CHECK_RESULT_TASK_GONE: return False else: # retry time.sleep(self.task_check_interval) def _try_send_result_once(self, result: Shareable, task_id: str, fl_ctx: FLContext): # wait until server is ready to receive while True: if self.asked_to_stop: return _TASK_CHECK_RESULT_TASK_GONE rc = self._check_task_once(task_id, fl_ctx) if rc == _TASK_CHECK_RESULT_OK: break elif rc == _TASK_CHECK_RESULT_TASK_GONE: return rc else: # try again time.sleep(self.task_check_interval) # try to send the result self.log_info(fl_ctx, "start to send task result to server") reply_sent = self.engine.send_task_result(result, fl_ctx) if reply_sent: self.log_info(fl_ctx, "task result sent to server") return _TASK_CHECK_RESULT_OK else: self.log_error(fl_ctx, "failed to send task result to server - will try again") return _TASK_CHECK_RESULT_TRY_AGAIN def _check_task_once(self, task_id: str, fl_ctx: FLContext) -> int: """This method checks whether the server is still waiting for the specified task. The real reason for this method is to fight against unstable network connections. We try to make sure that when we send task result to the server, the connection is available. If the task check succeeds, then the network connection is likely to be available. Otherwise, we keep retrying until task check succeeds or the server tells us that the task is gone (timed out). Args: task_id: fl_ctx: Returns: """ self.log_info(fl_ctx, "checking task ...") task_check_req = Shareable() task_check_req.set_header(ReservedKey.TASK_ID, task_id) resp = self.engine.send_aux_request( targets=[FQCN.ROOT_SERVER], topic=ReservedTopic.TASK_CHECK, request=task_check_req, timeout=self.task_check_timeout, fl_ctx=fl_ctx, optional=True, ) if resp and isinstance(resp, dict): reply = resp.get(FQCN.ROOT_SERVER) if not isinstance(reply, Shareable): self.log_error(fl_ctx, f"bad task_check reply from server: expect Shareable but got {type(reply)}") return _TASK_CHECK_RESULT_TRY_AGAIN rc = reply.get_return_code() if rc == ReturnCode.OK: return _TASK_CHECK_RESULT_OK elif rc == ReturnCode.COMMUNICATION_ERROR: self.log_error(fl_ctx, f"failed task_check: {rc}") return _TASK_CHECK_RESULT_TRY_AGAIN elif rc == ReturnCode.SERVER_NOT_READY: self.log_error(fl_ctx, f"server rejected task_check: {rc}") return _TASK_CHECK_RESULT_TRY_AGAIN elif rc == ReturnCode.TASK_UNKNOWN: self.log_error(fl_ctx, f"task no longer exists on server: {rc}") return _TASK_CHECK_RESULT_TASK_GONE else: # this should never happen self.log_error(fl_ctx, f"programming error: received {rc} from server") return _TASK_CHECK_RESULT_OK # try to push the result regardless else: self.log_error(fl_ctx, f"bad task_check reply from server: invalid resp {type(resp)}") return _TASK_CHECK_RESULT_TRY_AGAIN
[docs] def run(self, app_root, args): self.init_run(app_root, args) try: self._try_run() except Exception as e: with self.engine.new_context() as fl_ctx: self.log_exception(fl_ctx, f"processing error in RUN execution: {secure_format_exception(e)}") finally: # in case any task is still running, abort it self._abort_current_task() self.end_run_events_sequence("run method")
[docs] def init_run(self, app_root, args): with self.engine.new_context() as fl_ctx: self.fire_event(EventType.ABOUT_TO_START_RUN, fl_ctx) fl_ctx.set_prop(FLContextKey.APP_ROOT, app_root, sticky=True) fl_ctx.set_prop(FLContextKey.ARGS, args, sticky=True) fl_ctx.set_prop(ReservedKey.RUN_ABORT_SIGNAL, self.run_abort_signal, private=True, sticky=True) self.log_debug(fl_ctx, "firing event EventType.START_RUN") self.fire_event(EventType.START_RUN, fl_ctx) self.log_info(fl_ctx, "client runner started") with self.end_run_lock: self.end_run_fired = False
def _abort_current_task(self): with self.task_lock: task_abort_signal = self.task_abort_signal if task_abort_signal: # set task_abort_signal to None to prevent triggering again self.task_abort_signal = None task_name = "" task_id = "" task = self.current_task if task: task_name = task.name task_id = task.task_id with self.engine.new_context() as fl_ctx: fl_ctx.set_prop(FLContextKey.TASK_NAME, value=task_name, private=True, sticky=False) fl_ctx.set_prop(FLContextKey.TASK_ID, value=task_id, private=True, sticky=False) task_abort_signal.trigger(True) self.log_info(fl_ctx, "triggered task_abort_signal to stop task '{}'".format(task_name)) self.fire_event(EventType.ABORT_TASK, fl_ctx) self.log_info(fl_ctx, "fired ABORT_TASK event to abort current task {}".format(task_name))
[docs] def abort_task(self, task_names=None): has_task_to_abort = False with self.engine.new_context() as fl_ctx: with self.task_lock: if self.current_task: name = self.current_task.name if not task_names or name in task_names: has_task_to_abort = True else: self.log_info( fl_ctx, "Ignored abort_task request since current task '{}' is not target".format(name) ) else: self.log_info(fl_ctx, "Ignored abort_task request since there is no current task") if has_task_to_abort: self._abort_current_task()
[docs] def end_run_events_sequence(self, requester): with self.engine.new_context() as fl_ctx: self.log_info(fl_ctx, f"{requester} requests end run events sequence") with self.end_run_lock: if not self.end_run_fired: self.fire_event(EventType.ABOUT_TO_END_RUN, fl_ctx) self.log_info(fl_ctx, "ABOUT_TO_END_RUN fired") self.fire_event(EventType.END_RUN, fl_ctx) self.log_info(fl_ctx, "END_RUN fired") self.end_run_fired = True
[docs] def abort(self): """To Abort the current run. Returns: N/A """ with self.engine.new_context() as fl_ctx: self.log_info(fl_ctx, "ABORT (RUN) command received") self._abort_current_task() self.run_abort_signal.trigger("ABORT (RUN) triggered") self.asked_to_stop = True self.end_run_events_sequence("ABORT (RUN)")
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == InfoCollector.EVENT_TYPE_GET_STATS: collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR) if collector: if not isinstance(collector, GroupInfoCollector): raise TypeError("collector must be GroupInfoCollector, but got {}".format(type(collector))) if self.current_task: current_task_name = self.current_task.name else: current_task_name = "None" collector.set_info( group_name="ClientRunner", info={"job_id": self.job_id, "current_task_name": current_task_name, "status": "started"}, ) elif event_type == EventType.FATAL_TASK_ERROR: reason = fl_ctx.get_prop(key=FLContextKey.EVENT_DATA, default="") self.log_error(fl_ctx, "Aborting current task due to FATAL_TASK_ERROR received: {}".format(reason)) self._abort_current_task() elif event_type == EventType.FATAL_SYSTEM_ERROR: reason = fl_ctx.get_prop(key=FLContextKey.EVENT_DATA, default="") self.log_error(fl_ctx, "Aborting current RUN due to FATAL_SYSTEM_ERROR received: {}".format(reason)) self.abort()
def _handle_end_run(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: self.log_info(fl_ctx, "received aux request from Server to end current RUN") self.abort() return make_reply(ReturnCode.OK) def _handle_abort_task(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: self.log_info(fl_ctx, "received aux request from Server to abort current task") task_names = request.get("task_names", None) self.abort_task(task_names) return make_reply(ReturnCode.OK)