Source code for nvflare.fuel.f3.streaming.byte_streamer

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import threading
from typing import Optional

from nvflare.fuel.f3.cellnet.core_cell import CoreCell
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey
from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.stats_pool import StatsPoolManager
from nvflare.fuel.f3.streaming.stream_const import (
    STREAM_ACK_TOPIC,
    STREAM_CHANNEL,
    STREAM_DATA_TOPIC,
    StreamDataType,
    StreamHeaderKey,
)
from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture
from nvflare.fuel.f3.streaming.stream_utils import (
    ONE_MB,
    gen_stream_id,
    stream_stats_category,
    stream_thread_pool,
    wrap_view,
)

STREAM_CHUNK_SIZE = 1024 * 1024
STREAM_WINDOW_SIZE = 16 * STREAM_CHUNK_SIZE
STREAM_ACK_WAIT = 60

STREAM_TYPE_BYTE = "byte"
STREAM_TYPE_BLOB = "blob"
STREAM_TYPE_FILE = "file"

COUNTER_NAME_SENT = "sent"

log = logging.getLogger(__name__)


[docs]class TxTask: def __init__( self, channel: str, topic: str, target: str, headers: dict, stream: Stream, secure: bool, optional: bool ): self.sid = gen_stream_id() self.buffer = bytearray(ByteStreamer.get_chunk_size()) # Optimization to send the original buffer without copying self.direct_buf: Optional[bytes] = None self.buffer_size = 0 self.channel = channel self.topic = topic self.target = target self.headers = headers self.stream = stream self.stream_future = None self.task_future = None self.ack_waiter = threading.Event() self.seq = 0 self.offset = 0 self.offset_ack = 0 self.secure = secure self.optional = optional def __str__(self): return f"Tx[SID:{self.sid} to {self.target} for {self.channel}/{self.topic}]"
[docs]class ByteStreamer: def __init__(self, cell: CoreCell): self.cell = cell self.cell.register_request_cb(channel=STREAM_CHANNEL, topic=STREAM_ACK_TOPIC, cb=self._ack_handler) self.tx_task_map = {} self.map_lock = threading.Lock() self.sent_stream_counter_pool = StatsPoolManager.add_counter_pool( name="Sent_Stream_Counters", description="Counters of sent streams", counter_names=[COUNTER_NAME_SENT], scope=self.cell.my_info.fqcn, ) self.sent_stream_size_pool = StatsPoolManager.add_msg_size_pool( "Sent_Stream_Sizes", "Sizes of streams sent (MBs)", scope=self.cell.my_info.fqcn )
[docs] @staticmethod def get_chunk_size(): return CommConfigurator().get_streaming_chunk_size(STREAM_CHUNK_SIZE)
[docs] def send( self, channel: str, topic: str, target: str, headers: dict, stream: Stream, stream_type=STREAM_TYPE_BYTE, secure=False, optional=False, ) -> StreamFuture: tx_task = TxTask(channel, topic, target, headers, stream, secure, optional) with self.map_lock: self.tx_task_map[tx_task.sid] = tx_task future = StreamFuture(tx_task.sid) future.set_size(stream.get_size()) tx_task.stream_future = future tx_task.task_future = stream_thread_pool.submit(self._transmit_task, tx_task) self.sent_stream_counter_pool.increment( category=stream_stats_category(channel, topic, stream_type), counter_name=COUNTER_NAME_SENT ) self.sent_stream_size_pool.record_value( category=stream_stats_category(channel, topic, stream_type), value=stream.get_size() / ONE_MB ) return future
def _transmit_task(self, task: TxTask): chunk_size = self.get_chunk_size() while True: buf = task.stream.read(chunk_size) if not buf: # End of Stream self._transmit(task, final=True) self._stop_task(task) return # Flow control window = task.offset - task.offset_ack # It may take several ACKs to clear up the window window_size = CommConfigurator().get_streaming_window_size(STREAM_WINDOW_SIZE) while window > window_size: log.debug(f"{task} window size {window} exceeds limit: {window_size}") task.ack_waiter.clear() ack_wait = CommConfigurator().get_streaming_ack_wait(STREAM_ACK_WAIT) if not task.ack_waiter.wait(timeout=ack_wait): self._stop_task(task, StreamError(f"{task} ACK timeouts after {ack_wait} seconds")) return window = task.offset - task.offset_ack size = len(buf) if size > chunk_size: raise StreamError(f"Stream returns invalid size: {size} for {task}") if size + task.buffer_size > chunk_size: self._transmit(task) if size == chunk_size: task.direct_buf = buf else: task.buffer[task.buffer_size : task.buffer_size + size] = buf task.buffer_size += size def _transmit(self, task: TxTask, final=False): if task.buffer_size == 0: payload = bytes(0) elif task.buffer_size == self.get_chunk_size(): if task.direct_buf: payload = task.direct_buf else: payload = task.buffer else: payload = wrap_view(task.buffer)[0 : task.buffer_size] message = Message(None, payload) if task.headers: message.add_headers(task.headers) message.add_headers( { StreamHeaderKey.CHANNEL: task.channel, StreamHeaderKey.TOPIC: task.topic, StreamHeaderKey.SIZE: task.stream.get_size(), StreamHeaderKey.STREAM_ID: task.sid, StreamHeaderKey.DATA_TYPE: StreamDataType.FINAL if final else StreamDataType.CHUNK, StreamHeaderKey.SEQUENCE: task.seq, StreamHeaderKey.OFFSET: task.offset, } ) errors = self.cell.fire_and_forget( STREAM_CHANNEL, STREAM_DATA_TOPIC, task.target, message, secure=task.secure, optional=task.optional ) error = errors.get(task.target) if error: msg = f"Message sending error to target {task.target}: {error}" log.debug(msg) self._stop_task(task, StreamError(msg)) return # Update state task.seq += 1 task.offset += task.buffer_size task.buffer_size = 0 task.direct_buf = None # Update future task.stream_future.set_progress(task.offset) def _stop_task(self, task: TxTask, error: StreamError = None, notify=True): with self.map_lock: self.tx_task_map.pop(task.sid, None) if error: log.debug(f"Stream error: {error}") if task.stream_future: task.stream_future.set_exception(error) if notify: message = Message(None, None) if task.headers: message.add_headers(task.headers) message.add_headers( { StreamHeaderKey.STREAM_ID: task.sid, StreamHeaderKey.DATA_TYPE: StreamDataType.ERROR, StreamHeaderKey.OFFSET: task.offset, StreamHeaderKey.ERROR_MSG: str(error), } ) self.cell.fire_and_forget( STREAM_CHANNEL, STREAM_DATA_TOPIC, task.target, message, secure=task.secure, optional=True ) else: # Result is the number of bytes streamed if task.stream_future: task.stream_future.set_result(task.offset) def _ack_handler(self, message: Message): origin = message.get_header(MessageHeaderKey.ORIGIN) sid = message.get_header(StreamHeaderKey.STREAM_ID) offset = message.get_header(StreamHeaderKey.OFFSET, None) with self.map_lock: task = self.tx_task_map.get(sid, None) if not task: # Last few ACKs always arrive late so this is normal log.debug(f"ACK for stream {sid} received late from {origin} with offset {offset}") return error = message.get_header(StreamHeaderKey.ERROR_MSG, None) if error: self._stop_task(task, StreamError(f"Received error from {origin}: {error}"), notify=False) return if offset > task.offset_ack: task.offset_ack = offset if not task.ack_waiter.is_set(): task.ack_waiter.set()