# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import uuid
from typing import Any, Dict, List, Tuple
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.apis.streaming import ConsumerFactory, ObjectConsumer, ObjectProducer, StreamableEngine, StreamContext
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.fuel.utils.validation_utils import check_positive_int, check_positive_number
from .streamer_base import StreamerBase
_PREFIX = "FileStreamer."
_KEY_FILE_NAME = _PREFIX + "file_name"
_KEY_FILE_LOCATION = _PREFIX + "file_location"
_KEY_FILE_SIZE = _PREFIX + "file_size"
_KEY_DATA_SIZE = _PREFIX + "size"
_KEY_DATA = _PREFIX + "data"
_KEY_EOF = _PREFIX + "eof"
class _ChunkConsumer(ObjectConsumer):
def __init__(self, stream_ctx: StreamContext, dest_dir):
file_name = stream_ctx.get(_KEY_FILE_NAME)
self.logger = get_obj_logger(self)
self.file_name = file_name
self.dest_dir = dest_dir
self.file_size = stream_ctx.get(_KEY_FILE_SIZE)
self.received_size = 0
file_path = os.path.join(dest_dir, str(uuid.uuid4()))
self.file = open(file_path, "wb")
stream_ctx[_KEY_FILE_LOCATION] = file_path
def consume(
self,
shareable: Shareable,
stream_ctx: StreamContext,
fl_ctx: FLContext,
) -> Tuple[bool, Shareable]:
data = shareable.get(_KEY_DATA)
data_size = shareable.get(_KEY_DATA_SIZE)
if data:
if data_size != len(data):
err = f"received data {len(data)} does not match expected size {data_size}"
self.logger.error(err)
raise ValueError(err)
self.received_size += data_size
self.file.write(data)
else:
if data_size != 0:
err = f"no data received but expected size is {data_size}"
self.logger.error(err)
raise ValueError(err)
eof = shareable.get(_KEY_EOF)
if eof:
# stop streaming
if self.received_size != self.file_size:
err = f"received size {self.received_size} does not match expected file size {self.file_size}"
self.logger.error(err)
raise ValueError(err)
else:
return False, make_reply(ReturnCode.OK)
else:
# continue streaming
return True, make_reply(ReturnCode.OK)
def finalize(self, stream_ctx: StreamContext, fl_ctx: FLContext):
if self.file:
file_location = stream_ctx.get(_KEY_FILE_LOCATION)
self.file.close()
self.logger.debug(f"closed file {file_location}")
class _ChunkConsumerFactory(ConsumerFactory):
def __init__(self, dest_dir: str):
self.dest_dir = dest_dir
def get_consumer(self, stream_ctx: StreamContext, fl_ctx: FLContext) -> ObjectConsumer:
return _ChunkConsumer(stream_ctx, self.dest_dir)
class _ChunkProducer(ObjectProducer):
def __init__(self, file, chunk_size, timeout):
self.file = file
self.chunk_size = chunk_size
self.timeout = timeout
self.eof = False
self.logger = get_obj_logger(self)
def produce(
self,
stream_ctx: StreamContext,
fl_ctx: FLContext,
) -> Tuple[Shareable, float]:
chunk = self.file.read(self.chunk_size)
size = 0
if chunk:
size = len(chunk)
if not chunk or len(chunk) < self.chunk_size:
self.eof = True
self.logger.debug(f"sending chunk {size=}")
result = Shareable()
result[_KEY_DATA] = chunk
result[_KEY_DATA_SIZE] = size
result[_KEY_EOF] = self.eof
return result, self.timeout
def process_replies(
self,
replies: Dict[str, Shareable],
stream_ctx: StreamContext,
fl_ctx: FLContext,
) -> Any:
has_error = False
final_result = {}
for target, reply in replies.items():
rc = reply.get_return_code(ReturnCode.OK)
if rc == ReturnCode.OK:
final_result[target] = reply
else:
self.logger.error(f"error from target {target}: {rc}")
has_error = True
if has_error:
# done - failed
return False
elif self.eof:
# done - succeeded
return final_result
else:
# not done yet - continue streaming
return None
[docs]
class FileStreamer(StreamerBase):
[docs]
@staticmethod
def register_stream_processing(
fl_ctx: FLContext,
channel: str,
topic: str,
dest_dir: str = None,
stream_done_cb=None,
chunk_consumed_cb=None,
**cb_kwargs,
):
"""Register for stream processing on the receiving side.
Args:
fl_ctx: the FLContext object
channel: the app channel
topic: the app topic
dest_dir: the destination dir for received file. If not specified, system temp dir is used
stream_done_cb: if specified, the callback to be called when the file is completely received
chunk_consumed_cb: if specified, the callback to be called when a chunk is processed
**cb_kwargs: the kwargs for the stream_done_cb
Returns: None
Notes: the stream_done_cb must follow stream_done_cb_signature as defined in apis.streaming.
"""
if not dest_dir:
dest_dir = tempfile.gettempdir()
if not os.path.isdir(dest_dir):
raise ValueError(f"dest_dir '{dest_dir}' is not a valid dir")
engine = fl_ctx.get_engine()
if not isinstance(engine, StreamableEngine):
raise RuntimeError(f"engine must be StreamableEngine but got {type(engine)}")
engine.register_stream_processing(
channel=channel,
topic=topic,
factory=_ChunkConsumerFactory(dest_dir),
stream_done_cb=stream_done_cb,
consumed_cb=chunk_consumed_cb,
**cb_kwargs,
)
[docs]
@staticmethod
def stream_file(
channel: str,
topic: str,
stream_ctx: StreamContext,
targets: List[str],
file_name: str,
fl_ctx: FLContext,
chunk_size=None,
chunk_timeout=None,
optional=False,
secure=False,
) -> (str, bool):
"""Stream a file to one or more targets.
Args:
channel: the app channel
topic: the app topic
stream_ctx: context data of the stream
targets: targets that the file will be sent to
file_name: full path to the file to be streamed
fl_ctx: a FLContext object
chunk_size: size of each chunk to be streamed. If not specified, default to 1M bytes.
chunk_timeout: timeout for each chunk of data sent to targets.
optional: whether the file is optional
secure: whether P2P security is required
Returns: a tuple of (RC, Result):
- RC is ReturnCode.OK or ReturnCode.ERROR;
- Result is whether the streaming completed successfully
Notes: this is a blocking call - only returns after the streaming is done.
"""
if not os.path.isfile(file_name):
raise ValueError(f"file {file_name} is not a valid file")
if not chunk_size:
chunk_size = 1024 * 1024
check_positive_int("chunk_size", chunk_size)
if not chunk_timeout:
chunk_timeout = 5.0
check_positive_number("chunk_timeout", chunk_timeout)
file_stats = os.stat(file_name)
file_size = file_stats.st_size
if not stream_ctx:
stream_ctx = {}
stream_ctx[_KEY_FILE_SIZE] = file_size
with open(file_name, "rb") as file:
producer = _ChunkProducer(file, chunk_size, chunk_timeout)
engine = fl_ctx.get_engine()
if not isinstance(engine, StreamableEngine):
raise RuntimeError(f"engine must be StreamableEngine but got {type(engine)}")
stream_ctx[_KEY_FILE_NAME] = os.path.basename(file_name)
return engine.stream_objects(
channel=channel,
topic=topic,
stream_ctx=stream_ctx,
targets=targets,
producer=producer,
fl_ctx=fl_ctx,
optional=optional,
secure=secure,
)
[docs]
@staticmethod
def get_file_name(stream_ctx: StreamContext):
"""Get the file base name property from stream context.
This method is intended to be used by the stream_done_cb() function of the receiving side.
Args:
stream_ctx: the stream context
Returns: file base name
"""
return stream_ctx.get(_KEY_FILE_NAME)
[docs]
@staticmethod
def get_file_location(stream_ctx: StreamContext):
"""Get the file location property from stream context.
This method is intended to be used by the stream_done_cb() function of the receiving side.
Args:
stream_ctx: the stream context
Returns: location (full file path) of the received file
"""
return stream_ctx.get(_KEY_FILE_LOCATION)
[docs]
@staticmethod
def get_file_size(stream_ctx: StreamContext):
"""Get the file size property from stream context.
This method is intended to be used by the stream_done_cb() function of the receiving side.
Args:
stream_ctx: the stream context
Returns: size (in bytes) of the received file
"""
return stream_ctx.get(_KEY_FILE_SIZE)