Source code for nvflare.app_common.np.np_downloader

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from io import BytesIO
from typing import Any, List, Optional, Tuple

import numpy as np

from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.streaming.cacheable import CacheableObject, ItemConsumer
from nvflare.fuel.f3.streaming.download_service import download_object
from nvflare.fuel.f3.streaming.obj_downloader import ObjectDownloader

_TWO_MB = 2 * 1024 * 1024


[docs] class ArrayDownloadable(CacheableObject): def __init__(self, arrays: dict[str, np.ndarray], max_chunk_size: int): self.size = len(arrays) self.keys = list(arrays.keys()) super().__init__(arrays, max_chunk_size)
[docs] def get_item_count(self) -> int: return self.size
[docs] def produce_item(self, index: int) -> bytes: key = self.keys[index] arrays_to_send = {key: self.base_obj[key]} stream = BytesIO() np.savez(allow_pickle=False, file=stream, **arrays_to_send) return stream.getvalue()
[docs] class ArrayConsumer(ItemConsumer): def __init__(self, arrays_received_cb, cb_kwargs): ItemConsumer.__init__(self) self.arrays_received_cb = arrays_received_cb self.cb_kwargs = cb_kwargs if arrays_received_cb is not None and not callable(arrays_received_cb): raise ValueError("arrays_received_cb must be callable") @staticmethod def _to_dict(item: bytes) -> dict: result = {} stream = BytesIO(item) with np.load(stream, allow_pickle=False) as npz_obj: for k in npz_obj.files: result[k] = npz_obj[k] return result
[docs] def consume_items(self, items: List[Any], result: Any) -> Any: assert isinstance(items, list) if result is None: result = {} arrays = {} for item in items: td = self._to_dict(item) if not isinstance(td, dict): raise ValueError("cannot load received bytes to arrays") arrays.update(td) if self.arrays_received_cb is not None: cb_result = self.arrays_received_cb(arrays, **self.cb_kwargs) if isinstance(cb_result, dict): result.update(cb_result) else: result.update(arrays) return result
[docs] def add_arrays( downloader: ObjectDownloader, arrays: dict[str, np.ndarray], max_chunk_size: int = _TWO_MB, ) -> str: """Add arrays to be downloaded to the specified downloader. Args: downloader: the downloader to add arrays to. arrays: arrays to be downloaded max_chunk_size: max chunk size Returns: reference id for the arrays. """ obj = ArrayDownloadable(arrays, max_chunk_size) return downloader.add_object(obj)
[docs] def download_arrays( from_fqcn: str, ref_id: str, per_request_timeout: float, cell: Cell, secure=False, optional=False, abort_signal=None, arrays_received_cb=None, **cb_kwargs, ) -> Tuple[str, Optional[dict[str, np.ndarray]]]: """Download the referenced arrays from the source. Args: from_fqcn: FQCN of the data source. ref_id: reference ID of the arrays to be downloaded. per_request_timeout: timeout for requests sent to the data source. cell: cell to be used for communicating to the data source. secure: P2P private mode for communication optional: suppress log messages of communication abort_signal: signal for aborting download. arrays_received_cb: the callback to be called when one set of arrays are received Returns: tuple of (error message if any, downloaded state dict). """ consumer = ArrayConsumer(arrays_received_cb, cb_kwargs) download_object( from_fqcn=from_fqcn, ref_id=ref_id, consumer=consumer, per_request_timeout=per_request_timeout, cell=cell, secure=secure, optional=optional, abort_signal=abort_signal, ) return consumer.error, consumer.result