Source code for nvflare.fuel.f3.streaming.file_downloader

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
import tempfile
import uuid
from typing import Any, List, Optional

from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.streaming.obj_downloader import Consumer, ObjDownloader, Producer, ProduceRC, download_object
from nvflare.fuel.utils.validation_utils import check_positive_int

DEFAULT_CHUNK_SIZE = 5 * 1024 * 1024

"""
This package implements file downloading capability based on the ObjDownloader framework.
It provides implementation of the Producer and Consumer objects, required by ObjDownloader.
"""


class _StateKey:
    RECEIVED_BYTES = "received_bytes"


class _File:

    def __init__(self, file_name):
        """This is the "object" to be downloaded.

        Args:
            file_name: name of the file.
        """
        self.name = file_name
        self.size = os.path.getsize(file_name)


class _ChunkProducer(Producer):

    def __init__(self, chunk_size=None):
        Producer.__init__(self)
        if not chunk_size:
            chunk_size = DEFAULT_CHUNK_SIZE

        check_positive_int("chunk_size", chunk_size)
        self.chunk_size = chunk_size

    def produce(self, ref_id: str, obj, state: dict, requester: str) -> (str, Any, dict):
        assert isinstance(obj, _File)
        received_bytes = 0
        if state:
            received_bytes = state.get(_StateKey.RECEIVED_BYTES, 0)

        if not isinstance(received_bytes, int) or received_bytes < 0:
            self.logger.error(f"bad {_StateKey.RECEIVED_BYTES} {received_bytes} from {requester}")
            return ProduceRC.ERROR, None, None

        if received_bytes >= obj.size:
            # already done
            return ProduceRC.EOF, None, None

        num_bytes_to_send = min(self.chunk_size, obj.size - received_bytes)
        with open(obj.name, "rb") as f:
            f.seek(received_bytes)
            chunk = f.read(num_bytes_to_send)

        self.logger.debug(f"{received_bytes=}; sending {len(chunk)} bytes")
        return ProduceRC.OK, chunk, {_StateKey.RECEIVED_BYTES: received_bytes + len(chunk)}


[docs] class FileDownloader(ObjDownloader):
[docs] @classmethod def new_transaction( cls, cell: Cell, timeout: float, timeout_cb, **cb_kwargs, ): """Create a new file download transaction. Args: cell: the cell for communication with recipients timeout: timeout for the transaction timeout_cb: CB to be called when the transaction is timed out **cb_kwargs: args to be passed to the CB Returns: transaction id The timeout_cb must follow this signature: cb(tx_id, file_names: List[str], **cb_args) """ return ObjDownloader.new_transaction( cell=cell, producer=_ChunkProducer(), timeout=timeout, timeout_cb=cls._tx_timeout, app_timeout_cb=timeout_cb, **cb_kwargs, )
@classmethod def _tx_timeout(cls, tx_id: str, objs: List[Any], app_timeout_cb, **cb_kwargs): if app_timeout_cb: file_names = [obj.name for obj in objs] app_timeout_cb(tx_id, file_names, **cb_kwargs)
[docs] @classmethod def add_file( cls, transaction_id: str, file_name: str, ref_id=None, file_downloaded_cb=None, **cb_kwargs, ) -> str: """Add a file to be downloaded to the specified transaction. Args: transaction_id: ID of the transaction file_name: name of the file to be downloaded ref_id: ref id to be used, if provided file_downloaded_cb: CB to be called when the file is done downloading **cb_kwargs: args to be passed to the CB Returns: reference id for the file. The file_downloaded_cb must follow this signature: cb(ref_id: str, to_site: str, status: str, file_name: str, **cb_kwargs) """ obj = _File(file_name) return ObjDownloader.add_download_object( transaction_id=transaction_id, obj=obj, ref_id=ref_id, obj_downloaded_cb=cls._file_downloaded, app_downloaded_cb=file_downloaded_cb, **cb_kwargs, )
@classmethod def _file_downloaded(cls, ref_id: str, to_site: str, status: str, obj: _File, app_downloaded_cb, **cb_kwargs): if app_downloaded_cb: app_downloaded_cb(ref_id, to_site, status, obj.name, **cb_kwargs)
[docs] @classmethod def download_file( cls, from_fqcn: str, ref_id: str, per_request_timeout: float, cell: Cell, location: str = None, secure=False, optional=False, abort_signal=None, ) -> (str, Optional[str]): """Download the referenced file from the file owner. Args: from_fqcn: FQCN of the file owner. ref_id: reference ID of the file to be downloaded. per_request_timeout: timeout for requests sent to the file owner. cell: cell to be used for communicating to the file owner. location: dir for keeping the received file. If not specified, will use temp dir. secure: P2P private mode for communication optional: supress log messages of communication abort_signal: signal for aborting download. Returns: tuple of (error message if any, full path of the downloaded file). """ if location is not None: if not os.path.exists(location): raise ValueError(f"location '{location}' does not exist") if not os.path.isdir(location): raise ValueError(f"location '{location}' is not a valid dir") else: location = tempfile.gettempdir() consumer = _ChunkConsumer(location) download_object( from_fqcn=from_fqcn, ref_id=ref_id, consumer=consumer, per_request_timeout=per_request_timeout, cell=cell, secure=secure, optional=optional, abort_signal=abort_signal, ) return consumer.error, consumer.file_path
class _ChunkConsumer(Consumer): def __init__(self, location: str): Consumer.__init__(self) self.location = location self.file_path = os.path.join(location, str(uuid.uuid4())) self.file = open(self.file_path, "wb") self.logger.debug(f"created file {self.file_path}") self.total_bytes = 0 self.error = None def consume(self, ref_id, state: dict, data: Any) -> dict: assert isinstance(data, bytes) self.file.write(data) self.total_bytes += len(data) self.logger.debug(f"received {self.total_bytes} bytes for file {self.file_path}") return {_StateKey.RECEIVED_BYTES: self.total_bytes} def download_failed(self, ref_id, reason: str): self.logger.error(f"failed to download file with ref {ref_id}: {reason}") self.error = reason self.file.close() def download_completed(self, ref_id: str): self.file.close() self.logger.debug(f"closed file {self.file_path}")
[docs] def download_file( from_fqcn: str, ref_id: str, per_request_timeout: float, cell: Cell, location: str = None, secure=False, optional=False, abort_signal=None, ) -> (str, Optional[str]): return FileDownloader.download_file( from_fqcn, ref_id, per_request_timeout, cell, location, secure, optional, abort_signal )