# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import threading
from collections import deque
from typing import Callable, Deque, Dict, Optional, Tuple
from nvflare.fuel.f3.cellnet.core_cell import CoreCell
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey
from nvflare.fuel.f3.cellnet.registry import Callback, Registry
from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.connection import BytesAlike
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.stats_pool import StatsPoolManager
from nvflare.fuel.f3.streaming.stream_const import (
EOS,
STREAM_ACK_TOPIC,
STREAM_CHANNEL,
STREAM_DATA_TOPIC,
StreamDataType,
StreamHeaderKey,
)
from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture
from nvflare.fuel.f3.streaming.stream_utils import ONE_MB, stream_stats_category, stream_thread_pool
log = logging.getLogger(__name__)
MAX_OUT_SEQ_CHUNKS = 16
# 1/4 of the window size
ACK_INTERVAL = 1024 * 1024 * 4
READ_TIMEOUT = 300
COMPLETED_TASK_TTL = 60.0
RETRY_WAIT = 5.0
COUNTER_NAME_RECEIVED = "received"
# Read result status
RESULT_DATA = 0
RESULT_NO_DATA = 1
RESULT_EOS = 2
[docs]
class RxTask:
"""Receiving task for ByteStream"""
rx_task_map: Dict[Tuple[str, int], "RxTask"] = {}
map_lock = threading.Lock()
def __init__(self, sid: int, origin: str, cell: CoreCell, reliable: bool = False):
self.sid = sid
self.origin = origin
self.reliable = reliable
self.cell = cell
self.channel = None
self.topic = None
self.headers = None
self.size = 0
# The reassembled chunks in a double-ended queue
self.chunks: Deque[Tuple[bool, BytesAlike]] = deque()
self.chunk_offset = 0 # Start of the remaining data for partially read left-most chunk
# Out-of-sequence chunks to be assembled
self.out_seq_chunks: Dict[int, Tuple[bool, BytesAlike]] = {}
self.stream_future = None
self.next_seq = 0
self.offset = 0
self.received_offset = 0
self.offset_ack = 0
self.seq = -1
self.seq_ack = -1
self.waiter = threading.Event()
self.lock = threading.Lock()
self.ack_lock = threading.Lock()
self.eos = False
self.completed = False
self.failed = False
self.error = None
self.error_msg = None
self.stop_lock = threading.RLock()
self.cleanup_timer = None
config = CommConfigurator()
self.timeout = config.get_streaming_read_timeout(READ_TIMEOUT)
self.ack_interval = config.get_streaming_ack_interval(ACK_INTERVAL)
self.max_out_seq = config.get_streaming_max_out_seq_chunks(MAX_OUT_SEQ_CHUNKS)
self.completed_task_ttl = config.get_streaming_retry_timeout(
COMPLETED_TASK_TTL
) + config.get_streaming_retry_wait(RETRY_WAIT)
def __str__(self):
return f"Rx[SID:{self.sid} from {self.origin} for {self.channel}/{self.topic} Size: {self.size}]"
[docs]
@classmethod
def find_or_create_task(cls, message: Message, cell: CoreCell) -> Optional["RxTask"]:
sid = message.get_header(StreamHeaderKey.STREAM_ID)
origin = message.get_header(MessageHeaderKey.ORIGIN)
reliable = message.get_header(StreamHeaderKey.RELIABLE, False)
error = message.get_header(StreamHeaderKey.ERROR_MSG, None)
task_to_stop = None
with cls.map_lock:
task = cls.rx_task_map.get((origin, sid), None)
if not task:
if error:
log.warning(f"Received error for non-existing stream: SID {sid} from {origin}")
return None
task = RxTask(sid, origin, cell, reliable)
cls.rx_task_map[(origin, sid)] = task
else:
if error:
task_to_stop = task
if task_to_stop:
task_to_stop.stop(StreamError(f"{task_to_stop} Received error from {origin}: {error}"), notify=False)
return None
return task
[docs]
def read(self, size: int) -> BytesAlike:
count = 0
while True:
result_code, result = self._try_to_read(size)
if result_code == RESULT_EOS:
return EOS
elif result_code == RESULT_DATA:
return result
# result_code == RESULT_NO_DATA Block until chunks are received
if count > 0:
log.warning(f"{self} Read block is unblocked multiple times: {count}")
if not self.waiter.wait(self.timeout):
error = StreamError(f"{self} read timed out after {self.timeout} seconds")
self.stop(error)
raise error
count += 1
[docs]
def process_chunk(self, message: Message) -> bool:
"""Returns True if a new stream is created"""
with self.stop_lock:
failed = self.failed
error_msg = self.error_msg
completed = self.completed
if failed:
if self.reliable and error_msg:
self._send_error(error_msg)
return False
if completed:
# The task is kept in the map for the retry window only to re-ACK retried chunks
if self.reliable:
self._send_ack(self._get_ack_offset(), self.seq)
return False
new_stream = False
ack_to_send = None
stop_error = None
should_stop = False
duplicate_start = False
with self.lock:
seq = message.get_header(StreamHeaderKey.SEQUENCE)
if seq == 0:
if self.stream_future:
log.warning(f"{self} Received duplicate chunk 0, ignored")
if self.reliable:
ack_to_send = (self._get_ack_offset(), self.seq)
duplicate_start = True
else:
self._handle_new_stream(message)
new_stream = True
if not duplicate_start:
should_stop, ack_to_send, stop_error = self._handle_incoming_data(seq, message)
if ack_to_send:
self._send_ack(*ack_to_send)
if stop_error:
self.stop(stop_error)
elif should_stop:
self.stop()
return new_stream
def _handle_new_stream(self, message: Message):
self.channel = message.get_header(StreamHeaderKey.CHANNEL)
self.topic = message.get_header(StreamHeaderKey.TOPIC)
self.headers = message.headers
self.size = message.get_header(StreamHeaderKey.SIZE, 0)
retry_timeout = message.get_header(StreamHeaderKey.RETRY_TIMEOUT, None)
retry_wait = message.get_header(StreamHeaderKey.RETRY_WAIT, None)
if retry_timeout is not None and retry_wait is not None:
self.completed_task_ttl = max(self.completed_task_ttl, float(retry_timeout) + float(retry_wait))
self.stream_future = StreamFuture(self.sid, self.headers)
self.stream_future.set_size(self.size)
def _handle_incoming_data(
self, seq: int, message: Message
) -> Tuple[bool, Optional[Tuple[int, int]], Optional[StreamError]]:
data_type = message.get_header(StreamHeaderKey.DATA_TYPE)
last_chunk = data_type == StreamDataType.FINAL
ack_to_send = None
if seq < self.next_seq:
log.debug(f"{self} Duplicate chunk ignored {seq=}")
if self.reliable:
ack_to_send = (self._get_ack_offset(), self.seq)
return False, ack_to_send, None
if seq == self.next_seq:
self._append(seq, (last_chunk, message.payload))
# Try to reassemble out-of-seq chunks
while self.next_seq in self.out_seq_chunks:
chunk = self.out_seq_chunks.pop(self.next_seq)
self._append(self.next_seq, chunk)
else:
# Save out-of-seq chunks
if len(self.out_seq_chunks) >= self.max_out_seq:
return (
False,
None,
StreamError(f"{self} Too many out-of-sequence chunks: {len(self.out_seq_chunks)}"),
)
else:
if seq not in self.out_seq_chunks:
self.out_seq_chunks[seq] = last_chunk, message.payload
else:
log.warning(f"{self} Duplicate out-of-seq chunk ignored {seq=}")
if self.reliable:
ack_to_send = (self._get_ack_offset(), self.seq)
# If all chunks are lined up and last chunk received, the task can be deleted
should_stop = False
if not self.out_seq_chunks and self.chunks:
last_chunk, _ = self.chunks[-1]
if last_chunk:
should_stop = True
return should_stop, ack_to_send, None
[docs]
def stop(self, error: StreamError = None, notify=True):
if not error:
ack_to_send = None
schedule_remove = False
remove_now = False
with self.stop_lock:
if self.completed or self.failed:
return
with self.lock:
ack_offset = self.received_offset
ack_seq = self.seq
with self.ack_lock:
needs_ack = ack_seq != self.seq_ack or ack_offset > self.offset_ack
if needs_ack:
ack_to_send = (ack_offset, ack_seq)
if self.reliable:
self.completed = True
schedule_remove = True
else:
remove_now = True
if ack_to_send:
self._send_ack(*ack_to_send)
if schedule_remove:
self._schedule_remove_task()
elif remove_now:
self._remove_task()
return
schedule_remove = False
remove_now = False
with self.stop_lock:
if self.completed or self.failed:
return
# failed must be set last: _try_to_read reads it without stop_lock and
# expects error/error_msg to be populated once failed is observed
self.error = error
self.error_msg = str(error)
self.failed = True
if self.reliable:
schedule_remove = True
else:
remove_now = True
if self.headers:
optional = self.headers.get(StreamHeaderKey.OPTIONAL, False)
else:
optional = False
msg = f"Stream error: {error}"
if optional:
log.debug(msg)
else:
log.error(msg)
if self.stream_future:
self.stream_future.set_exception(error)
if not self.waiter.is_set():
self.waiter.set()
if notify:
self._send_error(str(error))
if schedule_remove:
self._schedule_remove_task()
elif remove_now:
self._remove_task()
def _send_error(self, error_msg: str):
message = Message()
message.add_headers(
{
StreamHeaderKey.STREAM_ID: self.sid,
StreamHeaderKey.DATA_TYPE: StreamDataType.ERROR,
StreamHeaderKey.ERROR_MSG: error_msg,
}
)
try:
errors = self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_ACK_TOPIC, self.origin, message)
except Exception as ex:
log.error(f"{self} failed to send error to {self.origin}: {ex}")
else:
errors = errors or {}
error = errors.get(self.origin)
if error:
log.error(f"{self} failed to send error to {self.origin}: {error}")
def _remove_task(self):
with self.stop_lock:
if self.cleanup_timer:
self.cleanup_timer.cancel()
self.cleanup_timer = None
with RxTask.map_lock:
task = RxTask.rx_task_map.get((self.origin, self.sid))
if task is self:
RxTask.rx_task_map.pop((self.origin, self.sid), None)
def _schedule_remove_task(self):
with self.stop_lock:
if self.cleanup_timer:
return
self.cleanup_timer = threading.Timer(self.completed_task_ttl, self._remove_task)
self.cleanup_timer.daemon = True
self.cleanup_timer.start()
def _try_to_read(self, size: int) -> Tuple[int, Optional[BytesAlike]]:
ack_to_send = None
with self.stop_lock:
if self.failed:
raise self.error or StreamError(self.error_msg)
with self.lock:
if self.eos:
return RESULT_EOS, None
if not self.chunks:
self.waiter.clear()
# stop(error) may have set failed and the waiter after the check at the top;
# re-check after the clear so the wakeup is not lost until the read timeout.
# stop_lock is not used here as it must not be acquired while holding self.lock.
if self.failed:
raise self.error or StreamError(self.error_msg)
return RESULT_NO_DATA, None
# Get the left most chunk
last_chunk, buf = self.chunks[0]
if buf is None:
buf = bytes(0)
end_offset = self.chunk_offset + size
if 0 < end_offset < len(buf):
# Partial read
result = buf[self.chunk_offset : end_offset]
self.chunk_offset = end_offset
final_chunk_consumed = False
else:
# Whole chunk is consumed
if self.chunk_offset:
result = buf[self.chunk_offset :]
else:
result = buf
self.chunk_offset = 0
self.chunks.popleft()
if last_chunk:
self.eos = True
final_chunk_consumed = last_chunk
self.offset += len(result)
with self.ack_lock:
ack_lag = self.offset - self.offset_ack
final_ack_needed = final_chunk_consumed and (self.offset > self.offset_ack or self.seq > self.seq_ack)
if ack_lag >= self.ack_interval or final_ack_needed:
ack_to_send = (self.offset, self.seq)
if self.stream_future:
self.stream_future.set_progress(self.offset)
if ack_to_send:
self._send_ack(*ack_to_send)
return RESULT_DATA, result
def _append(self, seq: int, buf: Tuple[bool, BytesAlike]):
if self.eos:
log.error(f"{self} Data after EOS is ignored")
return
self.chunks.append(buf)
_last_chunk, payload = buf
self.received_offset += len(payload) if payload else 0
if seq <= self.seq:
log.error(f"Sequence error: {seq} <= {self.seq}")
self.seq = seq
self.next_seq += 1
# Wake up blocking read()
if not self.waiter.is_set():
self.waiter.set()
def _get_ack_offset(self):
return self.received_offset if self.completed else self.offset
def _send_ack(self, offset, seq):
message = Message()
message.add_headers(
{
StreamHeaderKey.STREAM_ID: self.sid,
StreamHeaderKey.DATA_TYPE: StreamDataType.ACK,
StreamHeaderKey.OFFSET: offset,
StreamHeaderKey.SEQUENCE: seq,
}
)
try:
errors = self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_ACK_TOPIC, self.origin, message)
except Exception as ex:
log.error(f"{self} failed to ack seq {seq} to {self.origin}: {ex}")
return False
else:
errors = errors or {}
error = errors.get(self.origin)
if error:
log.error(f"{self} failed to ack seq {seq} to {self.origin}: {error}")
return False
with self.ack_lock:
self.offset_ack = max(self.offset_ack, offset)
self.seq_ack = max(self.seq_ack, seq)
return True
[docs]
class RxStream(Stream):
"""A stream that's used to read streams from the streaming task"""
def __init__(self, task: RxTask):
super().__init__(task.size, task.headers)
self.task = task
[docs]
def read(self, size: int) -> bytes:
if self.closed:
raise StreamError("Read from closed stream")
return self.task.read(size)
[docs]
def close(self):
if self.task.stream_future and not self.task.stream_future.done():
self.task.stream_future.set_result(self.task.offset)
self.closed = True
[docs]
class ByteReceiver:
received_stream_counter_pool = StatsPoolManager.add_counter_pool(
name="Received_Stream_Counters",
description="Counters of received streams",
counter_names=[COUNTER_NAME_RECEIVED],
)
received_stream_size_pool = StatsPoolManager.add_msg_size_pool(
"Received_Stream_Sizes", "Sizes of streams received (MBs)"
)
def __init__(self, cell: CoreCell):
self.cell = cell
self.cell.register_request_cb(channel=STREAM_CHANNEL, topic=STREAM_DATA_TOPIC, cb=self._data_handler)
self.registry = Registry()
[docs]
def register_callback(self, channel: str, topic: str, stream_cb: Callable, *args, **kwargs):
if not callable(stream_cb):
raise StreamError(f"specified stream_cb {type(stream_cb)} is not callable")
self.registry.set(channel, topic, Callback(stream_cb, args, kwargs))
def _data_handler(self, message: Message):
task = RxTask.find_or_create_task(message, self.cell)
if not task:
return
new_stream = task.process_chunk(message)
if new_stream:
# Invoke callback
callback = self.registry.find(task.channel, task.topic)
if not callback:
task.stop(StreamError(f"{task} No callback is registered for {task.channel}/{task.topic}"))
return
fqcn = self.cell.my_info.fqcn
ByteReceiver.received_stream_counter_pool.increment(
category=stream_stats_category(fqcn, task.channel, task.topic, "stream"),
counter_name=COUNTER_NAME_RECEIVED,
)
ByteReceiver.received_stream_size_pool.record_value(
category=stream_stats_category(fqcn, task.channel, task.topic, "stream"),
value=task.size / ONE_MB,
)
stream_thread_pool.submit(self._callback_wrapper, task, callback)
@staticmethod
def _callback_wrapper(task: RxTask, callback: Callback):
"""A wrapper to catch all exceptions in the callback"""
try:
stream = RxStream(task)
return callback.cb(task.stream_future, stream, False, *callback.args, **callback.kwargs)
except Exception as ex:
msg = f"{task} callback {callback.cb} throws exception: {ex}"
log.error(msg)
task.stop(StreamError(msg))