# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
from nvflare.apis.client import Client
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FilterKey, FLContextKey, ReservedKey, ReservedTopic, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.server_engine_spec import ServerEngineSpec
from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.apis.utils.fl_context_utils import add_job_audit_event
from nvflare.apis.utils.reliable_message import ReliableMessage
from nvflare.apis.utils.task_utils import apply_filters
from nvflare.private.defs import SpecialTaskName, TaskConstant
from nvflare.private.fed.tbi import TBI
from nvflare.private.privacy_manager import Scope
from nvflare.security.logging import secure_format_exception
from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector
[docs]class ServerRunnerConfig(object):
def __init__(
self,
heartbeat_timeout: int,
task_request_interval: float,
workflows: [],
task_data_filters: dict,
task_result_filters: dict,
handlers=None,
components=None,
):
"""Configuration for ServerRunner.
Args:
heartbeat_timeout (int): Client heartbeat timeout in seconds
task_request_interval (float): Task request interval in seconds
workflows (list): A list of workflow
task_data_filters (dict): A dict of {task_name: list of filters apply to data (pre-process)}
task_result_filters (dict): A dict of {task_name: list of filters apply to result (post-process)}
handlers (list, optional): A list of event handlers
components (dict, optional): A dict of extra python objects {id: object}
"""
self.heartbeat_timeout = heartbeat_timeout
self.task_request_interval = task_request_interval
self.workflows = workflows
self.task_data_filters = task_data_filters
self.task_result_filters = task_result_filters
self.handlers = handlers
self.components = components
[docs] def add_component(self, comp_id: str, component: object):
if not isinstance(comp_id, str):
raise TypeError(f"component id must be str but got {type(comp_id)}")
if comp_id in self.components:
raise ValueError(f"duplicate component id {comp_id}")
self.components[comp_id] = component
if isinstance(component, FLComponent):
self.handlers.append(component)
[docs]class ServerRunner(TBI):
ABORT_RETURN_CODES = [
ReturnCode.RUN_MISMATCH,
ReturnCode.TASK_UNKNOWN,
ReturnCode.UNSAFE_JOB,
]
def __init__(self, config: ServerRunnerConfig, job_id: str, engine: ServerEngineSpec):
"""Server runner class.
Args:
config (ServerRunnerConfig): configuration of server runner
job_id (str): The number to distinguish each experiment
engine (ServerEngineSpec): server engine
"""
TBI.__init__(self)
self.job_id = job_id
self.config = config
self.engine = engine
self.abort_signal = Signal()
self.wf_lock = threading.Lock()
self.current_wf = None
self.current_wf_index = 0
self.status = "init"
self.turn_to_cold = False
self._register_aux_message_handler(engine)
def _register_aux_message_handler(self, engine):
engine.register_aux_message_handler(
topic=ReservedTopic.SYNC_RUNNER, message_handle_func=self._handle_sync_runner
)
engine.register_aux_message_handler(
topic=ReservedTopic.JOB_HEART_BEAT, message_handle_func=self._handle_job_heartbeat
)
engine.register_aux_message_handler(topic=ReservedTopic.TASK_CHECK, message_handle_func=self._handle_task_check)
def _handle_sync_runner(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
# simply ack
self._report_client_active("syncRunner", fl_ctx)
return make_reply(ReturnCode.OK)
def _execute_run(self):
while self.current_wf_index < len(self.config.workflows):
wf = self.config.workflows[self.current_wf_index]
try:
with self.engine.new_context() as fl_ctx:
self.log_info(fl_ctx, "starting workflow {} ({}) ...".format(wf.id, type(wf.controller)))
fl_ctx.set_prop(FLContextKey.WORKFLOW, wf.id, sticky=True)
wf.controller.communicator.initialize_run(fl_ctx)
wf.controller.initialize(fl_ctx)
self.log_info(fl_ctx, "Workflow {} ({}) started".format(wf.id, type(wf.controller)))
self.log_debug(fl_ctx, "firing event EventType.START_WORKFLOW")
self.fire_event(EventType.START_WORKFLOW, fl_ctx)
# use the wf_lock to ensure state integrity between workflow change and message processing
with self.wf_lock:
# we only set self.current_wf to open for business after successful initialize_run!
self.current_wf = wf
with self.engine.new_context() as fl_ctx:
wf.controller.control_flow(self.abort_signal, fl_ctx)
except Exception as e:
with self.engine.new_context() as fl_ctx:
self.log_exception(fl_ctx, "Exception in workflow {}: {}".format(wf.id, secure_format_exception(e)))
self.system_panic("Exception in workflow {}: {}".format(wf.id, secure_format_exception(e)), fl_ctx)
finally:
with self.engine.new_context() as fl_ctx:
# do not execute finalize_run() until the wf_lock is acquired
with self.wf_lock:
# unset current_wf to prevent message processing
# then we can release the lock - no need to delay message processing
# during finalization!
# Note: WF finalization may take time since it needs to wait for
# the job monitor to join.
self.current_wf = None
self.log_info(fl_ctx, f"Workflow: {wf.id} finalizing ...")
try:
wf.controller.stop_controller(fl_ctx)
wf.controller.communicator.finalize_run(fl_ctx)
except Exception as e:
self.log_exception(
fl_ctx, "Error finalizing workflow {}: {}".format(wf.id, secure_format_exception(e))
)
self.log_debug(fl_ctx, "firing event EventType.END_WORKFLOW")
self.fire_event(EventType.END_WORKFLOW, fl_ctx)
# Stopped the server runner from the current responder, not continue the following responders.
if self.abort_signal.triggered:
break
self.current_wf_index += 1
[docs] def run(self):
with self.engine.new_context() as fl_ctx:
ReliableMessage.enable(fl_ctx)
self.log_info(fl_ctx, "Server runner starting ...")
self.log_debug(fl_ctx, "firing event EventType.START_RUN")
fl_ctx.set_prop(ReservedKey.RUN_ABORT_SIGNAL, self.abort_signal, private=True, sticky=True)
self.fire_event(EventType.START_RUN, fl_ctx)
self.engine.persist_components(fl_ctx, completed=False)
self.status = "started"
try:
self._execute_run()
except Exception as e:
with self.engine.new_context() as fl_ctx:
self.log_exception(fl_ctx, f"Error executing RUN: {secure_format_exception(e)}")
finally:
# use wf_lock to ensure state of current_wf!
self.status = "done"
with self.wf_lock:
with self.engine.new_context() as fl_ctx:
self.fire_event(EventType.ABOUT_TO_END_RUN, fl_ctx)
self.log_info(fl_ctx, "ABOUT_TO_END_RUN fired")
if not self.turn_to_cold:
# ask all clients to end run!
self.engine.send_aux_request(
targets=None,
topic=ReservedTopic.END_RUN,
request=Shareable(),
timeout=0.0,
fl_ctx=fl_ctx,
optional=True,
secure=False,
)
self.engine.persist_components(fl_ctx, completed=True)
self.check_end_run_readiness(fl_ctx)
# Now ready to end the run!
self.fire_event(EventType.END_RUN, fl_ctx)
self.log_info(fl_ctx, "END_RUN fired")
ReliableMessage.shutdown()
self.log_info(fl_ctx, "Server runner finished.")
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == InfoCollector.EVENT_TYPE_GET_STATS:
collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR)
if collector:
if not isinstance(collector, GroupInfoCollector):
raise TypeError("collector must be GroupInfoCollect but got {}".format(type(collector)))
with self.wf_lock:
if self.current_wf:
collector.set_info(
group_name="ServerRunner",
info={"job_id": self.job_id, "status": self.status, "workflow": self.current_wf.id},
)
elif event_type == EventType.FATAL_SYSTEM_ERROR:
fl_ctx.set_prop(key=FLContextKey.FATAL_SYSTEM_ERROR, value=True, private=True, sticky=True)
reason = fl_ctx.get_prop(key=FLContextKey.EVENT_DATA, default="")
self.log_error(fl_ctx, "Aborting current RUN due to FATAL_SYSTEM_ERROR received: {}".format(reason))
self.abort(fl_ctx)
def _task_try_again(self) -> (str, str, Shareable):
task_data = Shareable()
task_data.set_header(TaskConstant.WAIT_TIME, self.config.task_request_interval)
return SpecialTaskName.TRY_AGAIN, "", task_data
[docs] def process_task_request(self, client: Client, fl_ctx: FLContext) -> (str, str, Shareable):
"""Process task request from a client.
NOTE: the Engine will create a new fl_ctx and call this method:
with engine.new_context() as fl_ctx:
name, id, data = runner.process_task_request(client, fl_ctx)
...
Args:
client (Client): client object
fl_ctx (FLContext): FL context
Returns:
A tuple of (task name, task id, and task data)
"""
engine = fl_ctx.get_engine()
if not isinstance(engine, ServerEngineSpec):
raise TypeError("engine must be ServerEngineSpec but got {}".format(type(engine)))
self.log_debug(fl_ctx, "process task request from client")
if self.status == "init":
self.log_debug(fl_ctx, "server runner still initializing - asked client to try again later")
return self._task_try_again()
if self.status == "done":
self.log_info(fl_ctx, "server runner is finalizing - asked client to end the run")
return SpecialTaskName.END_RUN, "", None
peer_ctx = fl_ctx.get_peer_context()
if not isinstance(peer_ctx, FLContext):
self.log_debug(fl_ctx, "invalid task request: no peer context - asked client to try again later")
return self._task_try_again()
self._report_client_active("getTask", fl_ctx)
peer_job_id = peer_ctx.get_job_id()
if not peer_job_id or peer_job_id != self.job_id:
# the client is in a different RUN
self.log_info(fl_ctx, "invalid task request: not the same job_id - asked client to end the run")
return SpecialTaskName.END_RUN, "", None
try:
task_name, task_id, task_data = self._try_to_get_task(
# client, fl_ctx, self.config.task_request_timeout, self.config.task_retry_interval
client,
fl_ctx,
)
if not task_name or task_name == SpecialTaskName.TRY_AGAIN:
return self._task_try_again()
# filter task data
self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_DATA_FILTER")
self.fire_event(EventType.BEFORE_TASK_DATA_FILTER, fl_ctx)
try:
filter_name = Scope.TASK_DATA_FILTERS_NAME
task_data = apply_filters(
filter_name, task_data, fl_ctx, self.config.task_data_filters, task_name, FilterKey.OUT
)
except Exception as e:
self.log_exception(
fl_ctx,
"processing error in task data filter {}; "
"asked client to try again later".format(secure_format_exception(e)),
)
with self.wf_lock:
if self.current_wf:
self.current_wf.controller.communicator.handle_exception(task_id, fl_ctx)
return self._task_try_again()
self.log_debug(fl_ctx, "firing event EventType.AFTER_TASK_DATA_FILTER")
self.fire_event(EventType.AFTER_TASK_DATA_FILTER, fl_ctx)
self.log_info(fl_ctx, f"sent task assignment to client. client_name:{client.name} task_id:{task_id}")
audit_event_id = add_job_audit_event(fl_ctx=fl_ctx, msg=f'sent task to client "{client.name}"')
task_data.set_header(ReservedHeaderKey.AUDIT_EVENT_ID, audit_event_id)
task_data.set_header(TaskConstant.WAIT_TIME, self.config.task_request_interval)
return task_name, task_id, task_data
except Exception as e:
self.log_exception(
fl_ctx,
f"Error processing client task request: {secure_format_exception(e)}; asked client to try again later",
)
return self._task_try_again()
def _try_to_get_task(self, client, fl_ctx, timeout=None, retry_interval=0.005):
start = time.time()
while True:
with self.wf_lock:
if self.current_wf is None:
self.log_debug(fl_ctx, "no current workflow - asked client to try again later")
return "", "", None
self.log_debug(fl_ctx, "firing event EventType.BEFORE_PROCESS_TASK_REQUEST")
self.fire_event(EventType.BEFORE_PROCESS_TASK_REQUEST, fl_ctx)
task_name, task_id, task_data = self.current_wf.controller.communicator.process_task_request(
client, fl_ctx
)
self.log_debug(fl_ctx, "firing event EventType.AFTER_PROCESS_TASK_REQUEST")
self.fire_event(EventType.AFTER_PROCESS_TASK_REQUEST, fl_ctx)
if task_name and task_name != SpecialTaskName.TRY_AGAIN:
if task_data:
if not isinstance(task_data, Shareable):
self.log_error(
fl_ctx,
"bad task data generated by workflow {}: must be Shareable but got {}".format(
self.current_wf.id, type(task_data)
),
)
return "", "", None
else:
task_data = Shareable()
task_data.set_header(ReservedHeaderKey.TASK_ID, task_id)
task_data.set_header(ReservedHeaderKey.TASK_NAME, task_name)
task_data.add_cookie(ReservedHeaderKey.WORKFLOW, self.current_wf.id)
fl_ctx.set_prop(FLContextKey.TASK_NAME, value=task_name, private=True, sticky=False)
fl_ctx.set_prop(FLContextKey.TASK_ID, value=task_id, private=True, sticky=False)
fl_ctx.set_prop(FLContextKey.TASK_DATA, value=task_data, private=True, sticky=False)
self.log_info(fl_ctx, f"assigned task to client {client.name}: name={task_name}, id={task_id}")
return task_name, task_id, task_data
if timeout is None or time.time() - start > timeout:
break
time.sleep(retry_interval)
# ask client to retry
return "", "", None
[docs] def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
with self.wf_lock:
try:
if self.current_wf is None:
return
if self.current_wf.controller:
self.current_wf.controller.communicator.process_dead_client_report(client_name, fl_ctx)
except Exception as e:
self.log_exception(
fl_ctx, f"Error processing dead job by workflow {self.current_wf.id}: {secure_format_exception(e)}"
)
[docs] def process_submission(self, client: Client, task_name: str, task_id: str, result: Shareable, fl_ctx: FLContext):
"""Process task result submitted from a client.
NOTE: the Engine will create a new fl_ctx and call this method:
with engine.new_context() as fl_ctx:
name, id, data = runner.process_submission(client, fl_ctx)
Args:
client: Client object
task_name: task name
task_id: task id
result: task result
fl_ctx: FLContext
"""
self.log_info(fl_ctx, f"got result from client {client.name} for task: name={task_name}, id={task_id}")
self._report_client_active("submitTaskResult", fl_ctx)
if not isinstance(result, Shareable):
self.log_error(fl_ctx, "invalid result submission: must be Shareable but got {}".format(type(result)))
return
# set the reply prop so log msg context could include RC from it
fl_ctx.set_prop(FLContextKey.REPLY, result, private=True, sticky=False)
fl_ctx.set_prop(FLContextKey.TASK_NAME, value=task_name, private=True, sticky=False)
fl_ctx.set_prop(FLContextKey.TASK_RESULT, value=result, private=True, sticky=False)
fl_ctx.set_prop(FLContextKey.TASK_ID, value=task_id, private=True, sticky=False)
client_audit_event_id = result.get_header(ReservedHeaderKey.AUDIT_EVENT_ID, "")
add_job_audit_event(
fl_ctx=fl_ctx, ref=client_audit_event_id, msg=f"received result from client '{client.name}'"
)
if self.status != "started":
self.log_info(fl_ctx, "ignored result submission since server runner's status is {}".format(self.status))
return
peer_ctx = fl_ctx.get_peer_context()
if not isinstance(peer_ctx, FLContext):
self.log_info(fl_ctx, "invalid result submission: no peer context - dropped")
return
peer_job_id = peer_ctx.get_job_id()
if not peer_job_id or peer_job_id != self.job_id:
# the client is on a different RUN
self.log_info(fl_ctx, "invalid result submission: not the same job id - dropped")
return
rc = result.get_return_code(default=ReturnCode.OK)
if rc in self.ABORT_RETURN_CODES:
self.log_error(fl_ctx, f"aborting ServerRunner due to fatal return code {rc} from client {client.name}")
self.system_panic(
reason=f"Aborted job {self.job_id} due to fatal return code {rc} from client {client.name}",
fl_ctx=fl_ctx,
)
return
result.set_header(ReservedHeaderKey.TASK_NAME, task_name)
result.set_header(ReservedHeaderKey.TASK_ID, task_id)
result.set_peer_props(peer_ctx.get_all_public_props())
with self.wf_lock:
try:
if self.current_wf is None:
self.log_info(fl_ctx, "no current workflow - dropped submission.")
return
wf_id = result.get_cookie(ReservedHeaderKey.WORKFLOW, None)
if wf_id is not None and wf_id != self.current_wf.id:
self.log_info(
fl_ctx,
"Got result for workflow {}, but we are running {} - dropped submission.".format(
wf_id, self.current_wf.id
),
)
return
# filter task result
self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_RESULT_FILTER")
self.fire_event(EventType.BEFORE_TASK_RESULT_FILTER, fl_ctx)
try:
filter_name = Scope.TASK_RESULT_FILTERS_NAME
result = apply_filters(
filter_name, result, fl_ctx, self.config.task_result_filters, task_name, FilterKey.IN
)
except Exception as e:
self.log_exception(
fl_ctx,
"processing error in task result filter {}; ".format(secure_format_exception(e)),
)
result = make_reply(ReturnCode.TASK_RESULT_FILTER_ERROR)
self.log_debug(fl_ctx, "firing event EventType.AFTER_TASK_RESULT_FILTER")
self.fire_event(EventType.AFTER_TASK_RESULT_FILTER, fl_ctx)
self.log_debug(fl_ctx, "firing event EventType.BEFORE_PROCESS_SUBMISSION")
self.fire_event(EventType.BEFORE_PROCESS_SUBMISSION, fl_ctx)
self.current_wf.controller.communicator.process_submission(
client=client, task_name=task_name, task_id=task_id, result=result, fl_ctx=fl_ctx
)
self.log_info(fl_ctx, "finished processing client result by {}".format(self.current_wf.id))
self.log_debug(fl_ctx, "firing event EventType.AFTER_PROCESS_SUBMISSION")
self.fire_event(EventType.AFTER_PROCESS_SUBMISSION, fl_ctx)
except Exception as e:
self.log_exception(
fl_ctx,
"Error processing client result by {}: {}".format(self.current_wf.id, secure_format_exception(e)),
)
def _report_client_active(self, reason: str, fl_ctx: FLContext):
with self.wf_lock:
if self.current_wf and self.current_wf.controller:
peer_ctx = fl_ctx.get_peer_context()
assert isinstance(peer_ctx, FLContext)
client_name = peer_ctx.get_identity_name()
self.current_wf.controller.communicator.client_is_active(client_name, reason, fl_ctx)
def _handle_job_heartbeat(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
self.log_debug(fl_ctx, "received client job_heartbeat")
self._report_client_active("jobHeartbeat", fl_ctx)
return make_reply(ReturnCode.OK)
def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
self._report_client_active("taskCheck", fl_ctx)
task_id = request.get_header(ReservedHeaderKey.TASK_ID)
if not task_id:
self.log_error(fl_ctx, f"missing {ReservedHeaderKey.TASK_ID} in task_check request")
return make_reply(ReturnCode.BAD_REQUEST_DATA)
self.log_debug(fl_ctx, f"received task_check on task {task_id}")
with self.wf_lock:
if self.current_wf is None or self.current_wf.controller is None:
self.log_info(fl_ctx, "no current workflow - dropped task_check.")
return make_reply(ReturnCode.TASK_UNKNOWN)
task = self.current_wf.controller.communicator.process_task_check(task_id=task_id, fl_ctx=fl_ctx)
if task:
self.log_debug(fl_ctx, f"task {task_id} is still good")
return make_reply(ReturnCode.OK)
else:
self.log_info(fl_ctx, f"task {task_id} is not found")
return make_reply(ReturnCode.TASK_UNKNOWN)
[docs] def abort(self, fl_ctx: FLContext, turn_to_cold: bool = False):
self.status = "done"
self.abort_signal.trigger(value=True)
self.turn_to_cold = turn_to_cold
self.log_info(fl_ctx, "asked to abort - triggered abort_signal to stop the RUN")
[docs] def get_persist_state(self, fl_ctx: FLContext) -> dict:
return {"job_id": str(self.job_id), "current_wf_index": self.current_wf_index}
[docs] def restore(self, state_data: dict, fl_ctx: FLContext):
self.job_id = state_data.get("job_id")
self.current_wf_index = int(state_data.get("current_wf_index", 0))