Source code for nvflare.apis.impl.any_relay_manager

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from typing import Tuple

from nvflare.apis.controller_spec import ClientTask, Task, TaskCompletionStatus
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReservedHeaderKey, Shareable

from .task_manager import TaskCheckStatus, TaskManager

_KEY_DYNAMIC_TARGETS = "__dynamic_targets"
_KEY_TASK_RESULT_TIMEOUT = "__task_result_timeout"
_KEY_SEND_TARGET_COUNTS = "__sent_target_count"
_KEY_PENDING_CLIENT = "__pending_client"


[docs]class AnyRelayTaskManager(TaskManager): def __init__(self, task: Task, task_result_timeout, dynamic_targets): """Task manager for relay controller on SendOrder.ANY. Args: task (Task): an instance of Task task_result_timeout (int): timeout value on reply of one client dynamic_targets (bool): allow clients to join after this task starts """ TaskManager.__init__(self) if task_result_timeout is None: task_result_timeout = 0 task.props[_KEY_DYNAMIC_TARGETS] = dynamic_targets task.props[_KEY_TASK_RESULT_TIMEOUT] = task_result_timeout task.props[_KEY_SEND_TARGET_COUNTS] = {} # target name => times sent task.props[_KEY_PENDING_CLIENT] = None
[docs] def check_task_send(self, client_task: ClientTask, fl_ctx: FLContext) -> TaskCheckStatus: """Determine whether the task should be sent to the client. Args: client_task (ClientTask): the task processing state of the client fl_ctx (FLContext): fl context that comes with the task request Raises: RuntimeError: when a client asking for a task while the same client_task has already been dispatched to it Returns: TaskCheckStatus: NO_BLOCK for not sending the task, BLOCK for waiting, SEND for OK to send """ client_name = client_task.client.name task = client_task.task if task.props[_KEY_DYNAMIC_TARGETS]: if task.targets is None: task.targets = [] if client_name not in task.targets: task.targets.append(client_name) # is this client eligible? if client_name not in task.targets: # this client is not a target return TaskCheckStatus.NO_BLOCK client_occurrences = task.targets.count(client_name) sent_target_count = task.props[_KEY_SEND_TARGET_COUNTS] send_count = sent_target_count.get(client_name, 0) if send_count >= client_occurrences: # already sent enough times to this client return TaskCheckStatus.NO_BLOCK # only allow one pending task. Is there a client pending result? pending_client_name = task.props[_KEY_PENDING_CLIENT] task_result_timeout = task.props[_KEY_TASK_RESULT_TIMEOUT] if pending_client_name is not None: # see whether the result has been received pending_task = task.last_client_task_map[pending_client_name] if pending_task.result_received_time is None: # result has not been received # Note: in this case, the pending client and the asking client must not be the # same, because this would be a resend case already taken care of by the controller. if pending_client_name == client_name: raise RuntimeError("Logic Error: must not be here for client {}".format(client_name)) # should this client timeout? if task_result_timeout and time.time() - pending_task.task_sent_time > task_result_timeout: # timeout! # give up on the pending task and move to the next target sent_target_count[pending_client_name] -= 1 pass else: # continue to wait return TaskCheckStatus.BLOCK # can send task.props[_KEY_PENDING_CLIENT] = client_name sent_target_count[client_name] = send_count + 1 return TaskCheckStatus.SEND
[docs] def check_task_exit(self, task: Task) -> Tuple[bool, TaskCompletionStatus]: """Determine whether the task should exit. Args: task (Task): an instance of Task Returns: Tuple[bool, TaskCompletionStatus]: first entry in the tuple means whether to exit the task or not. If it's True, the task should exit. second entry in the tuple indicates the TaskCompletionStatus. """ # are we waiting for any client? num_targets = 0 if task.targets is None else len(task.targets) if num_targets == 0: # nothing has been sent return False, TaskCompletionStatus.IGNORED # see whether all targets are sent sent_target_count = task.props[_KEY_SEND_TARGET_COUNTS] total_sent = 0 for v in sent_target_count.values(): total_sent += v if total_sent < num_targets: return False, TaskCompletionStatus.IGNORED # client_tasks might have not been added to task if len(task.client_tasks) < num_targets: return False, TaskCompletionStatus.IGNORED for c_t in task.client_tasks: if c_t.result_received_time is None: return False, TaskCompletionStatus.IGNORED return True, TaskCompletionStatus.OK
[docs] def check_task_result(self, result: Shareable, client_task: ClientTask, fl_ctx: FLContext): """Check the result received from the client. See whether the client_task is the last one in the task's list If not, then it is a late response and ReservedHeaderKey.REPLY_IS_LATE is set to True in result's header. Args: result (Shareable): an instance of Shareable client_task (ClientTask): the task processing state of the client fl_ctx (FLContext): fl context that comes with the task request """ task = client_task.task if client_task != task.client_tasks[-1]: result.set_header(key=ReservedHeaderKey.REPLY_IS_LATE, value=True)