# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Callable, Optional
from nvflare.fuel.f3.connection import BytesAlike
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.streaming.byte_receiver import ByteReceiver
from nvflare.fuel.f3.streaming.byte_streamer import STREAM_TYPE_BLOB, ByteStreamer
from nvflare.fuel.f3.streaming.stream_const import EOS
from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture
from nvflare.fuel.f3.streaming.stream_utils import FastBuffer, stream_thread_pool, wrap_view
from nvflare.fuel.utils.buffer_list import BufferList
from nvflare.security.logging import secure_format_traceback
log = logging.getLogger(__name__)
[docs]class BlobStream(Stream):
def __init__(self, blob: BytesAlike, headers: Optional[dict]):
size = self.buffer_len(blob)
super().__init__(size, headers)
if not isinstance(blob, list):
self.blob_view = wrap_view(blob)
self.buffer_list = None
else:
self.blob_view = [wrap_view(b) for b in blob]
self.buffer_list = BufferList(self.blob_view)
[docs] def read(self, chunk_size: int) -> BytesAlike:
if self.pos >= self.get_size():
return EOS
next_pos = self.pos + chunk_size
if next_pos > self.get_size():
next_pos = self.get_size()
if self.buffer_list:
buf = self.buffer_list.read(self.pos, next_pos)
else:
buf = self.blob_view[self.pos : next_pos]
self.pos = next_pos
return buf
[docs] @staticmethod
def buffer_len(buffer: BytesAlike):
if not isinstance(buffer, list):
return len(buffer)
return sum(len(buf) for buf in buffer)
[docs]class BlobTask:
def __init__(self, future: StreamFuture, stream: Stream):
self.future = future
self.stream = stream
self.size = stream.get_size()
self.pre_allocated = self.size > 0
if self.pre_allocated:
self.buffer = wrap_view(bytearray(self.size))
else:
self.buffer = FastBuffer()
[docs]class BlobHandler:
def __init__(self, blob_cb: Callable):
self.blob_cb = blob_cb
[docs] def handle_blob_cb(self, future: StreamFuture, stream: Stream, resume: bool, *args, **kwargs) -> int:
if resume:
log.warning("Resume is not supported, ignored")
blob_task = BlobTask(future, stream)
stream_thread_pool.submit(self._read_stream, blob_task)
self.blob_cb(future, *args, **kwargs)
return 0
@staticmethod
def _read_stream(blob_task: BlobTask):
try:
# It's most efficient to use the same chunk size as the stream
chunk_size = ByteStreamer.get_chunk_size()
buf_size = 0
while True:
buf = blob_task.stream.read(chunk_size)
if not buf:
break
length = len(buf)
if blob_task.pre_allocated:
blob_task.buffer[buf_size : buf_size + length] = buf
else:
blob_task.buffer.append(buf)
buf_size += length
if blob_task.size and blob_task.size != buf_size:
log.warning(
f"Stream {blob_task.future.get_stream_id()} size doesn't match: " f"{blob_task.size} <> {buf_size}"
)
if blob_task.pre_allocated:
result = blob_task.buffer
else:
result = blob_task.buffer.to_bytes()
blob_task.future.set_result(result)
except Exception as ex:
log.error(f"Stream {blob_task.future.get_stream_id()} read error: {ex}")
log.debug(secure_format_traceback())
blob_task.future.set_exception(ex)
[docs]class BlobStreamer:
def __init__(self, byte_streamer: ByteStreamer, byte_receiver: ByteReceiver):
self.byte_streamer = byte_streamer
self.byte_receiver = byte_receiver
[docs] def send(
self, channel: str, topic: str, target: str, message: Message, secure: bool, optional: bool
) -> StreamFuture:
if message.payload is None:
message.payload = bytes(0)
if not isinstance(message.payload, (bytes, bytearray, memoryview, list)):
raise StreamError(f"BLOB is invalid type: {type(message.payload)}")
blob_stream = BlobStream(message.payload, message.headers)
return self.byte_streamer.send(
channel, topic, target, message.headers, blob_stream, STREAM_TYPE_BLOB, secure, optional
)
[docs] def register_blob_callback(self, channel, topic, blob_cb: Callable, *args, **kwargs):
handler = BlobHandler(blob_cb)
self.byte_receiver.register_callback(channel, topic, handler.handle_blob_cb, *args, **kwargs)