Source code for nvflare.fuel.f3.streaming.file_streamer

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from pathlib import Path
from typing import Callable, Optional

from nvflare.fuel.f3.connection import BytesAlike
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.streaming.byte_receiver import ByteReceiver
from nvflare.fuel.f3.streaming.byte_streamer import STREAM_TYPE_FILE, ByteStreamer
from nvflare.fuel.f3.streaming.stream_const import StreamHeaderKey
from nvflare.fuel.f3.streaming.stream_types import Stream, StreamFuture
from nvflare.fuel.f3.streaming.stream_utils import stream_thread_pool

log = logging.getLogger(__name__)


[docs]class FileStream(Stream): def __init__(self, file_name: str, headers: Optional[dict]): self.file = open(file_name, "rb") size = self.file.seek(0, os.SEEK_END) self.file.seek(0, os.SEEK_SET) super().__init__(size, headers)
[docs] def read(self, chunk_size: int) -> BytesAlike: return self.file.read(chunk_size)
[docs] def close(self): self.closed = True self.file.close()
[docs]class FileHandler: def __init__(self, file_cb: Callable): self.file_cb = file_cb self.size = 0 self.file_name = None
[docs] def handle_file_cb(self, future: StreamFuture, stream: Stream, resume: bool, *args, **kwargs) -> int: if resume: log.warning("Resume is not supported, ignored") self.size = stream.get_size() original_name = future.headers.get(StreamHeaderKey.FILE_NAME) file_name = self.file_cb(future, original_name, *args, **kwargs) stream_thread_pool.submit(self._write_to_file, file_name, future, stream) return 0
def _write_to_file(self, file_name: str, future: StreamFuture, stream: Stream): file = open(file_name, "wb") chunk_size = ByteStreamer.get_chunk_size() file_size = 0 while True: buf = stream.read(chunk_size) if not buf: break file_size += len(buf) file.write(buf) file.close() if self.size and (self.size != file_size): log.warning(f"Size doesn't match: {self.size} <> {file_size}") future.set_result(file_name)
[docs]class FileStreamer: def __init__(self, byte_streamer: ByteStreamer, byte_receiver: ByteReceiver): self.byte_streamer = byte_streamer self.byte_receiver = byte_receiver
[docs] def send( self, channel: str, topic: str, target: str, message: Message, secure=False, optional=False ) -> StreamFuture: file_name = Path(message.payload).name file_stream = FileStream(message.payload, message.headers) message.add_headers( { StreamHeaderKey.SIZE: file_stream.get_size(), StreamHeaderKey.FILE_NAME: file_name, } ) return self.byte_streamer.send( channel, topic, target, message.headers, file_stream, STREAM_TYPE_FILE, secure, optional )
[docs] def register_file_callback(self, channel, topic, file_cb: Callable, *args, **kwargs): handler = FileHandler(file_cb) self.byte_receiver.register_callback(channel, topic, handler.handle_file_cb, *args, **kwargs)