Source code for nvflare.fuel.f3.streaming.byte_receiver

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import threading
from collections import deque
from typing import Callable, Deque, Dict, Optional, Tuple

from nvflare.fuel.f3.cellnet.core_cell import CoreCell
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey
from nvflare.fuel.f3.cellnet.registry import Callback, Registry
from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.connection import BytesAlike
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.stats_pool import StatsPoolManager
from nvflare.fuel.f3.streaming.stream_const import (
    EOS,
    STREAM_ACK_TOPIC,
    STREAM_CHANNEL,
    STREAM_DATA_TOPIC,
    StreamDataType,
    StreamHeaderKey,
)
from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture
from nvflare.fuel.f3.streaming.stream_utils import ONE_MB, stream_stats_category, stream_thread_pool

log = logging.getLogger(__name__)

MAX_OUT_SEQ_CHUNKS = 16
# 1/4 of the window size
ACK_INTERVAL = 1024 * 1024 * 4
READ_TIMEOUT = 300
COUNTER_NAME_RECEIVED = "received"

# Read result status
RESULT_DATA = 0
RESULT_NO_DATA = 1
RESULT_EOS = 2


[docs] class RxTask: """Receiving task for ByteStream""" rx_task_map = {} map_lock = threading.Lock() def __init__(self, sid: int, origin: str, cell: CoreCell): self.sid = sid self.origin = origin self.cell = cell self.channel = None self.topic = None self.headers = None self.size = 0 # The reassembled chunks in a double-ended queue self.chunks: Deque[Tuple[bool, BytesAlike]] = deque() self.chunk_offset = 0 # Start of the remaining data for partially read left-most chunk # Out-of-sequence chunks to be assembled self.out_seq_chunks: Dict[int, Tuple[bool, BytesAlike]] = {} self.stream_future = None self.next_seq = 0 self.offset = 0 self.offset_ack = 0 self.waiter = threading.Event() self.lock = threading.Lock() self.eos = False self.last_chunk_received = False self.timeout = CommConfigurator().get_streaming_read_timeout(READ_TIMEOUT) self.ack_interval = CommConfigurator().get_streaming_ack_interval(ACK_INTERVAL) self.max_out_seq = CommConfigurator().get_streaming_max_out_seq_chunks(MAX_OUT_SEQ_CHUNKS) def __str__(self): return f"Rx[SID:{self.sid} from {self.origin} for {self.channel}/{self.topic} Size: {self.size}]"
[docs] @classmethod def find_or_create_task(cls, message: Message, cell: CoreCell) -> Optional["RxTask"]: sid = message.get_header(StreamHeaderKey.STREAM_ID) origin = message.get_header(MessageHeaderKey.ORIGIN) error = message.get_header(StreamHeaderKey.ERROR_MSG, None) with cls.map_lock: task = cls.rx_task_map.get(sid, None) if not task: if error: log.warning(f"Received error for non-existing stream: SID {sid} from {origin}") return None task = RxTask(sid, origin, cell) cls.rx_task_map[sid] = task else: if error: task.stop(StreamError(f"{task} Received error from {origin}: {error}"), notify=False) return None return task
[docs] def read(self, size: int) -> BytesAlike: count = 0 while True: result_code, result = self._try_to_read(size) if result_code == RESULT_EOS: return EOS elif result_code == RESULT_DATA: return result # result_code == RESULT_NO_DATA Block until chunks are received if count > 0: log.warning(f"{self} Read block is unblocked multiple times: {count}") if not self.waiter.wait(self.timeout): error = StreamError(f"{self} read timed out after {self.timeout} seconds") self.stop(error) raise error count += 1
[docs] def process_chunk(self, message: Message) -> bool: """Returns True if a new stream is created""" new_stream = False with self.lock: seq = message.get_header(StreamHeaderKey.SEQUENCE) if seq == 0: if self.stream_future: log.warning(f"{self} Received duplicate chunk 0, ignored") return new_stream self._handle_new_stream(message) new_stream = True self._handle_incoming_data(seq, message) return new_stream
def _handle_new_stream(self, message: Message): self.channel = message.get_header(StreamHeaderKey.CHANNEL) self.topic = message.get_header(StreamHeaderKey.TOPIC) self.headers = message.headers self.size = message.get_header(StreamHeaderKey.SIZE, 0) self.stream_future = StreamFuture(self.sid, self.headers) self.stream_future.set_size(self.size) def _handle_incoming_data(self, seq: int, message: Message): data_type = message.get_header(StreamHeaderKey.DATA_TYPE) last_chunk = data_type == StreamDataType.FINAL if last_chunk: self.last_chunk_received = True if seq < self.next_seq: log.warning(f"{self} Duplicate chunk ignored {seq=}") return if seq == self.next_seq: self._append((last_chunk, message.payload)) # Try to reassemble out-of-seq chunks while self.next_seq in self.out_seq_chunks: chunk = self.out_seq_chunks.pop(self.next_seq) self._append(chunk) else: # Save out-of-seq chunks if len(self.out_seq_chunks) >= self.max_out_seq: self.stop(StreamError(f"{self} Too many out-of-sequence chunks: {len(self.out_seq_chunks)}")) return else: if seq not in self.out_seq_chunks: self.out_seq_chunks[seq] = last_chunk, message.payload else: log.warning(f"{self} Duplicate out-of-seq chunk ignored {seq=}") # If all chunks are lined up and last chunk received, the task can be deleted if not self.out_seq_chunks and self.chunks: last_chunk, _ = self.chunks[-1] if last_chunk: self.stop()
[docs] def stop(self, error: StreamError = None, notify=True): with RxTask.map_lock: RxTask.rx_task_map.pop(self.sid, None) if not error: return if self.headers: optional = self.headers.get(StreamHeaderKey.OPTIONAL, False) else: optional = False msg = f"Stream error: {error}" if optional: log.debug(msg) else: log.error(msg) self.stream_future.set_exception(error) if notify: message = Message() message.add_headers( { StreamHeaderKey.STREAM_ID: self.sid, StreamHeaderKey.DATA_TYPE: StreamDataType.ERROR, StreamHeaderKey.ERROR_MSG: str(error), } ) self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_ACK_TOPIC, self.origin, message)
def _try_to_read(self, size: int) -> Tuple[int, Optional[BytesAlike]]: with self.lock: if self.eos: return RESULT_EOS, None if not self.chunks: self.waiter.clear() return RESULT_NO_DATA, None # Get the left most chunk last_chunk, buf = self.chunks[0] if buf is None: buf = bytes(0) end_offset = self.chunk_offset + size if 0 < end_offset < len(buf): # Partial read result = buf[self.chunk_offset : end_offset] self.chunk_offset = end_offset else: # Whole chunk is consumed if self.chunk_offset: result = buf[self.chunk_offset :] else: result = buf self.chunk_offset = 0 self.chunks.popleft() if last_chunk: self.eos = True self.offset += len(result) if not self.last_chunk_received and (self.offset - self.offset_ack > self.ack_interval): # Send ACK message = Message() message.add_headers( { StreamHeaderKey.STREAM_ID: self.sid, StreamHeaderKey.DATA_TYPE: StreamDataType.ACK, StreamHeaderKey.OFFSET: self.offset, } ) self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_ACK_TOPIC, self.origin, message) self.offset_ack = self.offset self.stream_future.set_progress(self.offset) return RESULT_DATA, result def _append(self, buf: Tuple[bool, BytesAlike]): if self.eos: log.error(f"{self} Data after EOS is ignored") return self.chunks.append(buf) self.next_seq += 1 # Wake up blocking read() if not self.waiter.is_set(): self.waiter.set()
[docs] class RxStream(Stream): """A stream that's used to read streams from the streaming task""" def __init__(self, task: RxTask): super().__init__(task.size, task.headers) self.task = task
[docs] def read(self, size: int) -> bytes: if self.closed: raise StreamError("Read from closed stream") return self.task.read(size)
[docs] def close(self): if not self.task.stream_future.done(): self.task.stream_future.set_result(self.task.offset) self.closed = True
[docs] class ByteReceiver: def __init__(self, cell: CoreCell): self.cell = cell self.cell.register_request_cb(channel=STREAM_CHANNEL, topic=STREAM_DATA_TOPIC, cb=self._data_handler) self.registry = Registry() self.received_stream_counter_pool = StatsPoolManager.add_counter_pool( name="Received_Stream_Counters", description="Counters of received streams", counter_names=[COUNTER_NAME_RECEIVED], scope=self.cell.my_info.fqcn, ) self.received_stream_size_pool = StatsPoolManager.add_msg_size_pool( "Received_Stream_Sizes", "Sizes of streams received (MBs)", scope=self.cell.my_info.fqcn )
[docs] def register_callback(self, channel: str, topic: str, stream_cb: Callable, *args, **kwargs): if not callable(stream_cb): raise StreamError(f"specified stream_cb {type(stream_cb)} is not callable") self.registry.set(channel, topic, Callback(stream_cb, args, kwargs))
def _data_handler(self, message: Message): task = RxTask.find_or_create_task(message, self.cell) if not task: return new_stream = task.process_chunk(message) if new_stream: # Invoke callback callback = self.registry.find(task.channel, task.topic) if not callback: task.stop(StreamError(f"{task} No callback is registered for {task.channel}/{task.topic}")) return self.received_stream_counter_pool.increment( category=stream_stats_category(task.channel, task.topic, "stream"), counter_name=COUNTER_NAME_RECEIVED, ) self.received_stream_size_pool.record_value( category=stream_stats_category(task.channel, task.topic, "stream"), value=task.size / ONE_MB ) stream_thread_pool.submit(self._callback_wrapper, task, callback) @staticmethod def _callback_wrapper(task: RxTask, callback: Callback): """A wrapper to catch all exceptions in the callback""" try: stream = RxStream(task) return callback.cb(task.stream_future, stream, False, *callback.args, **callback.kwargs) except Exception as ex: msg = f"{task} callback {callback.cb} throws exception: {ex}" log.error(msg) task.stop(StreamError(msg))