Source code for nvflare.fuel.f3.streaming.blob_streamer

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import threading
from typing import Callable, Optional

from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.connection import BytesAlike
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.streaming.byte_receiver import ByteReceiver
from nvflare.fuel.f3.streaming.byte_streamer import STREAM_CHUNK_SIZE, STREAM_TYPE_BLOB, ByteStreamer
from nvflare.fuel.f3.streaming.stream_const import EOS
from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture
from nvflare.fuel.f3.streaming.stream_utils import FastBuffer, callback_thread_pool, stream_thread_pool, wrap_view
from nvflare.fuel.utils.buffer_list import BufferList
from nvflare.security.logging import secure_format_traceback

log = logging.getLogger(__name__)


[docs] class BlobStream(Stream): def __init__(self, blob: BytesAlike, headers: Optional[dict]): size = self.buffer_len(blob) super().__init__(size, headers) if not isinstance(blob, list): self.blob_view = wrap_view(blob) self.buffer_list = None else: self.blob_view = [wrap_view(b) for b in blob] self.buffer_list = BufferList(self.blob_view)
[docs] def read(self, chunk_size: int) -> BytesAlike: if self.pos >= self.get_size(): return EOS next_pos = self.pos + chunk_size if next_pos > self.get_size(): next_pos = self.get_size() if self.buffer_list: buf = self.buffer_list.read(self.pos, next_pos) else: buf = self.blob_view[self.pos : next_pos] self.pos = next_pos return buf
[docs] @staticmethod def buffer_len(buffer: BytesAlike): if not isinstance(buffer, list): return len(buffer) return sum(len(buf) for buf in buffer)
[docs] class BlobTask: def __init__(self, future: StreamFuture, stream: Stream, max_size: int = 0): self.future = future self.stream = stream self.size = stream.get_size() self.max_size = max_size if self.size < 0: raise StreamError(f"Declared blob size cannot be negative: {self.size}") if self.max_size > 0 and self.size > self.max_size: raise StreamError(f"Declared blob size {self.size} exceeds configured limit {self.max_size}") self.pre_allocated = self.size > 0 if self.pre_allocated: self.buffer = wrap_view(bytearray(self.size)) else: self.buffer = FastBuffer() def __str__(self): return f"Blob[SID:{self.future.get_stream_id()} Size:{self.size}]"
[docs] class BlobHandler: def __init__(self, blob_cb: Callable): self.blob_cb = blob_cb config = CommConfigurator() self.chunk_size = config.get_streaming_chunk_size(STREAM_CHUNK_SIZE) self.max_blob_size = config.get_streaming_max_blob_size() @staticmethod def _fail(stream: Stream, future: StreamFuture, error: StreamError): if hasattr(stream, "task"): stream.task.stop(error) else: future.set_exception(error) def _store_chunk(self, blob_task: BlobTask, buf: BytesAlike, buf_size: int, thread_id: int) -> bool: length = len(buf) if blob_task.pre_allocated: return self._store_pre_allocated_chunk(blob_task, buf, buf_size, length, thread_id) return self._append_dynamic_chunk(blob_task, buf, buf_size, length, thread_id) def _store_pre_allocated_chunk( self, blob_task: BlobTask, buf: BytesAlike, buf_size: int, length: int, thread_id: int ) -> bool: remaining = len(blob_task.buffer) - buf_size if length > remaining: log.error(f"{blob_task} Buffer overrun: {thread_id=} {remaining=} {length=} {buf_size=}") self._fail( blob_task.stream, blob_task.future, StreamError(f"Buffer overrun: stream produced more data than declared size {blob_task.size}"), ) return False blob_task.buffer[buf_size : buf_size + length] = buf return True def _append_dynamic_chunk( self, blob_task: BlobTask, buf: BytesAlike, buf_size: int, length: int, thread_id: int ) -> bool: next_size = buf_size + length # read() already pulled this chunk, so rejection can overshoot by one chunk at most. if blob_task.max_size > 0 and next_size > blob_task.max_size: log.error(f"{blob_task} Size limit exceeded: {thread_id=} {next_size=} limit={blob_task.max_size}") error = StreamError( f"Blob received more data than configured limit {blob_task.max_size}: " f"received at least {next_size} bytes" ) self._fail(blob_task.stream, blob_task.future, error) return False blob_task.buffer.append(buf) return True
[docs] def handle_blob_cb(self, future: StreamFuture, stream: Stream, resume: bool, *args, **kwargs) -> int: if resume: log.warning("Resume is not supported, ignored") try: blob_task = BlobTask(future, stream, self.max_blob_size) except StreamError as ex: self._fail(stream, future, ex) return 0 except MemoryError as ex: error = StreamError(f"Unable to allocate buffer for declared blob size {stream.get_size()}") error.__cause__ = ex self._fail(stream, future, error) return 0 stream_thread_pool.submit(self._read_stream, blob_task) callback_thread_pool.submit(self._run_blob_cb, future, stream, args, kwargs) return 0
def _run_blob_cb(self, future: StreamFuture, stream: Stream, args: tuple, kwargs: dict): """Run blob_cb on the callback pool; preserve exception handling (log + task.stop) as in ByteReceiver.""" try: self.blob_cb(future, *args, **kwargs) except Exception as ex: # Suppress only when blob_cb is surfacing an already-recorded stream # failure (for example by calling future.result()). If blob_cb fails # after the future completed successfully, we still need to stop the # task so the sender receives the callback error. with future.lock: already_failed = future.error is not None if already_failed: kind = "StreamError" if isinstance(ex, StreamError) else "Exception" log.debug(f"{kind} from blob_cb suppressed; future already failed: {ex}") else: log.error(f"blob_cb threw: {ex}\n{secure_format_traceback()}") if hasattr(stream, "task"): stream.task.stop(StreamError(f"blob_cb threw {type(ex).__name__}: {ex}")) def _read_stream(self, blob_task: BlobTask): try: # It's most efficient to read the whole chunk size = self.chunk_size thread_id = threading.get_native_id() buf_size = 0 while True: buf = blob_task.stream.read(size) if not buf: break try: if not self._store_chunk(blob_task, buf, buf_size, thread_id): return except Exception as ex: length = len(buf) log.error( f"{blob_task} memoryview error: {ex} Debug info: " f"{thread_id=} {length=} {buf_size=} {type(buf)=}" ) raise ex buf_size += len(buf) if blob_task.size and blob_task.size != buf_size: blob_task.future.set_exception( StreamError(f"Size mismatch: declared {blob_task.size} but received {buf_size} bytes") ) return if blob_task.pre_allocated: result = blob_task.buffer else: result = blob_task.buffer.to_bytes() blob_task.future.set_result(result) except Exception as ex: log.error(f"Stream {blob_task} Read error: {ex}") log.error(secure_format_traceback()) blob_task.future.set_exception(ex)
[docs] class BlobStreamer: def __init__(self, byte_streamer: ByteStreamer, byte_receiver: ByteReceiver): self.byte_streamer = byte_streamer self.byte_receiver = byte_receiver
[docs] def send( self, channel: str, topic: str, target: str, message: Message, secure: bool, optional: bool ) -> StreamFuture: if message.payload is None: message.payload = bytes(0) if not isinstance(message.payload, (bytes, bytearray, memoryview, list)): raise StreamError(f"BLOB is invalid type: {type(message.payload)}") blob_stream = BlobStream(message.payload, message.headers) return self.byte_streamer.send( channel, topic, target, message.headers, blob_stream, STREAM_TYPE_BLOB, secure, optional )
[docs] def register_blob_callback(self, channel, topic, blob_cb: Callable, *args, **kwargs): handler = BlobHandler(blob_cb) self.byte_receiver.register_callback(channel, topic, handler.handle_blob_cb, *args, **kwargs)