Source code for nvflare.fuel.f3.streaming.byte_receiver

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import threading
from collections import deque
from typing import Callable, Deque, Dict, Optional, Tuple

from nvflare.fuel.f3.cellnet.core_cell import CoreCell
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey
from nvflare.fuel.f3.cellnet.registry import Callback, Registry
from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.connection import BytesAlike
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.stats_pool import StatsPoolManager
from nvflare.fuel.f3.streaming.stream_const import (
    EOS,
    STREAM_ACK_TOPIC,
    STREAM_CHANNEL,
    STREAM_DATA_TOPIC,
    StreamDataType,
    StreamHeaderKey,
)
from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture
from nvflare.fuel.f3.streaming.stream_utils import ONE_MB, stream_stats_category, stream_thread_pool

log = logging.getLogger(__name__)

MAX_OUT_SEQ_CHUNKS = 16
# 1/4 of the window size
ACK_INTERVAL = 1024 * 1024 * 4
READ_TIMEOUT = 300
COMPLETED_TASK_TTL = 60.0
RETRY_WAIT = 5.0
COUNTER_NAME_RECEIVED = "received"

# Read result status
RESULT_DATA = 0
RESULT_NO_DATA = 1
RESULT_EOS = 2


[docs] class RxTask: """Receiving task for ByteStream""" rx_task_map: Dict[Tuple[str, int], "RxTask"] = {} map_lock = threading.Lock() def __init__(self, sid: int, origin: str, cell: CoreCell, reliable: bool = False): self.sid = sid self.origin = origin self.reliable = reliable self.cell = cell self.channel = None self.topic = None self.headers = None self.size = 0 # The reassembled chunks in a double-ended queue self.chunks: Deque[Tuple[bool, BytesAlike]] = deque() self.chunk_offset = 0 # Start of the remaining data for partially read left-most chunk # Out-of-sequence chunks to be assembled self.out_seq_chunks: Dict[int, Tuple[bool, BytesAlike]] = {} self.stream_future = None self.next_seq = 0 self.offset = 0 self.received_offset = 0 self.offset_ack = 0 self.seq = -1 self.seq_ack = -1 self.waiter = threading.Event() self.lock = threading.Lock() self.ack_lock = threading.Lock() self.eos = False self.completed = False self.failed = False self.error = None self.error_msg = None self.stop_lock = threading.RLock() self.cleanup_timer = None config = CommConfigurator() self.timeout = config.get_streaming_read_timeout(READ_TIMEOUT) self.ack_interval = config.get_streaming_ack_interval(ACK_INTERVAL) self.max_out_seq = config.get_streaming_max_out_seq_chunks(MAX_OUT_SEQ_CHUNKS) self.completed_task_ttl = config.get_streaming_retry_timeout( COMPLETED_TASK_TTL ) + config.get_streaming_retry_wait(RETRY_WAIT) def __str__(self): return f"Rx[SID:{self.sid} from {self.origin} for {self.channel}/{self.topic} Size: {self.size}]"
[docs] @classmethod def find_or_create_task(cls, message: Message, cell: CoreCell) -> Optional["RxTask"]: sid = message.get_header(StreamHeaderKey.STREAM_ID) origin = message.get_header(MessageHeaderKey.ORIGIN) reliable = message.get_header(StreamHeaderKey.RELIABLE, False) error = message.get_header(StreamHeaderKey.ERROR_MSG, None) task_to_stop = None with cls.map_lock: task = cls.rx_task_map.get((origin, sid), None) if not task: if error: log.warning(f"Received error for non-existing stream: SID {sid} from {origin}") return None task = RxTask(sid, origin, cell, reliable) cls.rx_task_map[(origin, sid)] = task else: if error: task_to_stop = task if task_to_stop: task_to_stop.stop(StreamError(f"{task_to_stop} Received error from {origin}: {error}"), notify=False) return None return task
[docs] def read(self, size: int) -> BytesAlike: count = 0 while True: result_code, result = self._try_to_read(size) if result_code == RESULT_EOS: return EOS elif result_code == RESULT_DATA: return result # result_code == RESULT_NO_DATA Block until chunks are received if count > 0: log.warning(f"{self} Read block is unblocked multiple times: {count}") if not self.waiter.wait(self.timeout): error = StreamError(f"{self} read timed out after {self.timeout} seconds") self.stop(error) raise error count += 1
[docs] def process_chunk(self, message: Message) -> bool: """Returns True if a new stream is created""" with self.stop_lock: failed = self.failed error_msg = self.error_msg completed = self.completed if failed: if self.reliable and error_msg: self._send_error(error_msg) return False if completed: # The task is kept in the map for the retry window only to re-ACK retried chunks if self.reliable: self._send_ack(self._get_ack_offset(), self.seq) return False new_stream = False ack_to_send = None stop_error = None should_stop = False duplicate_start = False with self.lock: seq = message.get_header(StreamHeaderKey.SEQUENCE) if seq == 0: if self.stream_future: log.warning(f"{self} Received duplicate chunk 0, ignored") if self.reliable: ack_to_send = (self._get_ack_offset(), self.seq) duplicate_start = True else: self._handle_new_stream(message) new_stream = True if not duplicate_start: should_stop, ack_to_send, stop_error = self._handle_incoming_data(seq, message) if ack_to_send: self._send_ack(*ack_to_send) if stop_error: self.stop(stop_error) elif should_stop: self.stop() return new_stream
def _handle_new_stream(self, message: Message): self.channel = message.get_header(StreamHeaderKey.CHANNEL) self.topic = message.get_header(StreamHeaderKey.TOPIC) self.headers = message.headers self.size = message.get_header(StreamHeaderKey.SIZE, 0) retry_timeout = message.get_header(StreamHeaderKey.RETRY_TIMEOUT, None) retry_wait = message.get_header(StreamHeaderKey.RETRY_WAIT, None) if retry_timeout is not None and retry_wait is not None: self.completed_task_ttl = max(self.completed_task_ttl, float(retry_timeout) + float(retry_wait)) self.stream_future = StreamFuture(self.sid, self.headers) self.stream_future.set_size(self.size) def _handle_incoming_data( self, seq: int, message: Message ) -> Tuple[bool, Optional[Tuple[int, int]], Optional[StreamError]]: data_type = message.get_header(StreamHeaderKey.DATA_TYPE) last_chunk = data_type == StreamDataType.FINAL ack_to_send = None if seq < self.next_seq: log.debug(f"{self} Duplicate chunk ignored {seq=}") if self.reliable: ack_to_send = (self._get_ack_offset(), self.seq) return False, ack_to_send, None if seq == self.next_seq: self._append(seq, (last_chunk, message.payload)) # Try to reassemble out-of-seq chunks while self.next_seq in self.out_seq_chunks: chunk = self.out_seq_chunks.pop(self.next_seq) self._append(self.next_seq, chunk) else: # Save out-of-seq chunks if len(self.out_seq_chunks) >= self.max_out_seq: return ( False, None, StreamError(f"{self} Too many out-of-sequence chunks: {len(self.out_seq_chunks)}"), ) else: if seq not in self.out_seq_chunks: self.out_seq_chunks[seq] = last_chunk, message.payload else: log.warning(f"{self} Duplicate out-of-seq chunk ignored {seq=}") if self.reliable: ack_to_send = (self._get_ack_offset(), self.seq) # If all chunks are lined up and last chunk received, the task can be deleted should_stop = False if not self.out_seq_chunks and self.chunks: last_chunk, _ = self.chunks[-1] if last_chunk: should_stop = True return should_stop, ack_to_send, None
[docs] def stop(self, error: StreamError = None, notify=True): if not error: ack_to_send = None schedule_remove = False remove_now = False with self.stop_lock: if self.completed or self.failed: return with self.lock: ack_offset = self.received_offset ack_seq = self.seq with self.ack_lock: needs_ack = ack_seq != self.seq_ack or ack_offset > self.offset_ack if needs_ack: ack_to_send = (ack_offset, ack_seq) if self.reliable: self.completed = True schedule_remove = True else: remove_now = True if ack_to_send: self._send_ack(*ack_to_send) if schedule_remove: self._schedule_remove_task() elif remove_now: self._remove_task() return schedule_remove = False remove_now = False with self.stop_lock: if self.completed or self.failed: return # failed must be set last: _try_to_read reads it without stop_lock and # expects error/error_msg to be populated once failed is observed self.error = error self.error_msg = str(error) self.failed = True if self.reliable: schedule_remove = True else: remove_now = True if self.headers: optional = self.headers.get(StreamHeaderKey.OPTIONAL, False) else: optional = False msg = f"Stream error: {error}" if optional: log.debug(msg) else: log.error(msg) if self.stream_future: self.stream_future.set_exception(error) if not self.waiter.is_set(): self.waiter.set() if notify: self._send_error(str(error)) if schedule_remove: self._schedule_remove_task() elif remove_now: self._remove_task()
def _send_error(self, error_msg: str): message = Message() message.add_headers( { StreamHeaderKey.STREAM_ID: self.sid, StreamHeaderKey.DATA_TYPE: StreamDataType.ERROR, StreamHeaderKey.ERROR_MSG: error_msg, } ) try: errors = self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_ACK_TOPIC, self.origin, message) except Exception as ex: log.error(f"{self} failed to send error to {self.origin}: {ex}") else: errors = errors or {} error = errors.get(self.origin) if error: log.error(f"{self} failed to send error to {self.origin}: {error}") def _remove_task(self): with self.stop_lock: if self.cleanup_timer: self.cleanup_timer.cancel() self.cleanup_timer = None with RxTask.map_lock: task = RxTask.rx_task_map.get((self.origin, self.sid)) if task is self: RxTask.rx_task_map.pop((self.origin, self.sid), None) def _schedule_remove_task(self): with self.stop_lock: if self.cleanup_timer: return self.cleanup_timer = threading.Timer(self.completed_task_ttl, self._remove_task) self.cleanup_timer.daemon = True self.cleanup_timer.start() def _try_to_read(self, size: int) -> Tuple[int, Optional[BytesAlike]]: ack_to_send = None with self.stop_lock: if self.failed: raise self.error or StreamError(self.error_msg) with self.lock: if self.eos: return RESULT_EOS, None if not self.chunks: self.waiter.clear() # stop(error) may have set failed and the waiter after the check at the top; # re-check after the clear so the wakeup is not lost until the read timeout. # stop_lock is not used here as it must not be acquired while holding self.lock. if self.failed: raise self.error or StreamError(self.error_msg) return RESULT_NO_DATA, None # Get the left most chunk last_chunk, buf = self.chunks[0] if buf is None: buf = bytes(0) end_offset = self.chunk_offset + size if 0 < end_offset < len(buf): # Partial read result = buf[self.chunk_offset : end_offset] self.chunk_offset = end_offset final_chunk_consumed = False else: # Whole chunk is consumed if self.chunk_offset: result = buf[self.chunk_offset :] else: result = buf self.chunk_offset = 0 self.chunks.popleft() if last_chunk: self.eos = True final_chunk_consumed = last_chunk self.offset += len(result) with self.ack_lock: ack_lag = self.offset - self.offset_ack final_ack_needed = final_chunk_consumed and (self.offset > self.offset_ack or self.seq > self.seq_ack) if ack_lag >= self.ack_interval or final_ack_needed: ack_to_send = (self.offset, self.seq) if self.stream_future: self.stream_future.set_progress(self.offset) if ack_to_send: self._send_ack(*ack_to_send) return RESULT_DATA, result def _append(self, seq: int, buf: Tuple[bool, BytesAlike]): if self.eos: log.error(f"{self} Data after EOS is ignored") return self.chunks.append(buf) _last_chunk, payload = buf self.received_offset += len(payload) if payload else 0 if seq <= self.seq: log.error(f"Sequence error: {seq} <= {self.seq}") self.seq = seq self.next_seq += 1 # Wake up blocking read() if not self.waiter.is_set(): self.waiter.set() def _get_ack_offset(self): return self.received_offset if self.completed else self.offset def _send_ack(self, offset, seq): message = Message() message.add_headers( { StreamHeaderKey.STREAM_ID: self.sid, StreamHeaderKey.DATA_TYPE: StreamDataType.ACK, StreamHeaderKey.OFFSET: offset, StreamHeaderKey.SEQUENCE: seq, } ) try: errors = self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_ACK_TOPIC, self.origin, message) except Exception as ex: log.error(f"{self} failed to ack seq {seq} to {self.origin}: {ex}") return False else: errors = errors or {} error = errors.get(self.origin) if error: log.error(f"{self} failed to ack seq {seq} to {self.origin}: {error}") return False with self.ack_lock: self.offset_ack = max(self.offset_ack, offset) self.seq_ack = max(self.seq_ack, seq) return True
[docs] class RxStream(Stream): """A stream that's used to read streams from the streaming task""" def __init__(self, task: RxTask): super().__init__(task.size, task.headers) self.task = task
[docs] def read(self, size: int) -> bytes: if self.closed: raise StreamError("Read from closed stream") return self.task.read(size)
[docs] def close(self): if self.task.stream_future and not self.task.stream_future.done(): self.task.stream_future.set_result(self.task.offset) self.closed = True
[docs] class ByteReceiver: received_stream_counter_pool = StatsPoolManager.add_counter_pool( name="Received_Stream_Counters", description="Counters of received streams", counter_names=[COUNTER_NAME_RECEIVED], ) received_stream_size_pool = StatsPoolManager.add_msg_size_pool( "Received_Stream_Sizes", "Sizes of streams received (MBs)" ) def __init__(self, cell: CoreCell): self.cell = cell self.cell.register_request_cb(channel=STREAM_CHANNEL, topic=STREAM_DATA_TOPIC, cb=self._data_handler) self.registry = Registry()
[docs] def register_callback(self, channel: str, topic: str, stream_cb: Callable, *args, **kwargs): if not callable(stream_cb): raise StreamError(f"specified stream_cb {type(stream_cb)} is not callable") self.registry.set(channel, topic, Callback(stream_cb, args, kwargs))
def _data_handler(self, message: Message): task = RxTask.find_or_create_task(message, self.cell) if not task: return new_stream = task.process_chunk(message) if new_stream: # Invoke callback callback = self.registry.find(task.channel, task.topic) if not callback: task.stop(StreamError(f"{task} No callback is registered for {task.channel}/{task.topic}")) return fqcn = self.cell.my_info.fqcn ByteReceiver.received_stream_counter_pool.increment( category=stream_stats_category(fqcn, task.channel, task.topic, "stream"), counter_name=COUNTER_NAME_RECEIVED, ) ByteReceiver.received_stream_size_pool.record_value( category=stream_stats_category(fqcn, task.channel, task.topic, "stream"), value=task.size / ONE_MB, ) stream_thread_pool.submit(self._callback_wrapper, task, callback) @staticmethod def _callback_wrapper(task: RxTask, callback: Callback): """A wrapper to catch all exceptions in the callback""" try: stream = RxStream(task) return callback.cb(task.stream_future, stream, False, *callback.args, **callback.kwargs) except Exception as ex: msg = f"{task} callback {callback.cb} throws exception: {ex}" log.error(msg) task.stop(StreamError(msg))