# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import uuid
from typing import List, Tuple
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.apis.streaming import ConsumerFactory, ObjectConsumer, StreamableEngine, StreamContext
from nvflare.fuel.utils.validation_utils import check_positive_int, check_positive_number
from .streamer_base import ( # noqa: F401
KEY_DATA,
KEY_DATA_SIZE,
KEY_EOF,
KEY_FILE_LOCATION,
KEY_FILE_NAME,
KEY_FILE_SIZE,
BaseChunkConsumer,
BaseChunkProducer,
StreamerBase,
)
class _ChunkConsumer(BaseChunkConsumer):
def __init__(self, stream_ctx: StreamContext, dest_dir):
super().__init__()
self.file_name = stream_ctx.get(KEY_FILE_NAME)
self.dest_dir = dest_dir
self.file_size = stream_ctx.get(KEY_FILE_SIZE)
self.received_size = 0
file_path = os.path.join(dest_dir, str(uuid.uuid4()))
self.file = open(file_path, "wb")
stream_ctx[KEY_FILE_LOCATION] = file_path
def consume(
self,
shareable: Shareable,
stream_ctx: StreamContext,
fl_ctx: FLContext,
) -> Tuple[bool, Shareable]:
data = shareable.get(KEY_DATA)
data_size = shareable.get(KEY_DATA_SIZE)
self._validate_chunk(data, data_size)
if data:
self.received_size += data_size
self.file.write(data)
eof = shareable.get(KEY_EOF)
if eof:
if self.received_size != self.file_size:
err = f"received size {self.received_size} does not match expected file size {self.file_size}"
self.logger.error(err)
raise ValueError(err)
return False, make_reply(ReturnCode.OK)
return True, make_reply(ReturnCode.OK)
def finalize(self, stream_ctx: StreamContext, fl_ctx: FLContext):
if self.file:
file_location = stream_ctx.get(KEY_FILE_LOCATION)
self.file.close()
self.logger.debug(f"closed file {file_location}")
class _ChunkConsumerFactory(ConsumerFactory):
def __init__(self, dest_dir: str):
self.dest_dir = dest_dir
def get_consumer(self, stream_ctx: StreamContext, fl_ctx: FLContext) -> ObjectConsumer:
return _ChunkConsumer(stream_ctx, self.dest_dir)
class _ChunkProducer(BaseChunkProducer):
def __init__(self, file, chunk_size, timeout):
super().__init__()
self.file = file
self.chunk_size = chunk_size
self.timeout = timeout
def produce(
self,
stream_ctx: StreamContext,
fl_ctx: FLContext,
) -> Tuple[Shareable, float]:
chunk = self.file.read(self.chunk_size)
size = 0
if chunk:
size = len(chunk)
if not chunk or len(chunk) < self.chunk_size:
self.eof = True
self.logger.debug(f"sending chunk {size=}")
result = Shareable()
result[KEY_DATA] = chunk
result[KEY_DATA_SIZE] = size
result[KEY_EOF] = self.eof
return result, self.timeout
[docs]
class FileStreamer(StreamerBase):
[docs]
@staticmethod
def register_stream_processing(
fl_ctx: FLContext,
channel: str,
topic: str,
dest_dir: str = None,
stream_done_cb=None,
chunk_consumed_cb=None,
**cb_kwargs,
):
"""Register for stream processing on the receiving side.
Args:
fl_ctx: the FLContext object
channel: the app channel
topic: the app topic
dest_dir: the destination dir for received file. If not specified, system temp dir is used
stream_done_cb: if specified, the callback to be called when the file is completely received
chunk_consumed_cb: if specified, the callback to be called when a chunk is processed
**cb_kwargs: the kwargs for the stream_done_cb
Returns: None
Notes: the stream_done_cb must follow stream_done_cb_signature as defined in apis.streaming.
"""
if not dest_dir:
dest_dir = tempfile.gettempdir()
if not os.path.isdir(dest_dir):
raise ValueError(f"dest_dir '{dest_dir}' is not a valid dir")
engine = fl_ctx.get_engine()
if not isinstance(engine, StreamableEngine):
raise RuntimeError(f"engine must be StreamableEngine but got {type(engine)}")
engine.register_stream_processing(
channel=channel,
topic=topic,
factory=_ChunkConsumerFactory(dest_dir),
stream_done_cb=stream_done_cb,
consumed_cb=chunk_consumed_cb,
**cb_kwargs,
)
[docs]
@staticmethod
def stream_file(
channel: str,
topic: str,
stream_ctx: StreamContext,
targets: List[str],
file_name: str,
fl_ctx: FLContext,
chunk_size=None,
chunk_timeout=None,
optional=False,
secure=False,
) -> (str, bool):
"""Stream a file to one or more targets.
Args:
channel: the app channel
topic: the app topic
stream_ctx: context data of the stream
targets: targets that the file will be sent to
file_name: full path to the file to be streamed
fl_ctx: a FLContext object
chunk_size: size of each chunk to be streamed. If not specified, default to 1M bytes.
chunk_timeout: timeout for each chunk of data sent to targets.
optional: whether the file is optional
secure: whether P2P security is required
Returns: a tuple of (RC, Result):
- RC is ReturnCode.OK or ReturnCode.ERROR;
- Result is whether the streaming completed successfully
Notes: this is a blocking call - only returns after the streaming is done.
"""
if not os.path.isfile(file_name):
raise ValueError(f"file {file_name} is not a valid file")
if not chunk_size:
chunk_size = 1024 * 1024
check_positive_int("chunk_size", chunk_size)
if not chunk_timeout:
chunk_timeout = 5.0
check_positive_number("chunk_timeout", chunk_timeout)
file_stats = os.stat(file_name)
file_size = file_stats.st_size
if not stream_ctx:
stream_ctx = {}
stream_ctx[KEY_FILE_SIZE] = file_size
with open(file_name, "rb") as file:
producer = _ChunkProducer(file, chunk_size, chunk_timeout)
engine = fl_ctx.get_engine()
if not isinstance(engine, StreamableEngine):
raise RuntimeError(f"engine must be StreamableEngine but got {type(engine)}")
stream_ctx[KEY_FILE_NAME] = os.path.basename(file_name)
return engine.stream_objects(
channel=channel,
topic=topic,
stream_ctx=stream_ctx,
targets=targets,
producer=producer,
fl_ctx=fl_ctx,
optional=optional,
secure=secure,
)
[docs]
@staticmethod
def get_file_name(stream_ctx: StreamContext):
"""Get the file base name property from stream context.
This method is intended to be used by the stream_done_cb() function of the receiving side.
Args:
stream_ctx: the stream context
Returns: file base name
"""
return stream_ctx.get(KEY_FILE_NAME)
[docs]
@staticmethod
def get_file_location(stream_ctx: StreamContext):
"""Get the file location property from stream context.
This method is intended to be used by the stream_done_cb() function of the receiving side.
Args:
stream_ctx: the stream context
Returns: location (full file path) of the received file
"""
return stream_ctx.get(KEY_FILE_LOCATION)
[docs]
@staticmethod
def get_file_size(stream_ctx: StreamContext):
"""Get the file size property from stream context.
This method is intended to be used by the stream_done_cb() function of the receiving side.
Args:
stream_ctx: the stream context
Returns: size (in bytes) of the received file
"""
return stream_ctx.get(KEY_FILE_SIZE)