# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import threading
import time
from concurrent.futures import TimeoutError, as_completed
from typing import Callable, Optional
from nvflare.fuel.f3.cellnet.core_cell import CoreCell
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey
from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.mpm import MainProcessMonitor
from nvflare.fuel.f3.stats_pool import StatsPoolManager
from nvflare.fuel.f3.streaming.stream_const import (
STREAM_ACK_TOPIC,
STREAM_CHANNEL,
STREAM_DATA_TOPIC,
StreamDataType,
StreamHeaderKey,
)
from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture, StreamTaskSpec
from nvflare.fuel.f3.streaming.stream_utils import (
ONE_MB,
CheckedExecutor,
gen_stream_id,
stream_stats_category,
stream_thread_pool,
wrap_view,
)
STREAM_CHUNK_SIZE = 1024 * 1024
STREAM_WINDOW_SIZE = 16 * STREAM_CHUNK_SIZE
STREAM_ACK_WAIT = 300
STREAM_RETRY_WAIT = 5.0
STREAM_RETRY_TIMEOUT = 60.0
STREAM_RETRY_WORKERS = 32
STREAM_RETRY_RESULT_TIMEOUT = 1.0
STREAM_TYPE_BYTE = "byte"
STREAM_TYPE_BLOB = "blob"
STREAM_TYPE_FILE = "file"
COUNTER_NAME_SENT = "sent"
log = logging.getLogger(__name__)
def _payload_size(payload) -> int:
if payload is None:
return 0
if isinstance(payload, list):
return sum(len(item) for item in payload)
return len(payload)
def _snapshot_payload(payload):
if payload is None:
return None
if isinstance(payload, list):
return [bytes(item) for item in payload]
return bytes(payload)
[docs]
class ReliableRetryScheduler:
def __init__(self):
self.tasks = {}
self.cv = threading.Condition()
self.thread = None
self.stopped = False
self.generation = 0
self.retry_task_pool = CheckedExecutor(STREAM_RETRY_WORKERS, "stm_retry")
# task -> dispatch timestamp, used to detect retry dispatches stuck in transport sends
self.inflight_tasks = {}
self.stalled_tasks = set()
[docs]
def register(self, task):
with self.cv:
if self.stopped:
return
self.tasks[task.sid] = task
self.generation += 1
if not self.thread or not self.thread.is_alive():
self.thread = threading.Thread(target=self._run, name="stm_retry", daemon=True)
self.thread.start()
self.cv.notify()
[docs]
def unregister(self, task):
with self.cv:
registered = self.tasks.get(task.sid)
if registered is task:
self.tasks.pop(task.sid, None)
self.inflight_tasks.pop(task, None)
self.stalled_tasks.discard(task)
self.generation += 1
self.cv.notify()
[docs]
def wakeup(self):
with self.cv:
self.generation += 1
self.cv.notify()
[docs]
def shutdown(self):
with self.cv:
self.stopped = True
self.generation += 1
self.cv.notify()
thread = self.thread
if thread and thread.is_alive() and thread is not threading.current_thread():
thread.join(timeout=1.0)
self.retry_task_pool.shutdown(wait=False)
def _finish_inflight(self, task):
with self.cv:
self.inflight_tasks.pop(task, None)
self.stalled_tasks.discard(task)
self.cv.notify()
def _run(self):
while True:
with self.cv:
if self.stopped:
return
now = time.monotonic()
tasks = [task for task in self.tasks.values() if task not in self.inflight_tasks]
for task in tasks:
self.inflight_tasks[task] = now
stalled = [
(task, now - start)
for task, start in self.inflight_tasks.items()
if task not in self.stalled_tasks and now - start > task.retry_timeout
]
self.stalled_tasks.update(task for task, _elapsed in stalled)
generation = self.generation
for task, elapsed in stalled:
log.error(f"{task} retry dispatch has not returned after {elapsed:.1f} seconds, retries are stalled")
next_wait = None
futures = {}
completed_futures = set()
for task in tasks:
future = self.retry_task_pool.submit(task.retry_task)
if future is None:
self._finish_inflight(task)
continue
futures[future] = task
try:
for future in as_completed(futures, timeout=STREAM_RETRY_RESULT_TIMEOUT):
completed_futures.add(future)
task = futures[future]
self._finish_inflight(task)
wait_time = future.result()
if wait_time is not None:
next_wait = wait_time if next_wait is None else min(next_wait, wait_time)
except TimeoutError:
next_wait = (
STREAM_RETRY_RESULT_TIMEOUT if next_wait is None else min(next_wait, STREAM_RETRY_RESULT_TIMEOUT)
)
for future, task in futures.items():
if future not in completed_futures:
future.add_done_callback(lambda _future, retry_task=task: self._finish_inflight(retry_task))
with self.cv:
if self.stopped:
return
if self.generation == generation:
self.cv.wait(timeout=next_wait)
reliable_retry_scheduler = ReliableRetryScheduler()
MainProcessMonitor.add_cleanup_cb(reliable_retry_scheduler.shutdown)
[docs]
class TxTask(StreamTaskSpec):
def __init__(
self,
cell: CoreCell,
chunk_size: int,
channel: str,
topic: str,
target: str,
headers: dict,
stream: Stream,
reliable: Optional[bool],
secure: bool,
optional: bool,
):
self.cell = cell
self.chunk_size = chunk_size
self.sid = gen_stream_id()
self.buffer = wrap_view(bytearray(chunk_size))
# Optimization to send the original buffer without copying
self.direct_buf: Optional[bytes] = None
self.buffer_size = 0
self.channel = channel
self.topic = topic
self.target = target
self.headers = headers
self.stream = stream
self.stream_future = None
self.task_future = None
self.ack_waiter = threading.Event()
self.seq = 0
self.seq_ack = -1
self.offset = 0
self.offset_ack = 0
self.secure = secure
self.optional = optional
self.stopped = False
self.stopping = False
self.send_lock = threading.RLock()
self.stream_future = StreamFuture(self.sid, task_handle=self)
self.stream_future.set_size(stream.get_size())
config = CommConfigurator()
self.reliable = config.get_streaming_reliable(False) if reliable is None else reliable
self.window_size = config.get_streaming_window_size(STREAM_WINDOW_SIZE)
self.ack_wait = config.get_streaming_ack_wait(STREAM_ACK_WAIT)
self.ack_progress_timeout = config.get_streaming_ack_progress_timeout(60.0)
# Guard against zero/negative config to avoid wait(0) busy-spin loops.
self.ack_progress_check_interval = max(0.01, config.get_streaming_ack_progress_check_interval(5.0))
self.last_ack_progress_ts = time.monotonic()
self.retry_wait = max(0.01, config.get_streaming_retry_wait(STREAM_RETRY_WAIT))
self.retry_timeout = max(0.01, config.get_streaming_retry_timeout(STREAM_RETRY_TIMEOUT))
self.retry_max_pending_bytes = config.get_streaming_retry_max_pending_bytes(2 * self.window_size)
if self.reliable:
self.pending_messages = {}
self.pending_message_bytes = 0
self.retry_lock = threading.RLock()
reliable_retry_scheduler.register(self)
else:
self.pending_messages = None
self.pending_message_bytes = 0
self.retry_lock = None
def __str__(self):
return f"Tx[SID:{self.sid} to {self.target} for {self.channel}/{self.topic}]"
[docs]
def send_loop(self):
"""Read/send loop to transmit the whole stream with flow control"""
while not self.stopped:
buf = self.stream.read(self.chunk_size)
if not buf:
# End of Stream
if not self.send_pending_buffer(final=True):
return
self.stop()
return
# Flow control
window = self.offset - self.offset_ack
# It may take several ACKs to clear up the window
while window > self.window_size:
log.debug(f"{self} window size {window} exceeds limit: {self.window_size}")
wait_start = time.monotonic()
while window > self.window_size:
if self.stopped:
return
now = time.monotonic()
if now - self.last_ack_progress_ts >= self.ack_progress_timeout:
self.stop(StreamError(f"{self} ACK made no progress for {self.ack_progress_timeout} seconds"))
return
elapsed = now - wait_start
if elapsed >= self.ack_wait:
self.stop(StreamError(f"{self} ACK timeouts after {self.ack_wait} seconds"))
return
self.ack_waiter.clear()
wait_timeout = min(self.ack_progress_check_interval, self.ack_wait - elapsed)
self.ack_waiter.wait(timeout=wait_timeout)
window = self.offset - self.offset_ack
size = len(buf)
if size > self.chunk_size:
raise StreamError(f"{self} Stream returns invalid size: {size}")
# Don't push out chunk when it's equal, wait till next round to detect EOS
# For example, if the stream size is chunk size (1M), this avoids sending two chunks.
if size + self.buffer_size > self.chunk_size:
if not self.send_pending_buffer():
return
if size == self.chunk_size:
self.direct_buf = buf
else:
self.buffer[self.buffer_size : self.buffer_size + size] = buf
self.buffer_size += size
[docs]
def send_pending_buffer(self, final=False):
if self.buffer_size == 0:
payload = bytes(0)
elif self.buffer_size == self.chunk_size:
if self.direct_buf:
payload = self.direct_buf
else:
payload = self.buffer
else:
payload = self.buffer[0 : self.buffer_size]
if self.reliable:
payload = _snapshot_payload(payload)
message = Message(None, payload)
if self.headers:
message.add_headers(self.headers)
stream_headers = {
StreamHeaderKey.CHANNEL: self.channel,
StreamHeaderKey.TOPIC: self.topic,
StreamHeaderKey.SIZE: self.stream.get_size(),
StreamHeaderKey.STREAM_ID: self.sid,
StreamHeaderKey.DATA_TYPE: StreamDataType.FINAL if final else StreamDataType.CHUNK,
StreamHeaderKey.SEQUENCE: self.seq,
StreamHeaderKey.OFFSET: self.offset,
StreamHeaderKey.RELIABLE: self.reliable,
StreamHeaderKey.OPTIONAL: self.optional,
}
if self.reliable and self.seq == 0:
stream_headers[StreamHeaderKey.RETRY_WAIT] = self.retry_wait
stream_headers[StreamHeaderKey.RETRY_TIMEOUT] = self.retry_timeout
message.add_headers(stream_headers)
if self.reliable:
errors = None
over_limit_error = None
with self.send_lock:
curr_time = time.monotonic()
with self.retry_lock:
if self.stopped:
return False
pending_message_size = _payload_size(message.payload)
self.pending_messages[self.seq] = None, curr_time, message
self.pending_message_bytes += pending_message_size
if self.retry_max_pending_bytes > 0 and self.pending_message_bytes > self.retry_max_pending_bytes:
self.pending_messages.pop(self.seq, None)
self.pending_message_bytes -= pending_message_size
msg = (
f"{self} has too many retry messages "
f"({self.pending_message_bytes + pending_message_size} > {self.retry_max_pending_bytes})"
)
over_limit_error = StreamError(msg)
if not over_limit_error:
reliable_retry_scheduler.wakeup()
errors = self.cell.fire_and_forget(
STREAM_CHANNEL,
STREAM_DATA_TOPIC,
self.target,
message,
secure=self.secure,
optional=self.optional,
)
if over_limit_error:
log.error(str(over_limit_error))
self.stop(over_limit_error)
return False
else:
errors = self.cell.fire_and_forget(
STREAM_CHANNEL, STREAM_DATA_TOPIC, self.target, message, secure=self.secure, optional=self.optional
)
errors = errors or {}
error = errors.get(self.target)
if error:
msg = f"{self} Message sending error to target {self.target}: {error}"
if self.reliable:
log.error(f"{msg}, will retry in {self.retry_wait} seconds")
else:
self.stop(StreamError(msg))
return False
# Update state
self.seq += 1
self.offset += self.buffer_size
self.buffer_size = 0
self.direct_buf = None
# Update future
self.stream_future.set_progress(self.offset)
return True
[docs]
def stop(self, error: Optional[StreamError] = None, notify=True):
if self.reliable:
if error:
with self.send_lock:
if not self._prepare_reliable_stop(error):
return
elif not self._prepare_reliable_stop(error):
return
reliable_retry_scheduler.unregister(self)
else:
if self.stopped:
return
self.stopped = True
self.remove_task()
if not self.ack_waiter.is_set():
self.ack_waiter.set()
if self.task_future:
self.task_future.cancel()
if not error:
# Result is the number of bytes streamed
if self.stream_future:
self.stream_future.set_result(self.offset)
return
# Error handling
log.debug(f"{self} Stream error: {error}")
if self.stream_future:
self.stream_future.set_exception(error)
if notify:
message = Message(None, None)
if self.headers:
message.add_headers(self.headers)
message.add_headers(
{
StreamHeaderKey.STREAM_ID: self.sid,
StreamHeaderKey.DATA_TYPE: StreamDataType.ERROR,
StreamHeaderKey.OFFSET: self.offset,
StreamHeaderKey.ERROR_MSG: str(error),
}
)
try:
self.cell.fire_and_forget(
STREAM_CHANNEL, STREAM_DATA_TOPIC, self.target, message, secure=self.secure, optional=True
)
except Exception as ex:
log.error(f"{self} failed to notify stream error to target {self.target}: {ex}")
def _prepare_reliable_stop(self, error: Optional[StreamError]) -> bool:
with self.retry_lock:
if self.stopped:
return False
if not error and self.pending_messages:
self.stopping = True
reliable_retry_scheduler.wakeup()
if not self.ack_waiter.is_set():
self.ack_waiter.set()
return False
self.stopped = True
self.stopping = False
if error:
self.pending_messages.clear()
self.pending_message_bytes = 0
return True
[docs]
def handle_ack(self, message: Message):
origin = message.get_header(MessageHeaderKey.ORIGIN)
ack_seq = message.get_header(StreamHeaderKey.SEQUENCE, None)
offset = message.get_header(StreamHeaderKey.OFFSET, None)
error = message.get_header(StreamHeaderKey.ERROR_MSG, None)
if error:
self.stop(StreamError(f"{self} Received error from {origin}: {error}"), notify=False)
return
if self.reliable and ack_seq is None:
self.stop(StreamError(f"{self} receiving end at {origin} doesn't support reliable streaming"), notify=True)
return
if self.reliable:
should_stop = False
ack_progressed = False
with self.retry_lock:
if offset is not None and offset > self.offset_ack:
self.offset_ack = offset
ack_progressed = True
if ack_seq is not None and ack_seq > self.seq_ack:
self.seq_ack = ack_seq
ack_progressed = True
if ack_progressed:
self.last_ack_progress_ts = time.monotonic()
if self.pending_messages and ack_seq is not None:
for seq in list(self.pending_messages):
if seq <= ack_seq:
_retry_start_time, _last_retry, message = self.pending_messages.pop(seq)
self.pending_message_bytes -= _payload_size(message.payload)
should_stop = self.stopping and not self.pending_messages
if should_stop:
self.stop()
elif offset is not None and offset > self.offset_ack:
self.offset_ack = offset
self.last_ack_progress_ts = time.monotonic()
if not self.ack_waiter.is_set():
self.ack_waiter.set()
[docs]
def start_task_thread(self, task_handler: Callable):
self.task_future = stream_thread_pool.submit(task_handler, self)
[docs]
def cancel(self):
self.stop(error=StreamError("cancelled"))
[docs]
def retry_task(self) -> Optional[float]:
try:
return self._retry_task()
except Exception as ex:
msg = f"{self} retry thread ended due to error: {ex}"
log.error(msg)
self.stop(StreamError(msg), notify=True)
return None
def _retry_task(self) -> Optional[float]:
should_stop = False
next_wait = None
messages_to_retry = []
retry_next_wait = None
retry_error = None
with self.retry_lock:
if self.stopped:
return None
if not self.pending_messages:
should_stop = self.stopping
else:
curr_time = time.monotonic()
for seq, value in list(self.pending_messages.items()):
retry_start_time, last_retry, message = value
wait_time = self.retry_wait - (curr_time - last_retry)
remaining_retry_timeout = self.retry_timeout
if retry_start_time is not None:
retry_time = curr_time - retry_start_time
if retry_time > self.retry_timeout:
msg = f"{self} seq {seq} retry failed after {retry_time:.2f} seconds from first retry"
log.error(msg)
retry_error = StreamError(msg)
break
remaining_retry_timeout = self.retry_timeout - retry_time
wait_time = min(wait_time, remaining_retry_timeout)
if wait_time <= 0:
retry_start_time = curr_time if retry_start_time is None else retry_start_time
messages_to_retry.append((seq, message))
self.pending_messages[seq] = retry_start_time, curr_time, message
after_retry_wait = min(self.retry_wait, remaining_retry_timeout)
retry_next_wait = (
after_retry_wait if retry_next_wait is None else min(retry_next_wait, after_retry_wait)
)
else:
next_wait = wait_time if next_wait is None else min(next_wait, wait_time)
if retry_error:
self.stop(error=retry_error)
return None
if should_stop:
self.stop()
return None
if messages_to_retry:
# Hold send_lock so stop(error) cannot clear pending state and notify the receiver
# while a retry send is still in flight, which would deliver a ghost chunk.
with self.send_lock:
with self.retry_lock:
if self.stopped:
return None
for seq, message in messages_to_retry:
errors = self.cell.fire_and_forget(
STREAM_CHANNEL,
STREAM_DATA_TOPIC,
self.target,
message,
secure=self.secure,
optional=self.optional,
)
errors = errors or {}
error = errors.get(self.target)
if error:
log.error(
f"{self} message retry error for target {self.target} seq {seq}: "
f"{error}, will retry again in {self.retry_wait} seconds"
)
next_wait = retry_next_wait if next_wait is None else min(next_wait, retry_next_wait)
return next_wait
[docs]
def remove_task(self):
with ByteStreamer.map_lock:
ByteStreamer.tx_task_map.pop(self.sid, None)
log.debug(f"{self} is removed")
[docs]
class ByteStreamer:
tx_task_map = {}
map_lock = threading.Lock()
sent_stream_counter_pool = StatsPoolManager.add_counter_pool(
name="Sent_Stream_Counters",
description="Counters of sent streams",
counter_names=[COUNTER_NAME_SENT],
)
sent_stream_size_pool = StatsPoolManager.add_msg_size_pool("Sent_Stream_Sizes", "Sizes of streams sent (MBs)")
def __init__(self, cell: CoreCell):
self.cell = cell
self.cell.register_request_cb(channel=STREAM_CHANNEL, topic=STREAM_ACK_TOPIC, cb=self._ack_handler)
self.chunk_size = CommConfigurator().get_streaming_chunk_size(STREAM_CHUNK_SIZE)
[docs]
def get_chunk_size(self):
return self.chunk_size
[docs]
def send(
self,
channel: str,
topic: str,
target: str,
headers: dict,
stream: Stream,
stream_type=STREAM_TYPE_BYTE,
secure=False,
optional=False,
reliable: Optional[bool] = None,
) -> StreamFuture:
tx_task = TxTask(
self.cell, self.chunk_size, channel, topic, target, headers, stream, reliable, secure, optional
)
with ByteStreamer.map_lock:
ByteStreamer.tx_task_map[tx_task.sid] = tx_task
tx_task.start_task_thread(self._transmit_task)
fqcn = self.cell.my_info.fqcn
ByteStreamer.sent_stream_counter_pool.increment(
category=stream_stats_category(fqcn, channel, topic, stream_type), counter_name=COUNTER_NAME_SENT
)
ByteStreamer.sent_stream_size_pool.record_value(
category=stream_stats_category(fqcn, channel, topic, stream_type), value=stream.get_size() / ONE_MB
)
return tx_task.stream_future
@staticmethod
def _transmit_task(task: TxTask):
try:
task.send_loop()
except Exception as ex:
msg = f"{task} Error while sending: {ex}"
if task.optional:
log.debug(msg)
else:
log.error(msg)
task.stop(StreamError(msg), True)
@staticmethod
def _ack_handler(message: Message):
sid = message.get_header(StreamHeaderKey.STREAM_ID)
with ByteStreamer.map_lock:
tx_task = ByteStreamer.tx_task_map.get(sid, None)
if not tx_task:
origin = message.get_header(MessageHeaderKey.ORIGIN)
offset = message.get_header(StreamHeaderKey.OFFSET, None)
seq = message.get_header(StreamHeaderKey.SEQUENCE, None)
# Last few ACKs always arrive late so this is normal
log.debug(f"ACK for stream {sid} received late from {origin} with offset {offset} seq {seq}")
return
tx_task.handle_ack(message)