# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import struct
import tempfile
import threading
import weakref
from typing import Any, List, Optional, Tuple
import torch
from safetensors.torch import load as load_tensors
from safetensors.torch import save as save_tensors
from nvflare.app_common.utils.tensor_disk_offload_context import _TENSOR_DISK_OFFLOAD_ROOT_DIR
from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.streaming.cacheable import CacheableObject, ItemConsumer
from nvflare.fuel.f3.streaming.download_service import download_object
from nvflare.fuel.f3.streaming.obj_downloader import ObjectDownloader
from .lazy_tensor_dict import LazyTensorDict, _cleanup_temp_dir
_TWO_MB = 2 * 1024 * 1024
_ACTIVE_DISK_TENSOR_CONSUMERS = weakref.WeakSet()
_ACTIVE_DISK_TENSOR_CONSUMERS_LOCK = threading.Lock()
[docs]
def cleanup_active_disk_tensor_downloads(reason: str = "download aborted") -> None:
"""Clean partial tensor offload dirs still owned by active disk consumers."""
with _ACTIVE_DISK_TENSOR_CONSUMERS_LOCK:
consumers = list(_ACTIVE_DISK_TENSOR_CONSUMERS)
for consumer in consumers:
consumer.download_failed("active_disk_tensor_download", reason)
[docs]
class TensorDownloadable(CacheableObject):
def __init__(self, tensors: dict[str, torch.Tensor], max_chunk_size: int):
self.size = len(tensors)
self.keys = list(tensors.keys())
super().__init__(tensors, max_chunk_size)
[docs]
def get_item_count(self) -> int:
return self.size
[docs]
def produce_item(self, index: int) -> bytes:
key = self.keys[index]
tensor_to_send = {key: self.base_obj[key]}
return save_tensors(tensor_to_send)
[docs]
class TensorConsumer(ItemConsumer):
def __init__(self, tensors_received_cb, cb_kwargs):
ItemConsumer.__init__(self)
self.tensors_received_cb = tensors_received_cb
self.cb_kwargs = cb_kwargs
if tensors_received_cb is not None and not callable(tensors_received_cb):
raise ValueError("tensors_received_cb must be callable")
[docs]
def consume_items(self, items: List[Any], result: Any) -> Any:
if not isinstance(items, list):
raise TypeError(f"items must be list but got {type(items)}")
if result is None:
result = {}
tensors = {}
for item in items:
td = load_tensors(item)
if not isinstance(td, dict):
raise ValueError("cannot load received bytes to tensors")
tensors.update(td)
if self.tensors_received_cb:
cb_result = self.tensors_received_cb(tensors, **self.cb_kwargs)
if isinstance(cb_result, dict):
result.update(cb_result)
else:
result.update(tensors)
return result
[docs]
def add_tensors(
downloader: ObjectDownloader,
tensors: dict[str, torch.Tensor],
max_chunk_size: int = _TWO_MB,
) -> str:
"""Add tensors to be downloaded to the specified downloader.
Args:
downloader: the downloader to add tensors to.
tensors: state dict to be downloaded
max_chunk_size: max chunk size
Returns: reference id for the state dict.
"""
obj = TensorDownloadable(tensors, max_chunk_size)
return downloader.add_object(obj)
[docs]
def download_tensors(
from_fqcn: str,
ref_id: str,
per_request_timeout: float,
cell: Cell,
secure=False,
optional=False,
abort_signal=None,
tensors_received_cb=None,
**cb_kwargs,
) -> Tuple[str, Optional[dict[str, torch.Tensor]]]:
"""Download the referenced state dict from the source.
Args:
from_fqcn: FQCN of the data source.
ref_id: reference ID of the state dict to be downloaded.
per_request_timeout: timeout for requests sent to the data source.
cell: cell to be used for communicating to the data source.
secure: P2P private mode for communication
optional: supress log messages of communication
abort_signal: signal for aborting download.
tensors_received_cb: the callback to be called when one set of tensors are received
Returns: tuple of (error message if any, downloaded state dict).
"""
consumer = TensorConsumer(tensors_received_cb, cb_kwargs)
download_object(
from_fqcn=from_fqcn,
ref_id=ref_id,
consumer=consumer,
per_request_timeout=per_request_timeout,
cell=cell,
secure=secure,
optional=optional,
abort_signal=abort_signal,
)
return consumer.error, consumer.result
def _extract_safetensors_keys(data: bytes) -> list[str]:
"""Extract tensor key names from safetensors header without deserializing tensors."""
if len(data) < 8:
raise ValueError("Invalid safetensors data: too short")
header_size = struct.unpack("<Q", data[:8])[0]
if header_size == 0:
raise ValueError("Invalid safetensors data: empty header")
header_end = 8 + header_size
if header_end > len(data):
raise ValueError("Invalid safetensors data: header size exceeds payload length")
try:
header = json.loads(data[8:header_end])
except Exception as e:
raise ValueError("Invalid safetensors data: invalid JSON header") from e
if not isinstance(header, dict):
raise ValueError("Invalid safetensors data: header must be JSON object")
return [k for k in header.keys() if k != "__metadata__"]
[docs]
class DiskTensorConsumer(ItemConsumer):
"""Writes raw safetensors bytes to disk without deserializing to tensors."""
def __init__(self, temp_dir: str):
ItemConsumer.__init__(self)
self._temp_dir = temp_dir
self._cleaned = False
self._file_counter = 0
with _ACTIVE_DISK_TENSOR_CONSUMERS_LOCK:
_ACTIVE_DISK_TENSOR_CONSUMERS.add(self)
[docs]
def release(self) -> None:
with _ACTIVE_DISK_TENSOR_CONSUMERS_LOCK:
_ACTIVE_DISK_TENSOR_CONSUMERS.discard(self)
self._cleaned = True
[docs]
def cleanup(self) -> None:
with _ACTIVE_DISK_TENSOR_CONSUMERS_LOCK:
if self._cleaned:
return
self._cleaned = True
_ACTIVE_DISK_TENSOR_CONSUMERS.discard(self)
_cleanup_temp_dir(self._temp_dir)
[docs]
def consume_items(self, items: List[Any], result: Any) -> Any:
if not isinstance(items, list):
raise TypeError(f"items must be list but got {type(items)}")
if result is None:
result = {}
for item in items:
keys = _extract_safetensors_keys(item)
file_path = os.path.join(self._temp_dir, f"chunk_{self._file_counter}.safetensors")
self._file_counter += 1
with open(file_path, "wb") as f:
f.write(item)
for key in keys:
if key in result:
raise ValueError(
f"Duplicate tensor key '{key}' seen in multiple safetensors chunks; "
"streaming data may be malformed."
)
result[key] = (file_path, key)
return result
[docs]
def download_failed(self, ref_id, reason: str):
super().download_failed(ref_id, reason)
# Eager cleanup on download callback error; the outer caller may also
# attempt cleanup via consumer.error path. Double cleanup is intentional
# and safe because _cleanup_temp_dir handles already-removed paths.
self.cleanup()
[docs]
def download_tensors_to_disk(
from_fqcn: str,
ref_id: str,
per_request_timeout: float,
cell: Cell,
secure=False,
optional=False,
abort_signal=None,
) -> Tuple[str, Optional[LazyTensorDict]]:
"""Download tensors to disk instead of memory.
Returns: tuple of (error message if any, LazyTensorDict for lazy access).
"""
root_dir = cell.get_fobs_context().get(_TENSOR_DISK_OFFLOAD_ROOT_DIR)
if not root_dir:
raise RuntimeError(f"{_TENSOR_DISK_OFFLOAD_ROOT_DIR} is not set in FOBS context")
temp_dir = tempfile.mkdtemp(prefix="nvflare_tensors_", dir=root_dir)
consumer = DiskTensorConsumer(temp_dir)
try:
download_object(
from_fqcn=from_fqcn,
ref_id=ref_id,
consumer=consumer,
per_request_timeout=per_request_timeout,
cell=cell,
secure=secure,
optional=optional,
abort_signal=abort_signal,
)
except Exception:
consumer.cleanup()
raise
if consumer.error:
consumer.cleanup()
return consumer.error, None
key_to_file = consumer.result if consumer.result is not None else {}
consumer.release()
return None, LazyTensorDict(key_to_file=key_to_file, temp_dir=temp_dir)