# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import threading
from typing import Callable, Optional
from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.connection import BytesAlike
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.streaming.byte_receiver import ByteReceiver
from nvflare.fuel.f3.streaming.byte_streamer import STREAM_CHUNK_SIZE, STREAM_TYPE_BLOB, ByteStreamer
from nvflare.fuel.f3.streaming.stream_const import EOS
from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture
from nvflare.fuel.f3.streaming.stream_utils import FastBuffer, callback_thread_pool, stream_thread_pool, wrap_view
from nvflare.fuel.utils.buffer_list import BufferList
from nvflare.security.logging import secure_format_traceback
log = logging.getLogger(__name__)
[docs]
class BlobStream(Stream):
def __init__(self, blob: BytesAlike, headers: Optional[dict]):
size = self.buffer_len(blob)
super().__init__(size, headers)
if not isinstance(blob, list):
self.blob_view = wrap_view(blob)
self.buffer_list = None
else:
self.blob_view = [wrap_view(b) for b in blob]
self.buffer_list = BufferList(self.blob_view)
[docs]
def read(self, chunk_size: int) -> BytesAlike:
if self.pos >= self.get_size():
return EOS
next_pos = self.pos + chunk_size
if next_pos > self.get_size():
next_pos = self.get_size()
if self.buffer_list:
buf = self.buffer_list.read(self.pos, next_pos)
else:
buf = self.blob_view[self.pos : next_pos]
self.pos = next_pos
return buf
[docs]
@staticmethod
def buffer_len(buffer: BytesAlike):
if not isinstance(buffer, list):
return len(buffer)
return sum(len(buf) for buf in buffer)
[docs]
class BlobTask:
def __init__(self, future: StreamFuture, stream: Stream, max_size: int = 0):
self.future = future
self.stream = stream
self.size = stream.get_size()
self.max_size = max_size
if self.size < 0:
raise StreamError(f"Declared blob size cannot be negative: {self.size}")
if self.max_size > 0 and self.size > self.max_size:
raise StreamError(f"Declared blob size {self.size} exceeds configured limit {self.max_size}")
self.pre_allocated = self.size > 0
if self.pre_allocated:
self.buffer = wrap_view(bytearray(self.size))
else:
self.buffer = FastBuffer()
def __str__(self):
return f"Blob[SID:{self.future.get_stream_id()} Size:{self.size}]"
[docs]
class BlobHandler:
def __init__(self, blob_cb: Callable):
self.blob_cb = blob_cb
config = CommConfigurator()
self.chunk_size = config.get_streaming_chunk_size(STREAM_CHUNK_SIZE)
self.max_blob_size = config.get_streaming_max_blob_size()
@staticmethod
def _fail(stream: Stream, future: StreamFuture, error: StreamError):
if hasattr(stream, "task"):
stream.task.stop(error)
else:
future.set_exception(error)
def _store_chunk(self, blob_task: BlobTask, buf: BytesAlike, buf_size: int, thread_id: int) -> bool:
length = len(buf)
if blob_task.pre_allocated:
return self._store_pre_allocated_chunk(blob_task, buf, buf_size, length, thread_id)
return self._append_dynamic_chunk(blob_task, buf, buf_size, length, thread_id)
def _store_pre_allocated_chunk(
self, blob_task: BlobTask, buf: BytesAlike, buf_size: int, length: int, thread_id: int
) -> bool:
remaining = len(blob_task.buffer) - buf_size
if length > remaining:
log.error(f"{blob_task} Buffer overrun: {thread_id=} {remaining=} {length=} {buf_size=}")
self._fail(
blob_task.stream,
blob_task.future,
StreamError(f"Buffer overrun: stream produced more data than declared size {blob_task.size}"),
)
return False
blob_task.buffer[buf_size : buf_size + length] = buf
return True
def _append_dynamic_chunk(
self, blob_task: BlobTask, buf: BytesAlike, buf_size: int, length: int, thread_id: int
) -> bool:
next_size = buf_size + length
# read() already pulled this chunk, so rejection can overshoot by one chunk at most.
if blob_task.max_size > 0 and next_size > blob_task.max_size:
log.error(f"{blob_task} Size limit exceeded: {thread_id=} {next_size=} limit={blob_task.max_size}")
error = StreamError(
f"Blob received more data than configured limit {blob_task.max_size}: "
f"received at least {next_size} bytes"
)
self._fail(blob_task.stream, blob_task.future, error)
return False
blob_task.buffer.append(buf)
return True
[docs]
def handle_blob_cb(self, future: StreamFuture, stream: Stream, resume: bool, *args, **kwargs) -> int:
if resume:
log.warning("Resume is not supported, ignored")
try:
blob_task = BlobTask(future, stream, self.max_blob_size)
except StreamError as ex:
self._fail(stream, future, ex)
return 0
except MemoryError as ex:
error = StreamError(f"Unable to allocate buffer for declared blob size {stream.get_size()}")
error.__cause__ = ex
self._fail(stream, future, error)
return 0
stream_thread_pool.submit(self._read_stream, blob_task)
callback_thread_pool.submit(self._run_blob_cb, future, stream, args, kwargs)
return 0
def _run_blob_cb(self, future: StreamFuture, stream: Stream, args: tuple, kwargs: dict):
"""Run blob_cb on the callback pool; preserve exception handling (log + task.stop) as in ByteReceiver."""
try:
self.blob_cb(future, *args, **kwargs)
except Exception as ex:
# Suppress only when blob_cb is surfacing an already-recorded stream
# failure (for example by calling future.result()). If blob_cb fails
# after the future completed successfully, we still need to stop the
# task so the sender receives the callback error.
with future.lock:
already_failed = future.error is not None
if already_failed:
kind = "StreamError" if isinstance(ex, StreamError) else "Exception"
log.debug(f"{kind} from blob_cb suppressed; future already failed: {ex}")
else:
log.error(f"blob_cb threw: {ex}\n{secure_format_traceback()}")
if hasattr(stream, "task"):
stream.task.stop(StreamError(f"blob_cb threw {type(ex).__name__}: {ex}"))
def _read_stream(self, blob_task: BlobTask):
try:
# It's most efficient to read the whole chunk
size = self.chunk_size
thread_id = threading.get_native_id()
buf_size = 0
while True:
buf = blob_task.stream.read(size)
if not buf:
break
try:
if not self._store_chunk(blob_task, buf, buf_size, thread_id):
return
except Exception as ex:
length = len(buf)
log.error(
f"{blob_task} memoryview error: {ex} Debug info: "
f"{thread_id=} {length=} {buf_size=} {type(buf)=}"
)
raise ex
buf_size += len(buf)
if blob_task.size and blob_task.size != buf_size:
blob_task.future.set_exception(
StreamError(f"Size mismatch: declared {blob_task.size} but received {buf_size} bytes")
)
return
if blob_task.pre_allocated:
result = blob_task.buffer
else:
result = blob_task.buffer.to_bytes()
blob_task.future.set_result(result)
except Exception as ex:
log.error(f"Stream {blob_task} Read error: {ex}")
log.error(secure_format_traceback())
blob_task.future.set_exception(ex)
[docs]
class BlobStreamer:
def __init__(self, byte_streamer: ByteStreamer, byte_receiver: ByteReceiver):
self.byte_streamer = byte_streamer
self.byte_receiver = byte_receiver
[docs]
def send(
self, channel: str, topic: str, target: str, message: Message, secure: bool, optional: bool
) -> StreamFuture:
if message.payload is None:
message.payload = bytes(0)
if not isinstance(message.payload, (bytes, bytearray, memoryview, list)):
raise StreamError(f"BLOB is invalid type: {type(message.payload)}")
blob_stream = BlobStream(message.payload, message.headers)
return self.byte_streamer.send(
channel, topic, target, message.headers, blob_stream, STREAM_TYPE_BLOB, secure, optional
)
[docs]
def register_blob_callback(self, channel, topic, blob_cb: Callable, *args, **kwargs):
handler = BlobHandler(blob_cb)
self.byte_receiver.register_callback(channel, topic, handler.handle_blob_cb, *args, **kwargs)