# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fnmatch
import threading
import time
from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import ConfigVarName, FilterKey, FLContextKey, ReservedKey, ReservedTopic, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.fl_exception import UnsafeJobError
from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.apis.utils.fl_context_utils import add_job_audit_event
from nvflare.apis.utils.reliable_message import ReliableMessage
from nvflare.apis.utils.task_utils import apply_filters
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.private.defs import SpecialTaskName, TaskConstant
from nvflare.private.fed.client.client_engine_executor_spec import ClientEngineExecutorSpec, TaskAssignment
from nvflare.private.fed.tbi import TBI
from nvflare.private.json_configer import ConfigError
from nvflare.private.privacy_manager import Scope
from nvflare.security.logging import secure_format_exception
from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector
_TASK_CHECK_RESULT_OK = 0
_TASK_CHECK_RESULT_TRY_AGAIN = 1
_TASK_CHECK_RESULT_TASK_GONE = 2
[docs]class TaskRouter:
def __init__(self):
self.task_table = {}
self.patterns = []
@staticmethod
def _is_pattern(p: str):
return "*" in p
[docs] def add_executor(self, tasks: list, executor: Executor):
for t in tasks:
assert isinstance(t, str)
if t in self.task_table:
raise ConfigError(f'multiple executors defined for task "{t}"')
self.task_table[t] = executor
if self._is_pattern(t):
self.patterns.append((t, executor))
[docs] def route(self, task_name: str):
e = self.task_table.get(task_name)
if e:
return e
# check patterns
for p, e in self.patterns:
if fnmatch.fnmatch(task_name, p):
return e
return None
[docs]class ClientRunnerConfig(object):
def __init__(
self,
task_router: TaskRouter,
task_data_filters: dict, # task_name => list of filters
task_result_filters: dict, # task_name => list of filters
handlers=None, # list of event handlers
components=None, # dict of extra python objects: id => object
default_task_fetch_interval: float = 0.5,
):
"""To init ClientRunnerConfig.
Args:
task_router: TaskRouter object to find executor for a task
task_data_filters: task_name => list of data filters
task_result_filters: task_name => list of result filters
handlers: list of event handlers
components: dict of extra python objects: id => object
default_task_fetch_interval: default task fetch interval before getting the correct value from server.
default is set to 0.5.
"""
self.task_router = task_router
self.task_data_filters = task_data_filters
self.task_result_filters = task_result_filters
self.handlers = handlers
self.components = components
self.default_task_fetch_interval = default_task_fetch_interval
if not components:
self.components = {}
if not handlers:
self.handlers = []
[docs] def add_component(self, comp_id: str, component: object):
if not isinstance(comp_id, str):
raise TypeError(f"component id must be str but got {type(comp_id)}")
if comp_id in self.components:
raise ValueError(f"duplicate component id {comp_id}")
self.components[comp_id] = component
if isinstance(component, FLComponent):
self.handlers.append(component)
[docs]class ClientRunner(TBI):
def __init__(
self,
config: ClientRunnerConfig,
job_id,
engine: ClientEngineExecutorSpec,
):
"""Initializes the ClientRunner.
Args:
config: ClientRunnerConfig
job_id: job id
engine: ClientEngine object
"""
TBI.__init__(self)
self.task_router = config.task_router
self.task_data_filters = config.task_data_filters
self.task_result_filters = config.task_result_filters
self.default_task_fetch_interval = config.default_task_fetch_interval
self.job_id = job_id
self.engine = engine
self.run_abort_signal = Signal()
self.task_lock = threading.Lock()
self.running_tasks = {} # task_id => TaskAssignment
self.task_check_timeout = self.get_positive_float_var(ConfigVarName.TASK_CHECK_TIMEOUT, 5.0)
self.task_check_interval = self.get_positive_float_var(ConfigVarName.TASK_CHECK_INTERVAL, 5.0)
self.job_heartbeat_interval = self.get_positive_float_var(ConfigVarName.JOB_HEARTBEAT_INTERVAL, 10.0)
self.get_task_timeout = self.get_positive_float_var(ConfigVarName.GET_TASK_TIMEOUT, None)
self.submit_task_result_timeout = self.get_positive_float_var(ConfigVarName.SUBMIT_TASK_RESULT_TIMEOUT, None)
self._register_aux_message_handlers(engine)
[docs] def find_executor(self, task_name):
return self.task_router.route(task_name)
def _register_aux_message_handlers(self, engine):
engine.register_aux_message_handler(topic=ReservedTopic.END_RUN, message_handle_func=self._handle_end_run)
engine.register_aux_message_handler(topic=ReservedTopic.DO_TASK, message_handle_func=self._handle_do_task)
@staticmethod
def _reply_and_audit(reply: Shareable, ref, msg, fl_ctx: FLContext) -> Shareable:
audit_event_id = add_job_audit_event(fl_ctx=fl_ctx, ref=ref, msg=msg)
reply.set_header(ReservedKey.AUDIT_EVENT_ID, audit_event_id)
return reply
def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable:
if fl_ctx.is_job_unsafe():
return make_reply(ReturnCode.UNSAFE_JOB)
with self.task_lock:
self.running_tasks[task.task_id] = task
abort_signal = Signal(parent=self.run_abort_signal)
try:
reply = self._do_process_task(task, fl_ctx, abort_signal)
except Exception as ex:
self.log_exception(fl_ctx, secure_format_exception(ex))
reply = make_reply(ReturnCode.EXECUTION_EXCEPTION)
with self.task_lock:
self.running_tasks.pop(task.task_id, None)
if not isinstance(reply, Shareable):
self.log_error(fl_ctx, f"task reply must be Shareable, but got {type(reply)}")
reply = make_reply(ReturnCode.EXECUTION_EXCEPTION)
cookie_jar = task.data.get_cookie_jar()
if cookie_jar:
reply.set_cookie_jar(cookie_jar)
reply.set_header(ReservedHeaderKey.TASK_NAME, task.name)
return reply
def _do_process_task(self, task: TaskAssignment, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
if not isinstance(task.data, Shareable):
self.log_error(fl_ctx, f"got invalid task data in assignment: expect Shareable, but got {type(task.data)}")
return make_reply(ReturnCode.BAD_TASK_DATA)
fl_ctx.set_prop(FLContextKey.TASK_DATA, value=task.data, private=True, sticky=False)
fl_ctx.set_prop(FLContextKey.TASK_NAME, value=task.name, private=True, sticky=False)
fl_ctx.set_prop(FLContextKey.TASK_ID, value=task.task_id, private=True, sticky=False)
server_audit_event_id = task.data.get_header(ReservedKey.AUDIT_EVENT_ID, "")
add_job_audit_event(fl_ctx=fl_ctx, ref=server_audit_event_id, msg="received task from server")
peer_ctx = fl_ctx.get_peer_context()
if not peer_ctx:
self.log_error(fl_ctx, "missing peer context in Server task assignment")
return self._reply_and_audit(
reply=make_reply(ReturnCode.MISSING_PEER_CONTEXT),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.MISSING_PEER_CONTEXT}",
)
if not isinstance(peer_ctx, FLContext):
self.log_error(
fl_ctx,
"bad peer context in Server task assignment: expects FLContext but got {}".format(type(peer_ctx)),
)
return self._reply_and_audit(
reply=make_reply(ReturnCode.BAD_PEER_CONTEXT),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.BAD_PEER_CONTEXT}",
)
task.data.set_peer_props(peer_ctx.get_all_public_props())
peer_job_id = peer_ctx.get_job_id()
if peer_job_id != self.job_id:
self.log_error(fl_ctx, "bad task assignment: not for the same job_id")
return self._reply_and_audit(
reply=make_reply(ReturnCode.RUN_MISMATCH),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.RUN_MISMATCH}",
)
executor = self.find_executor(task.name)
if not executor:
self.log_error(fl_ctx, f"bad task assignment: no executor available for task {task.name}")
return self._reply_and_audit(
reply=make_reply(ReturnCode.TASK_UNKNOWN),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.TASK_UNKNOWN}",
)
executor_name = executor.__class__.__name__
self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_DATA_FILTER")
self.fire_event(EventType.BEFORE_TASK_DATA_FILTER, fl_ctx)
task_data = task.data
try:
filter_name = Scope.TASK_DATA_FILTERS_NAME
task_data = apply_filters(filter_name, task_data, fl_ctx, self.task_data_filters, task.name, FilterKey.IN)
except UnsafeJobError:
self.log_exception(fl_ctx, "UnsafeJobError from Task Data Filters")
executor.unsafe = True
fl_ctx.set_job_is_unsafe()
self.run_abort_signal.trigger(True)
return self._reply_and_audit(
reply=make_reply(ReturnCode.UNSAFE_JOB),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.UNSAFE_JOB}",
)
except Exception as e:
self.log_exception(fl_ctx, f"Processing error from Task Data Filters : {secure_format_exception(e)}")
return self._reply_and_audit(
reply=make_reply(ReturnCode.TASK_DATA_FILTER_ERROR),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.TASK_DATA_FILTER_ERROR}",
)
if not isinstance(task_data, Shareable):
self.log_error(
fl_ctx, "task data was converted to wrong type: expect Shareable but got {}".format(type(task_data))
)
return self._reply_and_audit(
reply=make_reply(ReturnCode.TASK_DATA_FILTER_ERROR),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.TASK_DATA_FILTER_ERROR}",
)
task.data = task_data
self.log_debug(fl_ctx, "firing event EventType.AFTER_TASK_DATA_FILTER")
fl_ctx.set_prop(FLContextKey.TASK_DATA, value=task.data, private=True, sticky=False)
self.fire_event(EventType.AFTER_TASK_DATA_FILTER, fl_ctx)
self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_EXECUTION")
fl_ctx.set_prop(FLContextKey.TASK_DATA, value=task.data, private=True, sticky=False)
self.fire_event(EventType.BEFORE_TASK_EXECUTION, fl_ctx)
try:
self.log_info(fl_ctx, f"invoking task executor {executor_name}")
add_job_audit_event(fl_ctx=fl_ctx, msg=f"invoked executor {executor_name}")
try:
reply = executor.execute(task.name, task.data, fl_ctx, abort_signal)
finally:
if abort_signal.triggered:
return self._reply_and_audit(
reply=make_reply(ReturnCode.TASK_ABORTED),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.TASK_ABORTED}",
)
if not isinstance(reply, Shareable):
self.log_error(
fl_ctx, f"bad result generated by executor {executor_name}: must be Shareable but got {type(reply)}"
)
return self._reply_and_audit(
reply=make_reply(ReturnCode.EXECUTION_RESULT_ERROR),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.EXECUTION_RESULT_ERROR}",
)
except RuntimeError as e:
self.log_exception(
fl_ctx, f"RuntimeError from executor {executor_name}: {secure_format_exception(e)}: Aborting the job!"
)
return self._reply_and_audit(
reply=make_reply(ReturnCode.EXECUTION_RESULT_ERROR),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.EXECUTION_RESULT_ERROR}",
)
except UnsafeJobError:
self.log_exception(fl_ctx, f"UnsafeJobError from executor {executor_name}")
executor.unsafe = True
fl_ctx.set_job_is_unsafe()
return self._reply_and_audit(
reply=make_reply(ReturnCode.UNSAFE_JOB),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.UNSAFE_JOB}",
)
except Exception as e:
self.log_exception(fl_ctx, f"Processing error from executor {executor_name}: {secure_format_exception(e)}")
return self._reply_and_audit(
reply=make_reply(ReturnCode.EXECUTION_EXCEPTION),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.EXECUTION_EXCEPTION}",
)
fl_ctx.set_prop(FLContextKey.TASK_RESULT, value=reply, private=True, sticky=False)
self.log_debug(fl_ctx, "firing event EventType.AFTER_TASK_EXECUTION")
self.fire_event(EventType.AFTER_TASK_EXECUTION, fl_ctx)
self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_RESULT_FILTER")
self.fire_event(EventType.BEFORE_TASK_RESULT_FILTER, fl_ctx)
try:
filter_name = Scope.TASK_RESULT_FILTERS_NAME
reply = apply_filters(filter_name, reply, fl_ctx, self.task_result_filters, task.name, FilterKey.OUT)
except UnsafeJobError:
self.log_exception(fl_ctx, "UnsafeJobError from Task Result Filters")
executor.unsafe = True
fl_ctx.set_job_is_unsafe()
return self._reply_and_audit(
reply=make_reply(ReturnCode.UNSAFE_JOB),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.UNSAFE_JOB}",
)
except Exception as e:
self.log_exception(fl_ctx, f"Processing error in Task Result Filter : {secure_format_exception(e)}")
return self._reply_and_audit(
reply=make_reply(ReturnCode.TASK_RESULT_FILTER_ERROR),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.TASK_RESULT_FILTER_ERROR}",
)
if not isinstance(reply, Shareable):
self.log_error(
fl_ctx, "task result was converted to wrong type: expect Shareable but got {}".format(type(reply))
)
return self._reply_and_audit(
reply=make_reply(ReturnCode.TASK_RESULT_FILTER_ERROR),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.TASK_RESULT_FILTER_ERROR}",
)
fl_ctx.set_prop(FLContextKey.TASK_RESULT, value=reply, private=True, sticky=False)
self.log_debug(fl_ctx, "firing event EventType.AFTER_TASK_RESULT_FILTER")
self.fire_event(EventType.AFTER_TASK_RESULT_FILTER, fl_ctx)
self.log_info(fl_ctx, "finished processing task")
if not isinstance(reply, Shareable):
self.log_error(
fl_ctx, "task processing error: expects result to be Shareable, but got {}".format(type(reply))
)
return self._reply_and_audit(
reply=make_reply(ReturnCode.EXECUTION_RESULT_ERROR),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.EXECUTION_RESULT_ERROR}",
)
return self._reply_and_audit(reply=reply, ref=server_audit_event_id, fl_ctx=fl_ctx, msg="submit result OK")
def _try_run(self):
heartbeat_thread = threading.Thread(target=self._send_job_heartbeat, args=[], daemon=True)
heartbeat_thread.start()
while not self.run_abort_signal.triggered:
with self.engine.new_context() as fl_ctx:
task_fetch_interval, _ = self.fetch_and_run_one_task(fl_ctx)
time.sleep(task_fetch_interval)
def _send_job_heartbeat(self):
request = Shareable()
last_heartbeat_sent_time = 0.0
while not self.run_abort_signal.triggered:
if time.time() - last_heartbeat_sent_time > self.job_heartbeat_interval:
with self.engine.new_context() as fl_ctx:
self.engine.send_aux_request(
targets=[FQCN.ROOT_SERVER],
topic=ReservedTopic.JOB_HEART_BEAT,
request=request,
timeout=0,
fl_ctx=fl_ctx,
optional=True,
)
last_heartbeat_sent_time = time.time()
# sleep very short time so that we can check stop condition (e.g. abort signal)
time.sleep(0.2)
[docs] def fetch_and_run_one_task(self, fl_ctx) -> (float, bool):
"""Fetches and runs a task.
Returns:
A tuple of (task_fetch_interval, task_processed).
"""
default_task_fetch_interval = self.default_task_fetch_interval
self.log_debug(fl_ctx, "fetching task from server ...")
task = self.engine.get_task_assignment(fl_ctx, self.get_task_timeout)
if not task:
self.log_debug(fl_ctx, "no task received - will try in {} secs".format(default_task_fetch_interval))
return default_task_fetch_interval, False
elif task.name == SpecialTaskName.END_RUN:
self.log_info(fl_ctx, "server asked to end the run")
self.run_abort_signal.trigger(True)
return default_task_fetch_interval, False
elif task.name == SpecialTaskName.TRY_AGAIN:
task_data = task.data
task_fetch_interval = default_task_fetch_interval
if task_data and isinstance(task_data, Shareable):
task_fetch_interval = task_data.get_header(TaskConstant.WAIT_TIME, task_fetch_interval)
self.log_debug(fl_ctx, "server asked to try again - will try in {} secs".format(task_fetch_interval))
return task_fetch_interval, False
self.log_info(fl_ctx, f"got task assignment: name={task.name}, id={task.task_id}")
task_data = task.data
if not isinstance(task_data, Shareable):
raise TypeError("task_data must be Shareable, but got {}".format(type(task_data)))
task_fetch_interval = task_data.get_header(TaskConstant.WAIT_TIME, default_task_fetch_interval)
task_reply = self._process_task(task, fl_ctx)
self.log_debug(fl_ctx, "firing event EventType.BEFORE_SEND_TASK_RESULT")
self.fire_event(EventType.BEFORE_SEND_TASK_RESULT, fl_ctx)
self._send_task_result(task_reply, task.task_id, fl_ctx)
self.log_debug(fl_ctx, "firing event EventType.AFTER_SEND_TASK_RESULT")
self.fire_event(EventType.AFTER_SEND_TASK_RESULT, fl_ctx)
return task_fetch_interval, True
def _send_task_result(self, result: Shareable, task_id: str, fl_ctx: FLContext):
try_count = 1
while True:
self.log_info(fl_ctx, f"try #{try_count}: sending task result to server")
if self.run_abort_signal.triggered:
self.log_info(fl_ctx, "job aborted: stopped trying to send result")
return False
try_count += 1
rc = self._try_send_result_once(result, task_id, fl_ctx)
if rc == _TASK_CHECK_RESULT_OK:
return True
elif rc == _TASK_CHECK_RESULT_TASK_GONE:
return False
else:
# retry
time.sleep(self.task_check_interval)
def _try_send_result_once(self, result: Shareable, task_id: str, fl_ctx: FLContext):
# wait until server is ready to receive
while True:
if self.run_abort_signal.triggered:
return _TASK_CHECK_RESULT_TASK_GONE
rc = self._check_task_once(task_id, fl_ctx)
if rc == _TASK_CHECK_RESULT_OK:
break
elif rc == _TASK_CHECK_RESULT_TASK_GONE:
return rc
else:
# try again
time.sleep(self.task_check_interval)
# try to send the result
self.log_info(fl_ctx, "start to send task result to server")
reply_sent = self.engine.send_task_result(result, fl_ctx, timeout=self.submit_task_result_timeout)
if reply_sent:
self.log_info(fl_ctx, "task result sent to server")
return _TASK_CHECK_RESULT_OK
else:
self.log_error(fl_ctx, "failed to send task result to server - will try again")
return _TASK_CHECK_RESULT_TRY_AGAIN
def _check_task_once(self, task_id: str, fl_ctx: FLContext) -> int:
"""This method checks whether the server is still waiting for the specified task.
The real reason for this method is to fight against unstable network connections.
We try to make sure that when we send task result to the server, the connection is available.
If the task check succeeds, then the network connection is likely to be available.
Otherwise, we keep retrying until task check succeeds or the server tells us that the task is gone (timed out).
Args:
task_id:
fl_ctx:
Returns:
"""
self.log_info(fl_ctx, "checking task ...")
task_check_req = Shareable()
task_check_req.set_header(ReservedKey.TASK_ID, task_id)
resp = self.engine.send_aux_request(
targets=[FQCN.ROOT_SERVER],
topic=ReservedTopic.TASK_CHECK,
request=task_check_req,
timeout=self.task_check_timeout,
fl_ctx=fl_ctx,
optional=True,
)
if resp and isinstance(resp, dict):
reply = resp.get(FQCN.ROOT_SERVER)
if not isinstance(reply, Shareable):
self.log_error(fl_ctx, f"bad task_check reply from server: expect Shareable but got {type(reply)}")
return _TASK_CHECK_RESULT_TRY_AGAIN
rc = reply.get_return_code()
if rc == ReturnCode.OK:
return _TASK_CHECK_RESULT_OK
elif rc == ReturnCode.COMMUNICATION_ERROR:
self.log_error(fl_ctx, f"failed task_check: {rc}")
return _TASK_CHECK_RESULT_TRY_AGAIN
elif rc == ReturnCode.SERVER_NOT_READY:
self.log_error(fl_ctx, f"server rejected task_check: {rc}")
return _TASK_CHECK_RESULT_TRY_AGAIN
elif rc == ReturnCode.TASK_UNKNOWN:
self.log_debug(fl_ctx, f"task no longer exists on server: {rc}")
return _TASK_CHECK_RESULT_TASK_GONE
else:
# this should never happen
self.log_error(fl_ctx, f"programming error: received {rc} from server")
return _TASK_CHECK_RESULT_OK # try to push the result regardless
else:
self.log_error(fl_ctx, f"bad task_check reply from server: invalid resp {type(resp)}")
return _TASK_CHECK_RESULT_TRY_AGAIN
[docs] def run(self, app_root, args):
self.init_run(app_root, args)
try:
self._try_run()
except Exception as e:
with self.engine.new_context() as fl_ctx:
self.log_exception(fl_ctx, f"processing error in RUN execution: {secure_format_exception(e)}")
finally:
self.end_run_events_sequence()
ReliableMessage.shutdown()
with self.task_lock:
self.running_tasks = {}
[docs] def init_run(self, app_root, args):
sync_timeout = self.get_positive_float_var(
var_name=ConfigVarName.RUNNER_SYNC_TIMEOUT,
default=2.0,
)
max_sync_tries = self.get_positive_int_var(
var_name=ConfigVarName.MAX_RUNNER_SYNC_TRIES,
default=30,
)
target = "server"
synced = False
sync_start = time.time()
with self.engine.new_context() as fl_ctx:
for i in range(max_sync_tries):
# sync with server runner before starting
time.sleep(0.5)
resp = self.engine.send_aux_request(
targets=[target],
topic=ReservedTopic.SYNC_RUNNER,
request=Shareable(),
timeout=sync_timeout,
fl_ctx=fl_ctx,
optional=True,
secure=False,
)
if not resp:
continue
reply = resp.get(target)
if not reply:
continue
assert isinstance(reply, Shareable)
rc = reply.get_return_code()
if rc == ReturnCode.OK:
synced = True
break
if not synced:
raise RuntimeError(f"cannot sync with Server Runner after {max_sync_tries} tries")
self.log_info(fl_ctx, f"synced to Server Runner in {time.time()-sync_start} seconds")
ReliableMessage.enable(fl_ctx)
self.fire_event(EventType.ABOUT_TO_START_RUN, fl_ctx)
fl_ctx.set_prop(FLContextKey.APP_ROOT, app_root, sticky=True)
fl_ctx.set_prop(FLContextKey.ARGS, args, sticky=True)
fl_ctx.set_prop(ReservedKey.RUN_ABORT_SIGNAL, self.run_abort_signal, private=True, sticky=True)
self.log_debug(fl_ctx, "firing event EventType.START_RUN")
self.fire_event(EventType.START_RUN, fl_ctx)
self.log_info(fl_ctx, "client runner started")
[docs] def end_run_events_sequence(self):
with self.engine.new_context() as fl_ctx:
self.log_info(fl_ctx, "started end-run events sequence")
with self.task_lock:
num_running_tasks = len(self.running_tasks)
if num_running_tasks > 0:
self.fire_event(EventType.ABORT_TASK, fl_ctx)
self.log_info(fl_ctx, "fired ABORT_TASK event to abort all running tasks")
self.fire_event(EventType.ABOUT_TO_END_RUN, fl_ctx)
self.log_info(fl_ctx, "ABOUT_TO_END_RUN fired")
self.check_end_run_readiness(fl_ctx)
# now ready to end run
self.fire_event(EventType.END_RUN, fl_ctx)
self.log_info(fl_ctx, "END_RUN fired")
[docs] def abort(self, msg: str = ""):
"""To Abort the current run.
Returns: N/A
"""
# This is called when:
# 1. abort_job command is issued by the user
# 2. when the job is ended by the server when error conditions occur
# 3. when the job is ended normally at the end of the workflow
if not msg:
msg = "Client is stopping ..."
with self.engine.new_context() as fl_ctx:
self.log_info(fl_ctx, msg)
self.run_abort_signal.trigger(True)
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == InfoCollector.EVENT_TYPE_GET_STATS:
collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR)
if collector:
if not isinstance(collector, GroupInfoCollector):
raise TypeError("collector must be GroupInfoCollector, but got {}".format(type(collector)))
with self.task_lock:
current_tasks = []
for _, task in self.running_tasks.items():
current_tasks.append(task.name)
collector.set_info(
group_name="ClientRunner",
info={"job_id": self.job_id, "current_tasks": current_tasks},
)
elif event_type == EventType.FATAL_SYSTEM_ERROR:
# This happens when a task calls system_panic().
reason = fl_ctx.get_prop(key=FLContextKey.EVENT_DATA, default="")
self.log_error(fl_ctx, "Stopped ClientRunner due to FATAL_SYSTEM_ERROR: {}".format(reason))
self.run_abort_signal.trigger(True)
def _handle_end_run(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
# This happens when the controller on server asks the client to end the job.
# Usually at the end of the workflow.
self.log_info(fl_ctx, "received request from Server to end current RUN")
self.run_abort_signal.trigger(True)
return make_reply(ReturnCode.OK)
def _handle_do_task(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
self.log_info(fl_ctx, "received aux request to do task")
task_name = request.get_header(ReservedHeaderKey.TASK_NAME)
task_id = request.get_header(ReservedHeaderKey.TASK_ID)
task = TaskAssignment(name=task_name, task_id=task_id, data=request)
reply = self._process_task(task, fl_ctx)
return reply