# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
from dataclasses import dataclass
from typing import Optional
from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import FLContextKey, FLMetaKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import AppConstants
from nvflare.fuel.f3.streaming.download_service import DownloadService
from nvflare.fuel.f3.streaming.transfer_progress import (
DEFAULT_STREAMING_IDLE_TIMEOUT,
DIRECTION_TASK_PAYLOAD_DOWNLOAD,
STREAM_PROGRESS_COMPLETION_ACK_GRACE,
TransferProgressState,
TransferProgressTracker,
)
from nvflare.fuel.utils.constants import PipeChannelName
from nvflare.fuel.utils.pipe.cell_pipe import CellPipe
from nvflare.fuel.utils.pipe.pipe import Message, Pipe, Topic
from nvflare.fuel.utils.pipe.pipe_handler import PipeHandler
from nvflare.fuel.utils.validation_utils import (
check_non_negative_int,
check_non_negative_number,
check_positive_number,
check_str,
)
from nvflare.security.logging import secure_format_exception
STREAM_PROGRESS_TASK_ID_KEYS = ("task_id", "req_id", "request_id", "msg_id")
STREAM_PROGRESS_JOB_ID_KEYS = ("job_id",)
STREAM_PROGRESS_TRANSFER_ID_KEYS = ("transfer_id", "ref_id", "stream_id")
STREAM_PROGRESS_TRANSFER_ID_KIND_KEYS = ("transfer_id_kind", "stream_id_kind")
STREAM_PROGRESS_DIRECTION_KEYS = ("direction",)
STREAM_PROGRESS_RECEIVER_ID_KEYS = ("receiver_id", "requester_id", "requester_fqcn")
STREAM_PROGRESS_SEQUENCE_KEYS = ("sequence", "seq")
STREAM_PROGRESS_BYTES_KEYS = ("bytes_done", "progress", "bytes", "bytes_read", "bytes_received")
STREAM_PROGRESS_ITEM_KEYS = ("items_done", "items", "item_count")
STREAM_PROGRESS_STATE_KEYS = ("state", "status", "event", "event_type")
STREAM_PROGRESS_START_STATUSES = ("start", "started")
_DEFAULT_STREAMING_IDLE_TIMEOUT_SECS = DEFAULT_STREAMING_IDLE_TIMEOUT
STREAM_PROGRESS_MAX_TRACKED_RECORDS = 4096
# Match the DownloadService finished-ref tombstone window so late EOF/completion
# replies after clean transfer completion can still find the progress record.
STREAM_PROGRESS_TERMINAL_RECORD_TTL = DownloadService.FINISHED_REFS_TTL
STREAM_PROGRESS_STATE_ALIASES = {
"active": "active",
"progress": "active",
"in_progress": "active",
"running": "active",
"start": "active",
"started": "active",
"completed": "completed",
"complete": "completed",
"done": "completed",
"success": "completed",
"failed": "failed",
"fail": "failed",
"failure": "failed",
"error": "failed",
"exception": "failed",
"aborted": "aborted",
"abort": "aborted",
"cancelled": "aborted",
}
@dataclass(frozen=True)
class _StreamProgressRecordSnapshot:
job_id: str
task_id: str
transfer_id: str
direction: str
sequence: int
bytes_done: int
items_done: Optional[int]
started_time: float
last_progress_time: float
state: str
transfer_id_kind: Optional[str]
@property
def terminal(self) -> bool:
return self.state in TransferProgressState.TERMINAL_STATES
@dataclass(frozen=True)
class _StreamingTimeoutSnapshot:
peer_read_timeout_explicit: bool
peer_read_timeout: Optional[float]
streaming_idle_timeout: Optional[float]
[docs]
class TaskExchanger(Executor):
def __init__(
self,
pipe_id: str,
read_interval: float = 0.5,
heartbeat_interval: float = 5.0,
heartbeat_timeout: Optional[float] = 60.0,
resend_interval: float = 2.0,
max_resends: Optional[int] = None,
peer_read_timeout: Optional[float] = 60.0,
peer_read_timeout_explicit: bool = False,
streaming_idle_timeout: Optional[float] = _DEFAULT_STREAMING_IDLE_TIMEOUT_SECS,
task_wait_time: Optional[float] = None,
result_poll_interval: float = 0.5,
pipe_channel_name=PipeChannelName.TASK,
):
"""Constructor of TaskExchanger.
Args:
pipe_id (str): component id of pipe.
read_interval (float): how often to read from pipe.
heartbeat_interval (float): how often to send heartbeat to peer.
heartbeat_timeout (float, optional): how long to wait for a
heartbeat from the peer before treating the peer as dead,
0 means DO NOT check for heartbeat.
resend_interval (float): how often to resend a message if failing to send.
None means no resend. Note that if the pipe does not support resending,
then no resend.
max_resends (int, optional): max number of resend. None means no limit.
Defaults to None.
peer_read_timeout (float, optional): time to wait for peer to accept sent message.
peer_read_timeout_explicit (bool): whether peer_read_timeout came from an explicit user override. When
lower than streaming_idle_timeout, this preserves fast-fail behavior instead of progress-extending.
streaming_idle_timeout (float, optional): when task-send peer-read times out, continue waiting only while
the exact transfer has made monotonic stream progress within this many seconds.
task_wait_time (float, optional): how long to wait for a task to complete.
None means waiting forever. Defaults to None.
result_poll_interval (float): how often to poll task result.
Defaults to 0.5.
pipe_channel_name: the channel name for sending task requests.
Defaults to "task".
"""
Executor.__init__(self)
check_str("pipe_id", pipe_id)
check_positive_number("read_interval", read_interval)
check_positive_number("heartbeat_interval", heartbeat_interval)
if heartbeat_timeout is not None:
check_non_negative_number("heartbeat_timeout", heartbeat_timeout)
check_positive_number("resend_interval", resend_interval)
if max_resends is not None:
check_non_negative_int("max_resends", max_resends)
if peer_read_timeout is not None:
check_positive_number("peer_read_timeout", peer_read_timeout)
if streaming_idle_timeout is not None:
check_positive_number("streaming_idle_timeout", streaming_idle_timeout)
if task_wait_time is not None:
check_positive_number("task_wait_time", task_wait_time)
check_positive_number("result_poll_interval", result_poll_interval)
check_str("pipe_channel_name", pipe_channel_name)
self.pipe_id = pipe_id
self.read_interval = read_interval
self.heartbeat_interval = heartbeat_interval
self.heartbeat_timeout = heartbeat_timeout
self.resend_interval = resend_interval
self.max_resends = max_resends
# Timeout values are immutable after construction. Send-loop readers use snapshots to keep related
# values consistent and to make future runtime reconfiguration explicit.
self.peer_read_timeout = peer_read_timeout
self.peer_read_timeout_explicit = peer_read_timeout_explicit
self.streaming_idle_timeout = streaming_idle_timeout
self.task_wait_time = task_wait_time
self.result_poll_interval = result_poll_interval
self.pipe_channel_name = pipe_channel_name
self.pipe = None
self.pipe_handler = None
self._executing = threading.Event()
self._executing_lock = threading.Lock()
self._stream_progress_lock = threading.Lock()
self._stream_progress_tracker = self._make_stream_progress_tracker()
self._explicit_peer_read_timeout_warned = False
self._task_send_startup_budget_info_logged = False
self._peer_read_timeout_once_lock = threading.Lock()
[docs]
def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
engine = fl_ctx.get_engine()
self.pipe = engine.get_component(self.pipe_id)
if not isinstance(self.pipe, Pipe):
self.system_panic(f"component of {self.pipe_id} must be Pipe but got {type(self.pipe)}", fl_ctx)
return
self.pipe.open(self.pipe_channel_name)
elif event_type == EventType.BEFORE_TASK_EXECUTION:
with self._executing_lock:
if self._executing.is_set():
skip = True
else:
skip = False
if skip:
self.log_debug(fl_ctx, "skipping pipe handler reset: execute() is in progress")
return
# Ensure pipe is initialized and is a Pipe before operating on it.
if not isinstance(self.pipe, Pipe):
self.log_debug(fl_ctx, "pipe not initialized or not a Pipe; skipping pipe handler reset")
return
if self.pipe_handler:
self.pipe_handler.stop(close_pipe=False)
self.pipe.clear()
self._create_pipe_handler()
self.pipe_handler.start()
elif event_type == EventType.ABOUT_TO_END_RUN:
self.log_debug(fl_ctx, "Stopping pipe handler")
self._mark_all_stream_progress_terminal(TransferProgressState.ABORTED)
if self.pipe_handler:
self.pipe_handler.notify_end("end_of_job")
self.pipe_handler.stop(close_pipe=False)
if self.pipe:
self.pipe.close()
def _create_pipe_handler(self):
"""Create a new PipeHandler for self.pipe with a handler-bound status callback.
Each handler gets its own closure that checks identity before stopping,
so a late PEER_GONE from a previous handler cannot kill the current one.
The callback uses close_pipe=False because CellPipe.close() is irreversible.
"""
if self.heartbeat_timeout is None:
raise ValueError(
"heartbeat_timeout is None. Set heartbeat_timeout to 0 to disable heartbeat checking, "
"or to a non-negative timeout value."
)
handler = PipeHandler(
pipe=self.pipe,
read_interval=self.read_interval,
heartbeat_interval=self.heartbeat_interval,
heartbeat_timeout=self.heartbeat_timeout,
resend_interval=self.resend_interval,
max_resends=self.max_resends,
)
def _bound_status_cb(msg, _h=handler):
if self.pipe_handler is not _h:
self.logger.debug(f"Ignoring late {msg.topic} from a previous pipe handler")
return
self.logger.info(f"pipe status changed to {msg.topic}: {msg.data}")
self._mark_all_stream_progress_terminal(TransferProgressState.ABORTED)
_h.stop(close_pipe=False)
handler.set_status_cb(_bound_status_cb)
def _bound_msg_cb(msg, _h=handler):
if self.pipe_handler is not _h:
self.logger.debug(f"Ignoring late {msg.topic} from a previous pipe handler")
return
if msg.topic == Topic.STREAM_PROGRESS:
try:
self._handle_stream_progress_message(msg)
except Exception as ex:
self.logger.warning(f"ignored stream progress after handler error: {secure_format_exception(ex)}")
else:
with _h.lock:
_h.messages.append(msg)
handler.set_message_cb(_bound_msg_cb)
self.pipe_handler = handler
return handler
def _make_stream_progress_tracker(self):
idle_timeout = self.streaming_idle_timeout or _DEFAULT_STREAMING_IDLE_TIMEOUT_SECS
return TransferProgressTracker(idle_timeout=idle_timeout)
@staticmethod
def _get_progress_event_value(data, keys):
for key in keys:
if isinstance(data, dict):
if key in data:
return data.get(key)
elif hasattr(data, key):
return getattr(data, key)
return None
@staticmethod
def _normalize_progress_status(status) -> str:
if status is None:
return ""
return str(status).lower()
@staticmethod
def _normalize_tracker_state(status: str) -> str:
return STREAM_PROGRESS_STATE_ALIASES.get(status, TransferProgressState.ACTIVE)
def _handle_stream_progress_message(self, msg: Message):
data = msg.data
if not isinstance(data, dict):
self.logger.warning(f"ignored stream progress with invalid payload: {data}")
return
task_id = self._get_progress_event_value(data, STREAM_PROGRESS_TASK_ID_KEYS)
transfer_id = self._get_progress_event_value(data, STREAM_PROGRESS_TRANSFER_ID_KEYS)
job_id = self._get_progress_event_value(data, STREAM_PROGRESS_JOB_ID_KEYS)
direction = self._get_progress_event_value(data, STREAM_PROGRESS_DIRECTION_KEYS)
if direction is None:
self.logger.warning(f"ignored stream progress without direction: {data}")
return
direction = str(direction)
if direction != DIRECTION_TASK_PAYLOAD_DOWNLOAD:
self.logger.debug(f"ignored stream progress for unsupported direction {direction}: {data}")
return
if job_id is None or task_id is None or transfer_id is None:
self.logger.warning(f"ignored unscoped task_payload_download stream progress: {data}")
return
receiver_id = self._get_progress_event_value(data, STREAM_PROGRESS_RECEIVER_ID_KEYS)
transfer_id_kind = self._get_progress_event_value(data, STREAM_PROGRESS_TRANSFER_ID_KIND_KEYS)
status = self._normalize_progress_status(self._get_progress_event_value(data, STREAM_PROGRESS_STATE_KEYS))
bytes_done = self._get_progress_event_value(data, STREAM_PROGRESS_BYTES_KEYS)
items_done = self._get_progress_event_value(data, STREAM_PROGRESS_ITEM_KEYS)
sequence = self._get_progress_event_value(data, STREAM_PROGRESS_SEQUENCE_KEYS)
try:
bytes_done_value = int(bytes_done) if bytes_done is not None else 0
except (TypeError, ValueError):
self.logger.warning(f"ignored stream progress with invalid bytes_done value: {data}")
return
try:
items_done_value = int(items_done) if items_done is not None else None
except (TypeError, ValueError):
self.logger.warning(f"ignored stream progress with invalid items_done value: {data}")
return
job_id = str(job_id)
task_id = str(task_id)
transfer_id = str(transfer_id)
# execute() normalizes a missing FLContext job id to an empty string and stamps the same
# value into the task header. Treat that empty string as a valid progress scope here.
if not task_id or not transfer_id:
self.logger.warning(f"ignored unscoped task_payload_download stream progress: {data}")
return
transfer_id_kind = None if transfer_id_kind is None else str(transfer_id_kind)
receiver_id = None if receiver_id is None else str(receiver_id)
state = self._normalize_tracker_state(status)
with self._stream_progress_lock:
# Forward task payload aggregation is task/transfer scoped. receiver_id is tolerated for schema
# compatibility but intentionally not part of the forward-path tracker key.
record = self._stream_progress_tracker.get_record(
job_id=job_id,
task_id=task_id,
transfer_id=transfer_id,
direction=direction,
)
if record and record.state in (TransferProgressState.FAILED, TransferProgressState.ABORTED):
if status in STREAM_PROGRESS_START_STATUSES:
self._stream_progress_tracker.remove(
job_id=job_id,
task_id=task_id,
transfer_id=transfer_id,
direction=direction,
)
record = None
if record is None:
has_capacity, record_count = self._stream_progress_tracker_capacity_locked(direction)
if not has_capacity:
self.logger.warning(
f"ignored stream progress for task={task_id} transfer={transfer_id} direction={direction}: "
f"progress tracker is at capacity records={record_count} max={STREAM_PROGRESS_MAX_TRACKED_RECORDS}"
)
return
try:
sequence_value = int(sequence) if sequence is not None else (record.sequence + 1 if record else 0)
except (TypeError, ValueError):
sequence_value = record.sequence + 1 if record else 0
update = self._stream_progress_tracker.update(
job_id=job_id,
task_id=task_id,
transfer_id=transfer_id,
direction=direction,
sequence=sequence_value,
bytes_done=bytes_done_value,
items_done=items_done_value,
state=state,
transfer_id_kind=transfer_id_kind,
)
if update.accepted and update.record and update.record.terminal:
self._prune_terminal_stream_progress_records_locked()
if update.accepted:
if status in STREAM_PROGRESS_START_STATUSES:
event_kind = "start"
elif state == TransferProgressState.COMPLETED:
event_kind = "completion"
elif state in (TransferProgressState.FAILED, TransferProgressState.ABORTED):
event_kind = "failure"
else:
event_kind = "active"
self.logger.info(
f"accepted stream progress {event_kind} for task={task_id} transfer={transfer_id} direction={direction} "
f"receiver_id={receiver_id} state={state} sequence={sequence_value} bytes_done={bytes_done_value} "
f"items_done={items_done_value} progressed={update.progressed}"
)
else:
self.logger.debug(
f"ignored stream progress for task={task_id} transfer={transfer_id} direction={direction}: "
f"{update.reason}"
)
def _get_active_task_payload_records(self, task_id: str, job_id: Optional[str] = None):
normalized_job_id = "" if job_id is None else str(job_id)
with self._stream_progress_lock:
records = [
_StreamProgressRecordSnapshot(
job_id=record.job_id,
task_id=record.task_id,
transfer_id=record.transfer_id,
direction=record.direction,
sequence=record.sequence,
bytes_done=record.bytes_done,
items_done=record.items_done,
started_time=record.started_time,
last_progress_time=record.last_progress_time,
state=record.state,
transfer_id_kind=record.transfer_id_kind,
)
for record in self._stream_progress_tracker.records(
job_id=normalized_job_id,
task_id=str(task_id),
direction=DIRECTION_TASK_PAYLOAD_DOWNLOAD,
)
]
return records, [record for record in records if not record.terminal]
def _prune_terminal_stream_progress_records_locked(self):
self._stream_progress_tracker.prune(before_time=time.time() - STREAM_PROGRESS_TERMINAL_RECORD_TTL)
def _stream_progress_tracker_capacity_locked(self, direction: str) -> tuple[bool, int]:
max_records = STREAM_PROGRESS_MAX_TRACKED_RECORDS
if max_records <= 0:
return True, 0
record_count = len(self._stream_progress_tracker.records(direction=direction))
if record_count < max_records:
return True, record_count
removed_count = self._stream_progress_tracker.prune(
before_time=time.time() - STREAM_PROGRESS_TERMINAL_RECORD_TTL,
direction=direction,
)
record_count -= removed_count
if record_count < max_records:
return True, record_count
idle_timeout = self.streaming_idle_timeout or _DEFAULT_STREAMING_IDLE_TIMEOUT_SECS
removed_count = self._stream_progress_tracker.prune(
before_time=time.time() - idle_timeout,
include_active=True,
direction=direction,
)
record_count -= removed_count
return record_count < max_records, record_count
def _recent_completed_records_hold_wait(
self,
records,
now: float,
fl_ctx: FLContext,
task_name: str,
completed_ack_budget: Optional[float],
) -> bool:
if not records:
return False
completed_records = [record for record in records if record.state == TransferProgressState.COMPLETED]
if len(completed_records) != len(records):
return False
latest_record = max(completed_records, key=lambda record: record.last_progress_time)
elapsed = now - latest_record.last_progress_time
if completed_ack_budget is not None and elapsed >= completed_ack_budget:
return False
completed_ack_budget_text = "unbounded" if completed_ack_budget is None else f"{completed_ack_budget}s"
self.log_info(
fl_ctx,
f"peer has not ACKed task '{task_name}' yet, but stream transfer "
f"'{latest_record.transfer_id}' completed {elapsed:.2f} secs ago; continuing to wait "
f"until task_send_completed_ack_budget={completed_ack_budget_text}",
)
return True
@staticmethod
def _get_task_send_startup_budget(
streaming_idle_timeout: float,
peer_read_timeout: Optional[float] = None,
) -> Optional[float]:
if peer_read_timeout is None:
return None
peer_read_budget = peer_read_timeout
return min(streaming_idle_timeout, max(peer_read_budget, STREAM_PROGRESS_COMPLETION_ACK_GRACE))
@staticmethod
def _get_task_send_completed_ack_budget(
streaming_idle_timeout: float,
peer_read_timeout: Optional[float] = None,
) -> Optional[float]:
if peer_read_timeout is None:
return None
peer_read_budget = peer_read_timeout
return min(streaming_idle_timeout, max(peer_read_budget, STREAM_PROGRESS_COMPLETION_ACK_GRACE))
@staticmethod
def _is_explicit_peer_read_timeout_fast_fail(
peer_read_timeout_explicit: bool,
peer_read_timeout: Optional[float],
streaming_idle_timeout: Optional[float],
) -> bool:
return (
peer_read_timeout_explicit
and peer_read_timeout is not None
and streaming_idle_timeout is not None
and peer_read_timeout < streaming_idle_timeout
)
def _get_streaming_timeout_snapshot(self):
with self._stream_progress_lock:
return _StreamingTimeoutSnapshot(
peer_read_timeout_explicit=self.peer_read_timeout_explicit,
peer_read_timeout=self.peer_read_timeout,
streaming_idle_timeout=self.streaming_idle_timeout,
)
def _should_continue_task_send_waiting(
self,
task_name: str,
task_id: str,
job_id: Optional[str],
send_start_time: float,
fl_ctx: FLContext,
) -> bool:
timeout_snapshot = self._get_streaming_timeout_snapshot()
peer_read_timeout = timeout_snapshot.peer_read_timeout
streaming_idle_timeout = timeout_snapshot.streaming_idle_timeout
if not streaming_idle_timeout:
return False
now = time.time()
records, active_records = self._get_active_task_payload_records(task_id, job_id)
records = [
record
for record in records
if not (
record.state in (TransferProgressState.FAILED, TransferProgressState.ABORTED)
and record.last_progress_time < send_start_time
)
]
if not records:
elapsed = now - send_start_time
wait_budget = self._get_task_send_startup_budget(streaming_idle_timeout, peer_read_timeout)
if wait_budget is not None and elapsed >= wait_budget:
return False
wait_budget_text = "unbounded" if wait_budget is None else f"{wait_budget}s"
self.log_info(
fl_ctx,
f"peer has not read task '{task_name}' after {elapsed} secs and no stream progress record "
f"exists yet; continuing to wait until task_send_wait_budget={wait_budget_text}",
)
return True
if not active_records:
completed_ack_budget = self._get_task_send_completed_ack_budget(streaming_idle_timeout, peer_read_timeout)
if self._recent_completed_records_hold_wait(records, now, fl_ctx, task_name, completed_ack_budget):
return True
return False
terminal_failure_records = [
record
for record in records
if record.state in (TransferProgressState.FAILED, TransferProgressState.ABORTED)
]
if terminal_failure_records:
return False
elapsed = now - send_start_time
recent_records = [
record for record in active_records if now - record.last_progress_time < streaming_idle_timeout
]
if not recent_records:
return False
record = max(recent_records, key=lambda item: item.last_progress_time)
self.log_info(
fl_ctx,
f"peer has not read task '{task_name}' after {elapsed:.2f} secs, "
f"but stream transfer '{record.transfer_id}' has recent activity "
f"(state={record.state}, bytes_done={record.bytes_done}, items_done={record.items_done}); "
"continuing to wait",
)
return True
def _mark_task_stream_progress_terminal(self, task_id: str, state: str, job_id: Optional[str] = None):
if not task_id:
return
with self._stream_progress_lock:
records = list(
self._stream_progress_tracker.records(
job_id="" if job_id is None else str(job_id),
task_id=str(task_id),
direction=DIRECTION_TASK_PAYLOAD_DOWNLOAD,
)
)
for record in records:
if not record.terminal:
self._stream_progress_tracker.mark_terminal(
job_id=record.job_id,
task_id=record.task_id,
transfer_id=record.transfer_id,
direction=record.direction,
state=state,
)
self._prune_terminal_stream_progress_records_locked()
def _mark_all_stream_progress_terminal(self, state: str):
with self._stream_progress_lock:
records = list(self._stream_progress_tracker.records(direction=DIRECTION_TASK_PAYLOAD_DOWNLOAD))
for record in records:
if not record.terminal:
self._stream_progress_tracker.mark_terminal(
job_id=record.job_id,
task_id=record.task_id,
transfer_id=record.transfer_id,
direction=record.direction,
state=state,
)
self._prune_terminal_stream_progress_records_locked()
def _should_honor_explicit_peer_read_timeout(self):
timeout_snapshot = self._get_streaming_timeout_snapshot()
return self._is_explicit_peer_read_timeout_fast_fail(
timeout_snapshot.peer_read_timeout_explicit,
timeout_snapshot.peer_read_timeout,
timeout_snapshot.streaming_idle_timeout,
)
def _log_explicit_peer_read_timeout_warning_once(self, fl_ctx: FLContext):
timeout_snapshot = self._get_streaming_timeout_snapshot()
with self._peer_read_timeout_once_lock:
if self._explicit_peer_read_timeout_warned:
return
self._explicit_peer_read_timeout_warned = True
self.log_warning(
fl_ctx,
f"explicit peer_read_timeout ({timeout_snapshot.peer_read_timeout}s) is lower than "
f"streaming_idle_timeout ({timeout_snapshot.streaming_idle_timeout}s); honoring fast-fail behavior "
"instead of extending the wait on stream progress",
)
def _should_log_clamped_task_send_startup_budget(self):
timeout_snapshot = self._get_streaming_timeout_snapshot()
return (
timeout_snapshot.peer_read_timeout_explicit
and timeout_snapshot.peer_read_timeout is not None
and timeout_snapshot.streaming_idle_timeout is not None
and timeout_snapshot.peer_read_timeout > timeout_snapshot.streaming_idle_timeout
and (self.pipe is None or isinstance(self.pipe, CellPipe))
)
def _log_clamped_task_send_startup_budget_once(self, fl_ctx: FLContext):
if not self._should_log_clamped_task_send_startup_budget():
return
timeout_snapshot = self._get_streaming_timeout_snapshot()
with self._peer_read_timeout_once_lock:
if self._task_send_startup_budget_info_logged:
return
self._task_send_startup_budget_info_logged = True
self.log_info(
fl_ctx,
f"explicit peer_read_timeout ({timeout_snapshot.peer_read_timeout}s) is higher than streaming_idle_timeout "
f"({timeout_snapshot.streaming_idle_timeout}s); using streaming_idle_timeout as the no-progress task-send "
"startup budget",
)
def _get_task_send_peer_read_timeout(self):
"""Return the per-send wait timeout used by PipeHandler.
CellPipe can extend bounded waits through progress callbacks. Other pipe types do not honor
those callbacks, so an explicitly disabled peer timeout stays `None` and defers to
PipeHandler.default_request_timeout.
"""
timeout_snapshot = self._get_streaming_timeout_snapshot()
peer_read_timeout = timeout_snapshot.peer_read_timeout
streaming_idle_timeout = timeout_snapshot.streaming_idle_timeout
if peer_read_timeout is None:
if streaming_idle_timeout and (self.pipe is None or isinstance(self.pipe, CellPipe)):
return min(streaming_idle_timeout, STREAM_PROGRESS_COMPLETION_ACK_GRACE)
return None
if not streaming_idle_timeout or self._is_explicit_peer_read_timeout_fast_fail(
timeout_snapshot.peer_read_timeout_explicit, peer_read_timeout, streaming_idle_timeout
):
return peer_read_timeout
if self.pipe is not None and not isinstance(self.pipe, CellPipe):
return peer_read_timeout
return min(peer_read_timeout, streaming_idle_timeout, STREAM_PROGRESS_COMPLETION_ACK_GRACE)
def _unread_task_send_is_failure(self):
timeout_snapshot = self._get_streaming_timeout_snapshot()
peer_read_timeout = timeout_snapshot.peer_read_timeout
streaming_idle_timeout = timeout_snapshot.streaming_idle_timeout
if peer_read_timeout is not None:
return True
if self.pipe is not None and not isinstance(self.pipe, CellPipe):
return True
return bool(streaming_idle_timeout and (self.pipe is None or isinstance(self.pipe, CellPipe)))
def _send_task_to_peer(self, req: Message, fl_ctx: FLContext, abort_signal: Signal) -> bool:
job_id = None
get_header = getattr(req.data, "get_header", None)
if callable(get_header):
job_id = get_header(FLMetaKey.JOB_ID)
job_id = "" if job_id is None else str(job_id)
send_start_time = time.time()
def _progress_wait_cb():
if self._should_honor_explicit_peer_read_timeout():
self._log_explicit_peer_read_timeout_warning_once(fl_ctx)
return False
return self._should_continue_task_send_waiting(
task_name=req.topic,
task_id=req.msg_id,
job_id=job_id,
send_start_time=send_start_time,
fl_ctx=fl_ctx,
)
req._progress_wait_cb = _progress_wait_cb
self._log_clamped_task_send_startup_budget_once(fl_ctx)
return self.pipe_handler.send_to_peer(
req, timeout=self._get_task_send_peer_read_timeout(), abort_signal=abort_signal
)
[docs]
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
"""
The TaskExchanger always sends the Shareable to the peer, and expects to receive a Shareable object from the
peer. The peer can convert the Shareable object to whatever format that is best for its applications (e.g.
DXO or FLModel object). Similarly, when submitting result, the peer must convert its result object to a
Shareable object before sending it back to the TaskExchanger.
This "late-binding" (binding of the Shareable object to an application-friendly object) strategy makes the
TaskExchanger generic and can be reused for any applications (e.g. Shareable based, DXO based, or any custom
data based).
"""
with self._executing_lock:
acquired = not self._executing.is_set()
if acquired:
self._executing.set()
try:
return self._do_execute(task_name, shareable, fl_ctx, abort_signal)
finally:
if acquired:
self._executing.clear()
def _do_execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
if not self.check_input_shareable(task_name, shareable, fl_ctx):
self.log_error(fl_ctx, "bad input task shareable")
return make_reply(ReturnCode.BAD_TASK_DATA)
job_id = fl_ctx.get_job_id()
job_id = "" if job_id is None else str(job_id)
shareable.set_header(FLMetaKey.JOB_ID, job_id)
shareable.set_header(FLMetaKey.SITE_NAME, fl_ctx.get_identity_name())
task_id = shareable.get_header(key=FLContextKey.TASK_ID)
# send to peer
self.log_info(fl_ctx, f"sending task to peer {self.peer_read_timeout=}")
req = Message.new_request(topic=task_name, data=shareable, msg_id=task_id)
start_time = time.time()
has_been_read = self._send_task_to_peer(req, fl_ctx, abort_signal)
if not has_been_read:
if self._unread_task_send_is_failure():
self._mark_task_stream_progress_terminal(task_id, TransferProgressState.ABORTED, job_id=job_id)
self.log_error(
fl_ctx,
f"peer does not accept task '{task_name}' in {time.time() - start_time} secs - aborting task!",
)
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
self.log_info(
fl_ctx,
f"peer did not confirm reading task '{task_name}' in {time.time() - start_time} secs; "
"continuing because peer_read_timeout is disabled",
)
else:
self.log_info(fl_ctx, f"task {task_name} sent to peer in {time.time() - start_time} secs")
# wait for result
self.log_debug(fl_ctx, "Waiting for result from peer")
start = time.time()
while True:
if abort_signal.triggered:
# notify peer that the task is aborted
self.log_debug(fl_ctx, f"task '{task_name}' is aborted.")
self._mark_task_stream_progress_terminal(task_id, TransferProgressState.ABORTED, job_id=job_id)
self.pipe_handler.notify_abort(task_id)
self.pipe_handler.stop()
return make_reply(ReturnCode.TASK_ABORTED)
if self.pipe_handler.asked_to_stop:
self.log_info(fl_ctx, "task pipe stopped! aborting task.")
self._mark_task_stream_progress_terminal(task_id, TransferProgressState.ABORTED, job_id=job_id)
self.pipe_handler.notify_abort(task_id)
abort_signal.trigger("task pipe stopped!")
return make_reply(ReturnCode.TASK_ABORTED)
reply: Optional[Message] = self.pipe_handler.get_next()
if reply is None:
if self.task_wait_time and time.time() - start > self.task_wait_time:
# timed out
self.log_error(fl_ctx, f"task '{task_name}' timeout after {self.task_wait_time} secs")
# also tell peer to abort the task
self._mark_task_stream_progress_terminal(task_id, TransferProgressState.ABORTED, job_id=job_id)
self.pipe_handler.notify_abort(task_id)
abort_signal.trigger(f"task '{task_name}' timeout after {self.task_wait_time} secs")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
elif reply.msg_type != Message.REPLY:
self.log_warning(
fl_ctx, f"ignored reply: '{reply}' (wrong message type) while waiting for the result of {task_name}"
)
elif req.topic != reply.topic:
# ignore wrong topic
self.log_warning(
fl_ctx,
f"ignored reply: '{reply}' (reply topic does not match req: '{req}') while waiting for the result of {task_name}",
)
elif req.msg_id != reply.req_id:
self.log_warning(
fl_ctx,
f"ignored reply: '{reply}' (reply req_id does not match req msg_id: '{req}') while waiting for the result of {task_name}",
)
else:
self.log_info(fl_ctx, f"got result '{reply}' for task '{task_name}'")
try:
result = reply.data
if not isinstance(result, Shareable):
self._mark_task_stream_progress_terminal(task_id, TransferProgressState.FAILED, job_id=job_id)
self.log_error(fl_ctx, f"bad task result from peer: expect Shareable but got {type(result)}")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
current_round = shareable.get_header(AppConstants.CURRENT_ROUND)
if current_round is not None:
result.set_header(AppConstants.CURRENT_ROUND, current_round)
if not self.check_output_shareable(task_name, result, fl_ctx):
self._mark_task_stream_progress_terminal(task_id, TransferProgressState.FAILED, job_id=job_id)
self.log_error(fl_ctx, "bad task result from peer")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
self.log_info(fl_ctx, f"received result of {task_name} from peer in {time.time() - start} secs")
self._mark_task_stream_progress_terminal(task_id, TransferProgressState.COMPLETED, job_id=job_id)
return result
except Exception as ex:
self._mark_task_stream_progress_terminal(task_id, TransferProgressState.FAILED, job_id=job_id)
self.log_error(fl_ctx, f"Failed to convert result: {secure_format_exception(ex)}")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
time.sleep(self.result_poll_interval)
[docs]
def check_output_shareable(self, task_name: str, shareable: Shareable, fl_ctx: FLContext) -> bool:
"""Checks output shareable after execute.
Returns:
True, if output shareable looks good; False, otherwise.
"""
return True
[docs]
def ask_peer_to_end(self, fl_ctx: FLContext) -> bool:
req = Message.new_request(topic=Topic.END, data="END")
has_been_read = self.pipe_handler.send_to_peer(req, timeout=self.peer_read_timeout)
if self.peer_read_timeout and not has_been_read:
self.log_warning(
fl_ctx,
f"3rd party does not read END msg in {self.peer_read_timeout} secs!",
)
return False
return True
[docs]
def peer_is_up_or_dead(self) -> bool:
return self.pipe_handler.peer_is_up_or_dead.is_set()
[docs]
def reset_peer_is_up_or_dead(self):
self.pipe_handler.peer_is_up_or_dead.clear()
[docs]
def pause_pipe_handler(self):
"""Stops pipe_handler heartbeat."""
if self.pipe_handler:
self.pipe_handler.pause()
[docs]
def resume_pipe_handler(self):
"""Resumes pipe_handler heartbeat."""
if self.pipe_handler:
self.pipe_handler.resume()
[docs]
def get_pipe(self):
"""Gets pipe."""
return self.pipe
[docs]
def get_pipe_channel_name(self):
"""Gets pipe_channel_name."""
return self.pipe_channel_name