# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
from threading import Lock
from typing import List, Optional, Tuple, Union
from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, SendOrder, Task, TaskCompletionStatus
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import ConfigVarName, FLContextKey, SystemConfigs
from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_def import job_from_meta
from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_copy
from nvflare.apis.signal import Signal
from nvflare.apis.wf_comm_spec import WFCommSpec
from nvflare.fuel.utils.config_service import ConfigService
from nvflare.security.logging import secure_format_exception
from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector
from .any_relay_manager import AnyRelayTaskManager
from .bcast_manager import BcastForeverTaskManager, BcastTaskManager
from .send_manager import SendTaskManager
from .seq_relay_manager import SequentialRelayTaskManager
from .task_manager import TaskCheckStatus, TaskManager
_TASK_KEY_ENGINE = "___engine"
_TASK_KEY_MANAGER = "___mgr"
_TASK_KEY_DONE = "___done"
def _check_positive_int(name, value):
if not isinstance(value, int):
raise TypeError("{} must be an instance of int, but got {}.".format(name, type(name)))
if value < 0:
raise ValueError("{} must >= 0.".format(name))
def _check_inputs(task: Task, fl_ctx: FLContext, targets: Union[List[Client], List[str], None]):
if not isinstance(task, Task):
raise TypeError("task must be an instance of Task, but got {}".format(type(task)))
if not isinstance(fl_ctx, FLContext):
raise TypeError("fl_ctx must be an instance of FLContext, but got {}".format(type(fl_ctx)))
if targets is not None:
if not isinstance(targets, list):
raise TypeError("targets must be a list of Client or string, but got {}".format(type(targets)))
for t in targets:
if not isinstance(t, (Client, str)):
raise TypeError(
"targets must be a list of Client or string, but got element of type {}".format(type(t))
)
def _get_client_task(target, task: Task):
for ct in task.client_tasks:
if target == ct.client.name:
return ct
return None
class _DeadClientStatus:
def __init__(self):
self.report_time = time.time()
self.disconnect_time = None
[docs]class WFCommServer(FLComponent, WFCommSpec):
def __init__(self, task_check_period=0.2):
"""Manage life cycles of tasks and their destinations.
Args:
task_check_period (float, optional): interval for checking status of tasks. Defaults to 0.2.
"""
super().__init__()
self.controller = None
self._engine = None
self._tasks = [] # list of standing tasks
self._client_task_map = {} # client_task_id => client_task
self._all_done = False
self._task_lock = Lock()
self._task_monitor = threading.Thread(target=self._monitor_tasks, args=(), daemon=True)
self._task_check_period = task_check_period
self._dead_client_grace = 60.0
self._dead_clients = {} # clients reported dead: name => _DeadClientStatus
self._dead_clients_lock = Lock() # need lock since dead_clients can be modified from different threads
# make sure check_tasks, process_task_request, process_submission does not interfere with each other
self._controller_lock = Lock()
[docs] def initialize_run(self, fl_ctx: FLContext):
"""Called by runners to initialize controller with information in fl_ctx.
.. attention::
Controller subclasses must not overwrite this method.
Args:
fl_ctx (FLContext): FLContext information
"""
engine = fl_ctx.get_engine()
if not engine:
self.system_panic(f"Engine not found. {self.name} exiting.", fl_ctx)
return
self._engine = engine
self._dead_client_grace = ConfigService.get_float_var(
name=ConfigVarName.DEAD_CLIENT_GRACE_PERIOD, conf=SystemConfigs.APPLICATION_CONF, default=60.0
)
self._task_monitor.start()
def _try_again(self) -> Tuple[str, str, Optional[Shareable]]:
# TODO: how to tell client no shareable available now?
return "", "", None
def _set_stats(self, fl_ctx: FLContext):
"""Called to set stats into InfoCollector.
Args:
fl_ctx (FLContext): info collector is retrieved from fl_ctx with InfoCollector.CTX_KEY_STATS_COLLECTOR key
"""
collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR, None)
if collector:
if not isinstance(collector, GroupInfoCollector):
raise TypeError(
"collector must be an instance of GroupInfoCollector, but got {}".format(type(collector))
)
collector.add_info(
group_name=self.controller.name,
info={
"tasks": {t.name: [ct.client.name for ct in t.client_tasks] for t in self._tasks},
},
)
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext):
"""Called when events are fired.
Args:
event_type (str): all event types, including AppEventType and EventType
fl_ctx (FLContext): FLContext information with current event type
"""
if event_type == InfoCollector.EVENT_TYPE_GET_STATS:
self._set_stats(fl_ctx)
[docs] def process_dead_client_report(self, client_name: str, fl_ctx: FLContext):
with self._dead_clients_lock:
self.log_warning(fl_ctx, f"received dead job report for client {client_name}")
if not self._dead_clients.get(client_name):
self.log_warning(fl_ctx, f"client {client_name} is placed on dead client watch list")
self._dead_clients[client_name] = _DeadClientStatus()
else:
self.log_warning(fl_ctx, f"discarded dead client report {client_name=}: already on watch list")
[docs] def process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[str, str, Shareable]:
"""Called by runner when a client asks for a task.
.. note::
This is called in a separate thread.
Args:
client (Client): The record of one client requesting tasks
fl_ctx (FLContext): The FLContext associated with this request
Raises:
TypeError: when client is not an instance of Client
TypeError: when fl_ctx is not an instance of FLContext
TypeError: when any standing task containing an invalid client_task
Returns:
Tuple[str, str, Shareable]: task_name, an id for the client_task, and the data for this request
"""
with self._controller_lock:
return self._do_process_task_request(client, fl_ctx)
def _do_process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[str, str, Shareable]:
if not isinstance(client, Client):
raise TypeError("client must be an instance of Client, but got {}".format(type(client)))
if not isinstance(fl_ctx, FLContext):
raise TypeError("fl_ctx must be an instance of FLContext, but got {}".format(type(fl_ctx)))
client_task_to_send = None
with self._task_lock:
self.logger.debug("self._tasks: {}".format(self._tasks))
for task in self._tasks:
if task.completion_status is not None:
# this task is finished (and waiting for the monitor to exit it)
continue
# do we need to send this task to this client?
# note: the task could be sent to a client multiple times (e.g. in relay)
# we only check the last ClientTask sent to the client
client_task_to_check = task.last_client_task_map.get(client.name, None)
self.logger.debug("client_task_to_check: {}".format(client_task_to_check))
resend_task = False
if client_task_to_check is not None:
# this client has been sent the task already
if client_task_to_check.result_received_time is None:
# controller has not received result from client
# something wrong happens when client working on this task, so resend the task
resend_task = True
client_task_to_send = client_task_to_check
fl_ctx.set_prop(FLContextKey.IS_CLIENT_TASK_RESEND, True, sticky=False)
if not resend_task:
# check with the task manager whether to send
manager = task.props[_TASK_KEY_MANAGER]
if client_task_to_check is None:
client_task_to_check = ClientTask(task=task, client=client)
check_status = manager.check_task_send(client_task_to_check, fl_ctx)
self.logger.debug(
"Checking client task: {}, task.client.name: {}".format(
client_task_to_check, client_task_to_check.client.name
)
)
self.logger.debug("Check task send get check_status: {}".format(check_status))
if check_status == TaskCheckStatus.BLOCK:
# do not send this task, and do not check other tasks
return self._try_again()
elif check_status == TaskCheckStatus.NO_BLOCK:
# do not send this task, but continue to check next task
continue
else:
# creates the client_task to be checked for sending
client_task_to_send = ClientTask(client, task)
break
# NOTE: move task sending process outside the task lock
# This is to minimize the locking time and to avoid potential deadlock:
# the CB could schedule another task, which requires lock
self.logger.debug("Determining based on client_task_to_send: {}".format(client_task_to_send))
if client_task_to_send is None:
# no task available for this client
return self._try_again()
# try to send the task
can_send_task = True
task = client_task_to_send.task
with task.cb_lock:
# Note: must guarantee the after_task_sent_cb is always called
# regardless whether the task is sent successfully.
# This is so that the app could clear up things in after_task_sent_cb.
if task.before_task_sent_cb is not None:
try:
task.before_task_sent_cb(client_task=client_task_to_send, fl_ctx=fl_ctx)
except Exception as e:
self.log_exception(
fl_ctx,
"processing error in before_task_sent_cb on task {} ({}): {}".format(
client_task_to_send.task.name, client_task_to_send.id, secure_format_exception(e)
),
)
# this task cannot proceed anymore
task.completion_status = TaskCompletionStatus.ERROR
task.exception = e
self.logger.debug("before_task_sent_cb done on client_task_to_send: {}".format(client_task_to_send))
self.logger.debug(f"task completion status is {task.completion_status}")
if task.completion_status is not None:
can_send_task = False
# remember the task name and data to be sent to the client
# since task.data could be reset by the after_task_sent_cb
task_name = task.name
task_data = task.data
operator = task.operator
if task.after_task_sent_cb is not None:
try:
task.after_task_sent_cb(client_task=client_task_to_send, fl_ctx=fl_ctx)
except Exception as e:
self.log_exception(
fl_ctx,
"processing error in after_task_sent_cb on task {} ({}): {}".format(
client_task_to_send.task.name, client_task_to_send.id, secure_format_exception(e)
),
)
task.completion_status = TaskCompletionStatus.ERROR
task.exception = e
if task.completion_status is not None:
# NOTE: the CB could cancel the task
can_send_task = False
if not can_send_task:
return self._try_again()
self.logger.debug("after_task_sent_cb done on client_task_to_send: {}".format(client_task_to_send))
with self._task_lock:
# sent the ClientTask and remember it
now = time.time()
client_task_to_send.task_sent_time = now
client_task_to_send.task_send_count += 1
# add task operator to task_data shareable
if operator:
task_data.set_header(key=ReservedHeaderKey.TASK_OPERATOR, value=operator)
if not resend_task:
task.last_client_task_map[client.name] = client_task_to_send
task.client_tasks.append(client_task_to_send)
self._client_task_map[client_task_to_send.id] = client_task_to_send
task_data.set_header(ReservedHeaderKey.TASK_ID, client_task_to_send.id)
return task_name, client_task_to_send.id, make_copy(task_data)
[docs] def handle_exception(self, task_id: str, fl_ctx: FLContext) -> None:
"""Called to cancel one task as its client_task is causing exception at upper level.
Args:
task_id (str): an id to the failing client_task
fl_ctx (FLContext): FLContext associated with this client_task
"""
with self._task_lock:
# task_id is the uuid associated with the client_task
client_task = self._client_task_map.get(task_id, None)
self.logger.debug("Handle exception on client_task {} with id {}".format(client_task, task_id))
if client_task is None:
# cannot find a standing task on the exception
return
task = client_task.task
self.cancel_task(task=task, fl_ctx=fl_ctx)
self.log_error(fl_ctx, "task {} is cancelled due to exception".format(task.name))
[docs] def process_task_check(self, task_id: str, fl_ctx: FLContext):
with self._task_lock:
# task_id is the uuid associated with the client_task
return self._client_task_map.get(task_id, None)
[docs] def process_submission(self, client: Client, task_name: str, task_id: str, result: Shareable, fl_ctx: FLContext):
"""Called to process a submission from one client.
.. note::
This method is called by a separate thread.
Args:
client (Client): the client that submitted this task
task_name (str): the task name associated this submission
task_id (str): the id associated with the client_task
result (Shareable): the actual submitted data from the client
fl_ctx (FLContext): the FLContext associated with this submission
Raises:
TypeError: when client is not an instance of Client
TypeError: when fl_ctx is not an instance of FLContext
TypeError: when result is not an instance of Shareable
ValueError: task_name is not found in the client_task
"""
with self._controller_lock:
self._do_process_submission(client, task_name, task_id, result, fl_ctx)
def _do_process_submission(
self, client: Client, task_name: str, task_id: str, result: Shareable, fl_ctx: FLContext
):
if not isinstance(client, Client):
raise TypeError("client must be an instance of Client, but got {}".format(type(client)))
if not isinstance(fl_ctx, FLContext):
raise TypeError("fl_ctx must be an instance of FLContext, but got {}".format(type(fl_ctx)))
if not isinstance(result, Shareable):
raise TypeError("result must be an instance of Shareable, but got {}".format(type(result)))
with self._task_lock:
# task_id is the uuid associated with the client_task
client_task = self._client_task_map.get(task_id, None)
self.log_debug(fl_ctx, "Get submission from client task={} id={}".format(client_task, task_id))
if client_task is None:
# cannot find a standing task for the submission
self.log_debug(fl_ctx, "no standing task found for {}:{}".format(task_name, task_id))
self.log_debug(fl_ctx, "firing event EventType.BEFORE_PROCESS_RESULT_OF_UNKNOWN_TASK")
self.fire_event(EventType.BEFORE_PROCESS_RESULT_OF_UNKNOWN_TASK, fl_ctx)
self.controller.process_result_of_unknown_task(client, task_name, task_id, result, fl_ctx)
self.log_debug(fl_ctx, "firing event EventType.AFTER_PROCESS_RESULT_OF_UNKNOWN_TASK")
self.fire_event(EventType.AFTER_PROCESS_RESULT_OF_UNKNOWN_TASK, fl_ctx)
return
task = client_task.task
with task.cb_lock:
if task.name != task_name:
raise ValueError("client specified task name {} doesn't match {}".format(task_name, task.name))
if task.completion_status is not None:
# the task is already finished - drop the result
self.log_info(fl_ctx, "task is already finished - submission dropped")
return
# do client task CB processing outside the lock
# this is because the CB could schedule another task, which requires the lock
client_task.result = result
manager = task.props[_TASK_KEY_MANAGER]
manager.check_task_result(result, client_task, fl_ctx)
if task.result_received_cb is not None:
try:
self.log_debug(fl_ctx, "invoking result_received_cb ...")
task.result_received_cb(client_task=client_task, fl_ctx=fl_ctx)
except Exception as e:
# this task cannot proceed anymore
self.log_exception(
fl_ctx,
"processing error in result_received_cb on task {}({}): {}".format(
task_name, task_id, secure_format_exception(e)
),
)
task.completion_status = TaskCompletionStatus.ERROR
task.exception = e
else:
self.log_debug(fl_ctx, "no result_received_cb")
client_task.result_received_time = time.time()
def _schedule_task(
self,
task: Task,
fl_ctx: FLContext,
manager: TaskManager,
targets: Union[List[Client], List[str], None],
allow_dup_targets: bool = False,
):
if task.schedule_time is not None:
# this task was scheduled before
# we do not allow a task object to be reused
self.logger.debug("task.schedule_time: {}".format(task.schedule_time))
raise ValueError("Task was already used. Please create a new task object.")
# task.targets = targets
target_names = list()
if targets is None:
for client in self._engine.get_clients():
target_names.append(client.name)
else:
if not isinstance(targets, list):
raise ValueError("task targets must be a list, but got {}".format(type(targets)))
for t in targets:
if isinstance(t, str):
name = t
elif isinstance(t, Client):
name = t.name
else:
raise ValueError("element in targets must be string or Client type, but got {}".format(type(t)))
if allow_dup_targets or (name not in target_names):
target_names.append(name)
task.targets = target_names
task.props[_TASK_KEY_MANAGER] = manager
task.props[_TASK_KEY_ENGINE] = self._engine
task.is_standing = True
task.schedule_time = time.time()
with self._task_lock:
self._tasks.append(task)
self.log_info(fl_ctx, "scheduled task {}".format(task.name))
[docs] def broadcast(
self,
task: Task,
fl_ctx: FLContext,
targets: Union[List[Client], List[str], None] = None,
min_responses: int = 1,
wait_time_after_min_received: int = 0,
):
"""Schedule a broadcast task. This is a non-blocking call.
The task is scheduled into a task list.
Clients can request tasks and controller will dispatch the task to eligible clients.
Args:
task (Task): the task to be scheduled
fl_ctx (FLContext): FLContext associated with this task
targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names
or None (all clients). Defaults to None.
min_responses (int, optional): the condition to mark this task as completed because enough clients
respond with submission. Defaults to 1.
wait_time_after_min_received (int, optional): a grace period for late clients to contribute their
submission. 0 means no grace period.
Submission of late clients in the grace period are still collected as valid submission. Defaults to 0.
Raises:
ValueError: min_responses is greater than the length of targets since this condition will make the task,
if allowed to be scheduled, never exit.
"""
_check_inputs(task=task, fl_ctx=fl_ctx, targets=targets)
_check_positive_int("min_responses", min_responses)
_check_positive_int("wait_time_after_min_received", wait_time_after_min_received)
if targets and min_responses > len(targets):
raise ValueError(
"min_responses ({}) must be less than length of targets ({}).".format(min_responses, len(targets))
)
manager = BcastTaskManager(
task=task, min_responses=min_responses, wait_time_after_min_received=wait_time_after_min_received
)
self._schedule_task(task=task, fl_ctx=fl_ctx, manager=manager, targets=targets)
[docs] def broadcast_and_wait(
self,
task: Task,
fl_ctx: FLContext,
targets: Union[List[Client], List[str], None] = None,
min_responses: int = 1,
wait_time_after_min_received: int = 0,
abort_signal: Optional[Signal] = None,
):
"""Schedule a broadcast task. This is a blocking call.
The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task
to eligible clients.
Args:
task (Task): the task to be scheduled
fl_ctx (FLContext): FLContext associated with this task
targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names
or None (all clients). Defaults to None.
min_responses (int, optional): the condition to mark this task as completed because enough clients
respond with submission. Defaults to 1.
wait_time_after_min_received (int, optional): a grace period for late clients to contribute their
submission. 0 means no grace period.
Submission of late clients in the grace period are still collected as valid submission. Defaults to 0.
abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs
this method to return. Defaults to None.
"""
self.broadcast(
task=task,
fl_ctx=fl_ctx,
targets=targets,
min_responses=min_responses,
wait_time_after_min_received=wait_time_after_min_received,
)
self.wait_for_task(task, abort_signal)
[docs] def broadcast_forever(self, task: Task, fl_ctx: FLContext, targets: Union[List[Client], List[str], None] = None):
"""Schedule a broadcast task. This is a non-blocking call.
The task is scheduled into a task list. Clients can request tasks and controller will dispatch
the task to eligible clients.
This broadcast will not end.
Args:
task (Task): the task to be scheduled
fl_ctx (FLContext): FLContext associated with this task
targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names
or None (all clients). Defaults to None.
"""
_check_inputs(task=task, fl_ctx=fl_ctx, targets=targets)
manager = BcastForeverTaskManager()
self._schedule_task(task=task, fl_ctx=fl_ctx, manager=manager, targets=targets)
[docs] def send(
self,
task: Task,
fl_ctx: FLContext,
targets: Union[List[Client], List[str], None] = None,
send_order: SendOrder = SendOrder.SEQUENTIAL,
task_assignment_timeout: int = 0,
):
"""Schedule a single task to targets. This is a non-blocking call.
The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task
to eligible clients based on the send_order.
Args:
task (Task): the task to be scheduled
fl_ctx (FLContext): FLContext associated with this task
targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names
or None (all clients). Defaults to None.
send_order (SendOrder, optional): the order for clients to become eligible.
SEQUENTIAL means the order in targets is enforced.
ANY means clients in targets and haven't received task are eligible for task.
Defaults to SendOrder.SEQUENTIAL.
task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0.
Raises:
ValueError: when task_assignment_timeout is greater than task's timeout.
TypeError: send_order is not defined in SendOrder
ValueError: targets is None or an empty list
"""
_check_inputs(
task=task,
fl_ctx=fl_ctx,
targets=targets,
)
_check_positive_int("task_assignment_timeout", task_assignment_timeout)
if task.timeout and task_assignment_timeout and task_assignment_timeout > task.timeout:
raise ValueError(
"task_assignment_timeout ({}) needs to be less than or equal to task.timeout ({}).".format(
task_assignment_timeout, task.timeout
)
)
if not isinstance(send_order, SendOrder):
raise TypeError("send_order must be in Enum SendOrder, but got {}".format(type(send_order)))
# targets must be provided
if targets is None or len(targets) == 0:
raise ValueError("Targets must be provided for send.")
manager = SendTaskManager(task, send_order, task_assignment_timeout)
self._schedule_task(
task=task,
fl_ctx=fl_ctx,
manager=manager,
targets=targets,
)
[docs] def send_and_wait(
self,
task: Task,
fl_ctx: FLContext,
targets: Union[List[Client], List[str], None] = None,
send_order: SendOrder = SendOrder.SEQUENTIAL,
task_assignment_timeout: int = 0,
abort_signal: Signal = None,
):
"""Schedule a single task to targets. This is a blocking call.
The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task
to eligible clients based on the send_order.
Args:
task (Task): the task to be scheduled
fl_ctx (FLContext): FLContext associated with this task
targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names
or None (all clients). Defaults to None.
send_order (SendOrder, optional): the order for clients to become eligible.
SEQUENTIAL means the order in targets is enforced.
ANY means clients in targets and haven't received task are eligible for task.
Defaults to SendOrder.SEQUENTIAL.
task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0.
abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs this
method to return. Defaults to None.
"""
self.send(
task=task,
fl_ctx=fl_ctx,
targets=targets,
send_order=send_order,
task_assignment_timeout=task_assignment_timeout,
)
self.wait_for_task(task, abort_signal)
[docs] def get_num_standing_tasks(self) -> int:
"""Get the number of tasks that are currently standing.
Returns:
int: length of the list of standing tasks
"""
return len(self._tasks)
[docs] def cancel_task(
self, task: Task, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None
):
"""Cancel the specified task.
Change the task completion_status, which will inform task monitor to clean up this task
note::
We only mark the task as completed and leave it to the task monitor to clean up.
This is to avoid potential deadlock of task_lock.
Args:
task (Task): the task to be cancelled
completion_status (str, optional): the completion status for this cancellation.
Defaults to TaskCompletionStatus.CANCELLED.
fl_ctx (Optional[FLContext], optional): FLContext associated with this cancellation. Defaults to None.
"""
task.completion_status = completion_status
[docs] def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None):
"""Cancel all standing tasks in this controller.
Args:
completion_status (str, optional): the completion status for this cancellation.
Defaults to TaskCompletionStatus.CANCELLED.
fl_ctx (Optional[FLContext], optional): FLContext associated with this cancellation. Defaults to None.
"""
with self._task_lock:
for t in self._tasks:
t.completion_status = completion_status
[docs] def finalize_run(self, fl_ctx: FLContext):
"""Do cleanup of the coordinator implementation.
.. attention::
Subclass controllers should not overwrite finalize_run.
Args:
fl_ctx (FLContext): FLContext associated with this action
"""
self.cancel_all_tasks() # unconditionally cancel all tasks
self._all_done = True
[docs] def relay(
self,
task: Task,
fl_ctx: FLContext,
targets: Union[List[Client], List[str], None] = None,
send_order: SendOrder = SendOrder.SEQUENTIAL,
task_assignment_timeout: int = 0,
task_result_timeout: int = 0,
dynamic_targets: bool = True,
):
"""Schedule a single task to targets in one-after-another style. This is a non-blocking call.
The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task
to eligible clients based on the send_order.
Args:
task (Task): the task to be scheduled
fl_ctx (FLContext): FLContext associated with this task
targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names
or None (all clients). Defaults to None.
send_order (SendOrder, optional): the order for clients to become eligible.
SEQUENTIAL means the order in targets is enforced.
ANY means any clients that are inside the targets and haven't received the task are eligible.
Defaults to SendOrder.SEQUENTIAL.
task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0.
task_result_timeout (int, optional): how long to wait for current working client to reply its result.
Defaults to 0.
dynamic_targets (bool, optional): allow clients not in targets to join at the end of targets list.
Defaults to True.
Raises:
ValueError: when task_assignment_timeout is greater than task's timeout
ValueError: when task_result_timeout is greater than task's timeout
TypeError: send_order is not defined in SendOrder
TypeError: when dynamic_targets is not a boolean variable
ValueError: targets is None or an empty list but dynamic_targets is False
"""
_check_inputs(
task=task,
fl_ctx=fl_ctx,
targets=targets,
)
_check_positive_int("task_assignment_timeout", task_assignment_timeout)
_check_positive_int("task_result_timeout", task_result_timeout)
if task.timeout and task_assignment_timeout and task_assignment_timeout > task.timeout:
raise ValueError(
"task_assignment_timeout ({}) needs to be less than or equal to task.timeout ({}).".format(
task_assignment_timeout, task.timeout
)
)
if task.timeout and task_result_timeout and task_result_timeout > task.timeout:
raise ValueError(
"task_result_timeout ({}) needs to be less than or equal to task.timeout ({}).".format(
task_result_timeout, task.timeout
)
)
if not isinstance(send_order, SendOrder):
raise TypeError("send_order must be in Enum SendOrder, but got {}".format(type(send_order)))
if not isinstance(dynamic_targets, bool):
raise TypeError("dynamic_targets must be an instance of bool, but got {}".format(type(dynamic_targets)))
if targets is None and dynamic_targets is False:
raise ValueError("Need to provide targets when dynamic_targets is set to False.")
if send_order == SendOrder.SEQUENTIAL:
manager = SequentialRelayTaskManager(
task=task,
task_assignment_timeout=task_assignment_timeout,
task_result_timeout=task_result_timeout,
dynamic_targets=dynamic_targets,
)
else:
manager = AnyRelayTaskManager(
task=task, task_result_timeout=task_result_timeout, dynamic_targets=dynamic_targets
)
self._schedule_task(
task=task,
fl_ctx=fl_ctx,
manager=manager,
targets=targets,
allow_dup_targets=True,
)
[docs] def relay_and_wait(
self,
task: Task,
fl_ctx: FLContext,
targets: Union[List[Client], List[str], None] = None,
send_order=SendOrder.SEQUENTIAL,
task_assignment_timeout: int = 0,
task_result_timeout: int = 0,
dynamic_targets: bool = True,
abort_signal: Optional[Signal] = None,
):
"""Schedule a single task to targets in one-after-another style. This is a blocking call.
The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task
to eligible clients based on the send_order.
Args:
task (Task): the task to be scheduled
fl_ctx (FLContext): FLContext associated with this task
targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names
or None (all clients). Defaults to None.
send_order (SendOrder, optional): the order for clients to become eligible.
SEQUENTIAL means the order in targets is enforced.
ANY means clients in targets and haven't received task are eligible for task.
Defaults to SendOrder.SEQUENTIAL.
task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0.
task_result_timeout (int, optional): how long to wait for current working client to reply its result.
Defaults to 0.
dynamic_targets (bool, optional): allow clients not in targets to join at the end of targets list.
Defaults to True.
abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs
this method to return. Defaults to None.
"""
self.relay(
task=task,
fl_ctx=fl_ctx,
targets=targets,
send_order=send_order,
task_assignment_timeout=task_assignment_timeout,
task_result_timeout=task_result_timeout,
dynamic_targets=dynamic_targets,
)
self.wait_for_task(task, abort_signal)
def _check_dead_clients(self):
if not self._dead_clients:
return
now = time.time()
with self._dead_clients_lock:
for client_name, status in self._dead_clients.items():
if status.disconnect_time:
# already disconnected
continue
if now - status.report_time < self._dead_client_grace:
# this report is still fresh - consider the client to be still alive
continue
# consider client disconnected
status.disconnect_time = now
self.logger.error(f"Client {client_name} is deemed disconnected!")
with self._engine.new_context() as fl_ctx:
fl_ctx.set_prop(FLContextKey.DISCONNECTED_CLIENT_NAME, client_name)
self.fire_event(EventType.CLIENT_DISCONNECTED, fl_ctx)
def _monitor_tasks(self):
while not self._all_done:
# determine clients are still active or not
self._check_dead_clients()
should_abort_job = self._job_policy_violated()
if not should_abort_job:
self.check_tasks()
else:
with self._engine.new_context() as fl_ctx:
self.system_panic("Aborting job due to deployment policy violation", fl_ctx)
return
time.sleep(self._task_check_period)
[docs] def check_tasks(self):
with self._controller_lock:
self._do_check_tasks()
def _do_check_tasks(self):
exit_tasks = []
with self._task_lock:
for task in self._tasks:
if task.completion_status is not None:
exit_tasks.append(task)
continue
# check the task-specific exit condition
manager = task.props[_TASK_KEY_MANAGER]
if manager is not None:
if not isinstance(manager, TaskManager):
raise TypeError(
"manager in task must be an instance of TaskManager, but got {}".format(manager)
)
should_exit, exit_status = manager.check_task_exit(task)
self.logger.debug("should_exit: {}, exit_status: {}".format(should_exit, exit_status))
if should_exit:
task.completion_status = exit_status
exit_tasks.append(task)
continue
# check if task timeout
if task.timeout and time.time() - task.schedule_time >= task.timeout:
task.completion_status = TaskCompletionStatus.TIMEOUT
exit_tasks.append(task)
continue
# check whether clients that the task is waiting are all dead
dead_clients = self._get_task_dead_clients(task)
if dead_clients:
self.logger.info(f"client {dead_clients} is dead - set task {task.name} to TIMEOUT")
task.completion_status = TaskCompletionStatus.CLIENT_DEAD
exit_tasks.append(task)
continue
for exit_task in exit_tasks:
exit_task.is_standing = False
self.logger.debug(
"Removing task={}, completion_status={}".format(exit_task, exit_task.completion_status)
)
self._tasks.remove(exit_task)
for client_task in exit_task.client_tasks:
self.logger.debug("Removing client_task with id={}".format(client_task.id))
self._client_task_map.pop(client_task.id)
# do the task exit processing outside the lock to minimize the locking time
# and to avoid potential deadlock since the CB could schedule another task
if len(exit_tasks) <= 0:
return
with self._engine.new_context() as fl_ctx:
for exit_task in exit_tasks:
with exit_task.cb_lock:
self.log_info(
fl_ctx, "task {} exit with status {}".format(exit_task.name, exit_task.completion_status)
)
if exit_task.task_done_cb is not None:
try:
exit_task.task_done_cb(task=exit_task, fl_ctx=fl_ctx)
except Exception as e:
self.log_exception(
fl_ctx,
"processing error in task_done_cb error on task {}: {}".format(
exit_task.name, secure_format_exception(e)
),
)
exit_task.completion_status = TaskCompletionStatus.ERROR
exit_task.exception = e
def _get_task_dead_clients(self, task: Task):
"""
See whether the task is only waiting for response from a dead client
"""
now = time.time()
lead_time = ConfigService.get_float_var(
name=ConfigVarName.DEAD_CLIENT_CHECK_LEAD_TIME, conf=SystemConfigs.APPLICATION_CONF, default=30.0
)
if now - task.schedule_time < lead_time:
# due to potential race conditions, we'll wait for at least 1 minute after the task
# is started before checking dead clients.
return None
dead_clients = []
for target in task.targets:
ct = _get_client_task(target, task)
if ct is not None and ct.result_received_time:
# response has been received from this client
continue
# either we have not sent the task to this client or we have not received response
# is the client already dead?
if self.get_client_disconnect_time(target):
# this client is dead - remember it
dead_clients.append(target)
else:
# this client is still alive
# we let the task continue its course since we still have live clients
return None
return dead_clients
@staticmethod
def _process_finished_task(task, func):
def wrap(*args, **kwargs):
if func:
func(*args, **kwargs)
task.props[_TASK_KEY_DONE] = True
return wrap
[docs] def wait_for_task(self, task: Task, abort_signal: Signal):
task.props[_TASK_KEY_DONE] = False
task.task_done_cb = self._process_finished_task(task=task, func=task.task_done_cb)
while True:
if task.completion_status is not None:
break
if abort_signal and abort_signal.triggered:
self.cancel_task(task, fl_ctx=None, completion_status=TaskCompletionStatus.ABORTED)
break
task_done = task.props[_TASK_KEY_DONE]
if task_done:
break
time.sleep(self._task_check_period)
def _job_policy_violated(self):
if not self._engine:
return False
with self._engine.new_context() as fl_ctx:
clients = self._engine.get_clients()
alive_clients = []
dead_clients = []
for client in clients:
if self.get_client_disconnect_time(client.name):
dead_clients.append(client.name)
else:
alive_clients.append(client.name)
if not dead_clients:
return False
if not alive_clients:
self.log_error(fl_ctx, f"All clients are dead: {dead_clients}")
return True
job_meta = fl_ctx.get_prop(FLContextKey.JOB_META)
job = job_from_meta(job_meta)
if len(alive_clients) < job.min_sites:
self.log_error(fl_ctx, f"Alive clients {len(alive_clients)} < required min {job.min_sites}")
return True
# check required clients:
if dead_clients and job.required_sites:
dead_required_clients = [c for c in dead_clients if c in job.required_sites]
if dead_required_clients:
self.log_error(fl_ctx, f"Required client(s) dead: {dead_required_clients}")
return True
return False
[docs] def client_is_active(self, client_name: str, reason: str, fl_ctx: FLContext):
with self._dead_clients_lock:
self.log_debug(fl_ctx, f"client {client_name} is active: {reason}")
if client_name in self._dead_clients:
self.log_info(fl_ctx, f"Client {client_name} is removed from watch list: {reason}")
status = self._dead_clients.pop(client_name)
if status.disconnect_time:
self.log_info(fl_ctx, f"Client {client_name} is reconnected")
fl_ctx.set_prop(FLContextKey.RECONNECTED_CLIENT_NAME, client_name)
self.fire_event(EventType.CLIENT_RECONNECTED, fl_ctx)
[docs] def get_client_disconnect_time(self, client_name: str):
"""Get the time that the client was deemed disconnected
Args:
client_name: name of the client
Returns: time at which the client was deemed disconnected; or None if the client is not disconnected
"""
status = self._dead_clients.get(client_name)
if status:
return status.disconnect_time
return None