# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
from nvflare.apis.client_engine_spec import ClientEngineSpec, TaskAssignment
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey, ReservedKey, ReservedTopic, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.private.defs import SpecialTaskName, TaskConstant
from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector
[docs]class ClientRunnerConfig(object):
def __init__(
self,
task_table: dict, # task_name => Executor
task_data_filters: dict, # task_name => list of filters
task_result_filters: dict, # task_name => list of filters
handlers=None, # list of event handlers
components=None, # dict of extra python objects: id => object
):
"""To init ClientRunnerConfig.
Args:
task_table: task_name: Executor dict
task_data_filters: task_name => list of data filters
task_result_filters: task_name => list of result filters
handlers: list of event handlers
components: dict of extra python objects: id => object
"""
self.task_table = task_table
self.task_data_filters = task_data_filters
self.task_result_filters = task_result_filters
self.handlers = handlers
self.components = components
[docs]class ClientRunner(FLComponent):
def __init__(
self,
config: ClientRunnerConfig,
job_id,
engine: ClientEngineSpec,
task_fetch_interval: int = 5, # fetch task every 5 secs
):
"""To init the ClientRunner.
Args:
config: ClientRunnerConfig
job_id: job id
engine: ClientEngine object
task_fetch_interval: fetch task interval
"""
FLComponent.__init__(self)
self.task_table = config.task_table
self.task_data_filters = config.task_data_filters
self.task_result_filters = config.task_result_filters
self.job_id = job_id
self.engine = engine
self.task_fetch_interval = task_fetch_interval
self.run_abort_signal = Signal()
self.task_abort_signal = None
self.current_executor = None
self.current_task = None
self.asked_to_stop = False
self.task_lock = threading.Lock()
self.end_run_fired = False
self.end_run_lock = threading.Lock()
engine.register_aux_message_handler(topic=ReservedTopic.END_RUN, message_handle_func=self._handle_end_run)
engine.register_aux_message_handler(topic=ReservedTopic.ABORT_ASK, message_handle_func=self._handle_abort_task)
def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable:
engine = fl_ctx.get_engine()
if not isinstance(engine, ClientEngineSpec):
raise TypeError("engine must be ClientEngineSpec, but got {}".format(type(engine)))
if not isinstance(task.data, Shareable):
self.log_error(
fl_ctx, "got invalid task data in assignment: expect Shareable, but got {}".format(type(task.data))
)
return make_reply(ReturnCode.BAD_TASK_DATA)
fl_ctx.set_prop(FLContextKey.TASK_DATA, value=task.data, private=True, sticky=False)
fl_ctx.set_prop(FLContextKey.TASK_NAME, value=task.name, private=True, sticky=False)
fl_ctx.set_prop(FLContextKey.TASK_ID, value=task.task_id, private=True, sticky=False)
peer_ctx = fl_ctx.get_peer_context()
if not peer_ctx:
self.log_error(fl_ctx, "missing peer context in Server task assignment")
return make_reply(ReturnCode.MISSING_PEER_CONTEXT)
if not isinstance(peer_ctx, FLContext):
self.log_error(
fl_ctx,
"bad peer context in Server task assignment: expects FLContext but got {}".format(type(peer_ctx)),
)
return make_reply(ReturnCode.BAD_PEER_CONTEXT)
task.data.set_peer_props(peer_ctx.get_all_public_props())
peer_job_id = peer_ctx.get_job_id()
if peer_job_id != self.job_id:
self.log_error(fl_ctx, "bad task assignment: not for the same job_id")
return make_reply(ReturnCode.RUN_MISMATCH)
executor = self.task_table.get(task.name)
if not executor:
self.log_error(fl_ctx, "bad task assignment: no executor available for task {}".format(task.name))
return make_reply(ReturnCode.TASK_UNKNOWN)
self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_DATA_FILTER")
self.fire_event(EventType.BEFORE_TASK_DATA_FILTER, fl_ctx)
filter_list = self.task_data_filters.get(task.name)
if filter_list:
task_data = task.data
for f in filter_list:
try:
task_data = f.process(task_data, fl_ctx)
except BaseException:
self.log_exception(fl_ctx, "processing error in Task Data Filter {}".format(type(f)))
return make_reply(ReturnCode.TASK_DATA_FILTER_ERROR)
if not isinstance(task_data, Shareable):
self.log_error(
fl_ctx, "task data was converted to wrong type: expect Shareable but got {}".format(type(task_data))
)
return make_reply(ReturnCode.TASK_DATA_FILTER_ERROR)
task.data = task_data
self.log_debug(fl_ctx, "firing event EventType.AFTER_TASK_DATA_FILTER")
fl_ctx.set_prop(FLContextKey.TASK_DATA, value=task.data, private=True, sticky=False)
self.fire_event(EventType.AFTER_TASK_DATA_FILTER, fl_ctx)
self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_EXECUTION")
fl_ctx.set_prop(FLContextKey.TASK_DATA, value=task.data, private=True, sticky=False)
self.fire_event(EventType.BEFORE_TASK_EXECUTION, fl_ctx)
try:
self.log_info(fl_ctx, "invoking task executor {}".format(type(executor)))
with self.task_lock:
self.task_abort_signal = Signal()
self.current_executor = executor
self.current_task = task
try:
reply = executor.execute(task.name, task.data, fl_ctx, self.task_abort_signal)
finally:
with self.task_lock:
if self.task_abort_signal is None:
task_aborted = True
else:
task_aborted = False
self.task_abort_signal = None
self.current_task = None
self.current_executor = None
if task_aborted:
return make_reply(ReturnCode.TASK_ABORTED)
if not isinstance(reply, Shareable):
self.log_error(
fl_ctx,
"bad result generated by executor {}: must be Shareable but got {}".format(
type(executor), type(reply)
),
)
return make_reply(ReturnCode.EXECUTION_RESULT_ERROR)
except RuntimeError as e:
self.log_exception(fl_ctx, f"Critical RuntimeError happened with Exception {e}: Aborting the RUN!")
self.asked_to_stop = True
return make_reply(ReturnCode.EXECUTION_RESULT_ERROR)
except BaseException:
self.log_exception(fl_ctx, "processing error in task executor {}".format(type(executor)))
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
fl_ctx.set_prop(FLContextKey.TASK_RESULT, value=reply, private=True, sticky=False)
self.log_debug(fl_ctx, "firing event EventType.AFTER_TASK_EXECUTION")
self.fire_event(EventType.AFTER_TASK_EXECUTION, fl_ctx)
self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_RESULT_FILTER")
self.fire_event(EventType.BEFORE_TASK_RESULT_FILTER, fl_ctx)
filter_list = self.task_result_filters.get(task.name)
if filter_list:
for f in filter_list:
try:
reply = f.process(reply, fl_ctx)
except BaseException:
self.log_exception(fl_ctx, "processing error in Task Result Filter {}".format(type(f)))
return make_reply(ReturnCode.TASK_RESULT_FILTER_ERROR)
if not isinstance(reply, Shareable):
self.log_error(
fl_ctx, "task result was converted to wrong type: expect Shareable but got {}".format(type(reply))
)
return make_reply(ReturnCode.TASK_RESULT_FILTER_ERROR)
fl_ctx.set_prop(FLContextKey.TASK_RESULT, value=reply, private=True, sticky=False)
self.log_debug(fl_ctx, "firing event EventType.AFTER_TASK_RESULT_FILTER")
self.fire_event(EventType.AFTER_TASK_RESULT_FILTER, fl_ctx)
self.log_info(fl_ctx, "finished processing task")
if not isinstance(reply, Shareable):
self.log_error(
fl_ctx, "task processing error: expects result to be Shareable, but got {}".format(type(reply))
)
return make_reply(ReturnCode.EXECUTION_RESULT_ERROR)
return reply
def _try_run(self):
task_fetch_interval = self.task_fetch_interval
while not self.asked_to_stop:
with self.engine.new_context() as fl_ctx:
if self.run_abort_signal.triggered:
self.log_info(fl_ctx, "run abort signal received")
break
time.sleep(task_fetch_interval)
if self.run_abort_signal.triggered:
self.log_info(fl_ctx, "run abort signal received")
break
# reset to default fetch interval
task_fetch_interval = self.task_fetch_interval
self.log_debug(fl_ctx, "fetching task from server ...")
task = self.engine.get_task_assignment(fl_ctx)
if not task:
self.log_debug(fl_ctx, "no task received - will try in {} secs".format(task_fetch_interval))
continue
if task.name == SpecialTaskName.END_RUN:
self.log_info(fl_ctx, "server asked to end the run")
break
if task.name == SpecialTaskName.TRY_AGAIN:
task_data = task.data
if task_data and isinstance(task_data, Shareable):
task_fetch_interval = task_data.get(TaskConstant.WAIT_TIME, self.task_fetch_interval)
self.log_debug(
fl_ctx, "server asked to try again - will try in {} secs".format(task_fetch_interval)
)
continue
self.log_info(fl_ctx, "got task assignment: name={}, id={}".format(task.name, task.task_id))
# create a new task abort signal
task_reply = self._process_task(task, fl_ctx)
if not isinstance(task_reply, Shareable):
raise TypeError("task_reply must be Shareable, but got {}".format(type(task_reply)))
self.log_debug(fl_ctx, "firing event EventType.BEFORE_SEND_TASK_RESULT")
self.fire_event(EventType.BEFORE_SEND_TASK_RESULT, fl_ctx)
# set the cookie in the reply!
task_data = task.data
if not isinstance(task_data, Shareable):
raise TypeError("task_data must be Shareable, but got {}".format(type(task_data)))
cookie_jar = task_data.get_cookie_jar()
if cookie_jar:
task_reply.set_cookie_jar(cookie_jar)
reply_sent = self.engine.send_task_result(task_reply, fl_ctx)
if reply_sent:
self.log_info(
fl_ctx, "result sent to server for task: name={}, id={}".format(task.name, task.task_id)
)
else:
self.log_error(
fl_ctx,
"failed to send result to server for task: name={}, id={}".format(task.name, task.task_id),
)
self.log_debug(fl_ctx, "firing event EventType.AFTER_SEND_TASK_RESULT")
self.fire_event(EventType.AFTER_SEND_TASK_RESULT, fl_ctx)
[docs] def run(self, app_root, args):
with self.engine.new_context() as fl_ctx:
self.fire_event(EventType.ABOUT_TO_START_RUN, fl_ctx)
fl_ctx.set_prop(FLContextKey.APP_ROOT, app_root, sticky=True)
fl_ctx.set_prop(FLContextKey.ARGS, args, sticky=True)
fl_ctx.set_prop(ReservedKey.RUN_ABORT_SIGNAL, self.run_abort_signal, private=True, sticky=True)
self.log_debug(fl_ctx, "firing event EventType.START_RUN")
self.fire_event(EventType.START_RUN, fl_ctx)
self.log_info(fl_ctx, "client runner started")
with self.end_run_lock:
self.end_run_fired = False
try:
self._try_run()
except BaseException as e:
with self.engine.new_context() as fl_ctx:
self.log_exception(fl_ctx, "processing error in RUN execution: {}".format(e))
finally:
# in case any task is still running, abort it
self._abort_current_task()
self.end_run_events_sequence("run method")
def _abort_current_task(self):
with self.task_lock:
task_abort_signal = self.task_abort_signal
if task_abort_signal:
# set task_abort_signal to None to prevent triggering again
self.task_abort_signal = None
task_name = ""
task_id = ""
task = self.current_task
if task:
task_name = task.name
task_id = task.task_id
with self.engine.new_context() as fl_ctx:
fl_ctx.set_prop(FLContextKey.TASK_NAME, value=task_name, private=True, sticky=False)
fl_ctx.set_prop(FLContextKey.TASK_ID, value=task_id, private=True, sticky=False)
task_abort_signal.trigger(True)
self.log_info(fl_ctx, "triggered task_abort_signal to stop task '{}'".format(task_name))
self.fire_event(EventType.ABORT_TASK, fl_ctx)
self.log_info(fl_ctx, "fired ABORT_TASK event to abort current task {}".format(task_name))
[docs] def abort_task(self, task_names=None):
has_task_to_abort = False
with self.engine.new_context() as fl_ctx:
with self.task_lock:
if self.current_task:
name = self.current_task.name
if not task_names or name in task_names:
has_task_to_abort = True
else:
self.log_info(
fl_ctx, "Ignored abort_task request since current task '{}' is not target".format(name)
)
else:
self.log_info(fl_ctx, "Ignored abort_task request since there is no current task")
if has_task_to_abort:
self._abort_current_task()
[docs] def end_run_events_sequence(self, requester):
with self.engine.new_context() as fl_ctx:
self.log_info(fl_ctx, f"{requester} requests end run events sequence")
with self.end_run_lock:
if not self.end_run_fired:
self.fire_event(EventType.ABOUT_TO_END_RUN, fl_ctx)
self.log_info(fl_ctx, "ABOUT_TO_END_RUN fired")
self.fire_event(EventType.END_RUN, fl_ctx)
self.log_info(fl_ctx, "END_RUN fired")
self.end_run_fired = True
[docs] def abort(self):
"""To Abort the current run.
Returns: N/A
"""
with self.engine.new_context() as fl_ctx:
self.log_info(fl_ctx, "ABORT (RUN) command received")
self._abort_current_task()
self.run_abort_signal.trigger("ABORT (RUN) triggered")
self.asked_to_stop = True
self.end_run_events_sequence("ABORT (RUN)")
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == InfoCollector.EVENT_TYPE_GET_STATS:
collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR)
if collector:
if not isinstance(collector, GroupInfoCollector):
raise TypeError("collector must be GroupInfoCollector, but got {}".format(type(collector)))
if self.current_task:
current_task_name = self.current_task.name
else:
current_task_name = "None"
collector.set_info(
group_name="ClientRunner",
info={"job_id": self.job_id, "current_task_name": current_task_name, "status": "started"},
)
elif event_type == EventType.FATAL_TASK_ERROR:
reason = fl_ctx.get_prop(key=FLContextKey.EVENT_DATA, default="")
self.log_error(fl_ctx, "Aborting current task due to FATAL_TASK_ERROR received: {}".format(reason))
self._abort_current_task()
elif event_type == EventType.FATAL_SYSTEM_ERROR:
reason = fl_ctx.get_prop(key=FLContextKey.EVENT_DATA, default="")
self.log_error(fl_ctx, "Aborting current RUN due to FATAL_SYSTEM_ERROR received: {}".format(reason))
self.abort()
def _handle_end_run(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
self.log_info(fl_ctx, "received aux request from Server to end current RUN")
self.abort()
return make_reply(ReturnCode.OK)
def _handle_abort_task(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
self.log_info(fl_ctx, "received aux request from Server to abort current task")
task_names = request.get("task_names", None)
self.abort_task(task_names)
return make_reply(ReturnCode.OK)