# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
import uuid
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple
from nvflare.apis.signal import Signal
from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode
from nvflare.fuel.f3.cellnet.utils import make_reply, new_cell_message
from nvflare.fuel.f3.message import Message
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.security.logging import secure_format_exception
OBJ_DOWNLOADER_CHANNEL = "download_service__"
OBJ_DOWNLOADER_TOPIC = "download_service__download"
"""
This package provides a framework for building object downloading capability (file download, tensor download, etc.).
A large object takes a lot of memory space. Sending a large object in one message needs even more memory space since
the object needs to be serialized into large number of bytes. Additional memory space may still be needed for the
transport layer to send the message. If the message is to be sent to multiple endpoints, even more memory is needed.
Object Downloading can drastically reduce memory consumption:
- Instead of sending the large object in one message, it is divided into many smaller objects;
- Instead of pushing the message to the endpoints, each endpoint will come to request. This makes it more reliable when
different endpoints have different speed.
Object Downloading works as follows:
- The sender prepares the object(s) for downloading. It first creates a transaction to get a tx_id. It then adds each
object (called Downloadable) to be downloaded to the transaction, and get a reference id (ref_id).
- The sender sends the ref_id(s) to all recipients through a separate message.
- Each recipient then calls the download_object function to download each referenced large object.
Note that the endpoint that received object refs may forward the refs to another endpoint, which then downloads the
referenced object(s).
To develop the downloading capability for a type of object (e.g. a file, a tensor state dict, etc.), you need to provide
the implementation of a Downloadable and a Consumer.
- On the sending side, the Downloadable is responsible for producing the next small object to be sent (a chunk of bytes;
a small subset of the large dict; etc.).
- On the receiving side, the Consumer is responsible for processing the received small objects (writing the received
bytes to a temp file; putting the received small dict to the end result; etc.).
One issue with object downloading is object life cycle management. Since the large objects to be downloaded are usually
temporary, you need to remove them when they are downloaded by all receivers. But the problem is that you don't know how
quickly each receiver can finish downloading these large objects. When a transaction contains multiple objects to be
downloaded, it's even harder to know it.
There are two ways to handle this issue: object downloaded callback, and transaction timeout.
You can implement the downloaded_to_one method for the Downloadable object. This method is called when the object is
downloaded to one receiver.
You can also implement the downloaded_to_all method for the Downloadable object. This method is called when the object
is downloaded to all receivers.
Note that the downloaded_to_all method only works if you know how many receivers the object will be downloaded to!
You can always implement the transaction_done method for the Downloadable object. This method is called when the
transaction is done for some reason (normal completion or timeout).
Transaction timeout is the amount of time after the last downloading activity on any object in the
transaction from any receiver. For example, suppose you want to send 2 large files to 3 receivers, each time a download
request is received on any file from any of the 3 receivers, the last activity time of the transaction is updated to now.
If no downloading activity is received from any receiver on any objects in the transaction for the specified timeout,
the transaction is considered "timed out", and the transaction_done method is called for each Downloadable object
added to the transaction.
Unlike with Object Streamer that the object owner pushes small objects to the recipients; with Object Downloader,
each recipient pulls the data from the object owner.
"""
[docs]
class Downloadable(ABC):
def __init__(self, obj: Any):
self.base_obj = obj
[docs]
def set_transaction(self, tx_id: str, ref_id: str):
"""This method is called when the object is added to a transaction.
You can use this method to keep transaction ID and/or ref ID for your own purpose.
Args:
tx_id: the ID of the transaction that the object has been added to.
ref_id: ref ID generated for the object.
Returns: None
"""
pass
[docs]
@abstractmethod
def produce(self, state: dict, requester: str) -> Tuple[str, Any, dict]:
"""Produce a small object to be sent (on object sender side).
Args:
state: current state of downloading, received from the downloading receiver
requester: the FQCN of the receiver that is downloading
Returns: a tuple of (return code, a small object to be sent, new state to be sent).
"""
pass
[docs]
def downloaded_to_one(self, to_receiver: str, status: str):
"""Called when an object is downloaded to a receiver.
Args:
to_receiver: name of the receiver that the object has been completely downloaded to.
status: the download status: DownloadStatus.SUCCESS or DownloadStatus.FAILED.
Returns: None
"""
pass
[docs]
def downloaded_to_all(self):
"""Called when the object is fully downloaded to all receivers."""
pass
[docs]
def transaction_done(self, transaction_id: str, status: str):
"""Called when the transaction is finished.
Args:
transaction_id: ID of the transaction.
status: completion status, a value defined in TransactionDoneStatus.
Returns: None
"""
pass
[docs]
def release(self):
"""Drop the infrastructure reference to the source object.
Called by _Transaction.transaction_done() AFTER the transaction_done_cb
fires. Subclasses should override this to null their base_obj (or any
other large reference) so the GC can reclaim the memory immediately.
The default implementation is a no-op.
"""
pass
class _PropKey:
REF_ID = "ref_id"
STATE = "state"
DATA = "data"
STATUS = "status"
class _Ref:
def __init__(
self,
tx,
obj: Downloadable,
ref_id=None,
):
if ref_id:
# use provided ref_id
self.rid = ref_id
else:
self.rid = "R" + str(uuid.uuid4())
self.tx = tx
self.obj = obj
self.num_receivers_done = 0
self.receiver_statuses = {}
self._downloaded_to_all_called = False
def mark_active(self):
self.tx.mark_active()
def obj_downloaded(self, to_receiver: str, status: str):
if to_receiver in self.receiver_statuses:
return
self.receiver_statuses[to_receiver] = status
self.num_receivers_done = len(self.receiver_statuses)
assert isinstance(self.obj, Downloadable)
self.obj.downloaded_to_one(to_receiver, status)
assert isinstance(self.tx, _Transaction)
if 0 < self.tx.num_receivers <= self.num_receivers_done and not self._downloaded_to_all_called:
# this object is done for all receivers
self._downloaded_to_all_called = True
self.obj.downloaded_to_all()
[docs]
class ProduceRC:
"""Defines return code for the Downloadable object's 'produce' method."""
OK = "ok"
ERROR = "error"
EOF = "eof"
[docs]
class DownloadStatus:
"""Constants for object download status."""
SUCCESS = "success"
FAILED = "failed"
[docs]
class TransactionDoneStatus:
"""Constants for transaction completion status."""
FINISHED = "finished"
TIMEOUT = "timeout"
DELETED = "deleted"
class _FinishedRef:
def __init__(self, receiver_statuses: dict[str, str], timestamp: float):
self.receiver_statuses = receiver_statuses
self.last_active_time = timestamp
def expired(self, now: float, ttl: float) -> bool:
return now - self.last_active_time > ttl
class _Transaction:
def __init__(
self,
timeout: float,
num_receivers: int,
tx_id=None,
transaction_done_cb=None,
cb_kwargs=None,
):
"""Constructor of the transaction object.
Args:
timeout: amount of time since last activity
num_receivers: number of receivers. 0 means unlimited.
tx_id: if provided, use it; otherwise create one
"""
if tx_id:
self.tid = tx_id
else:
self.tid = "T" + str(uuid.uuid4())
self.timeout = timeout
self.num_receivers = num_receivers
self.last_active_time = time.time()
self.start_time = time.time()
self.total_bytes = 0
self.transaction_done_cb = transaction_done_cb
self.cb_kwargs = cb_kwargs
self.refs = []
self.logger = get_obj_logger(self)
def mark_active(self):
"""Called to update the last active time of the transaction.
Returns:
"""
self.last_active_time = time.time()
def add_object(
self,
obj: Downloadable,
ref_id=None,
):
"""Add a large object (to be downloaded) to the transaction.
Args:
obj: the large object to be downloaded
ref_id: the ref id to be used, if specified
Returns:
"""
r = _Ref(self, obj, ref_id)
self.refs.append(r)
obj.set_transaction(self.tid, r.rid)
return r
def timed_out(self):
"""Called when the transaction is timed out.
Returns:
"""
self.transaction_done(TransactionDoneStatus.TIMEOUT)
def is_finished(self):
"""Check whether the transaction is finished (all objects are downloaded)."""
if self.num_receivers <= 0:
return False
for ref in self.refs:
assert isinstance(ref, _Ref)
if ref.num_receivers_done < self.num_receivers:
return False
return True
def transaction_done(self, status: str):
"""Called when the transaction is finished."""
elapsed = time.time() - self.start_time
size_mb = self.total_bytes / (1024 * 1024)
self.logger.info(
f"[server] download tx {self.tid} done: status={status} elapsed={elapsed:.2f}s "
f"size={size_mb:.1f}MB ({self.total_bytes:,} bytes)"
)
# Snapshot base_objs BEFORE the loop so the callback receives the
# original objects. obj.transaction_done() may clear the chunk cache
# (CacheableObject.clear_cache()); the source object itself is released
# via obj.release() AFTER the callback so the callback can still
# observe it (e.g. for memory-GC notifications).
base_objs = [ref.obj.base_obj for ref in self.refs]
for ref in self.refs:
obj = ref.obj
assert isinstance(obj, Downloadable)
obj.transaction_done(self.tid, status)
if self.transaction_done_cb:
self.transaction_done_cb(self.tid, status, base_objs, **self.cb_kwargs)
# Release source objects after the callback so the callback can still
# reference them. This drops the last infrastructure reference to
# large objects (e.g. numpy dicts) allowing GC to reclaim them.
for ref in self.refs:
ref.obj.release()
[docs]
class TransactionInfo:
"""This structure contains public info of a transaction:
timeout value of the transaction;
number of receivers that objects in the transaction will be downloaded to. 0 means unknown.
objects that are added to the transaction.
"""
def __init__(self, tx: _Transaction):
self.timeout = tx.timeout
self.num_receivers = tx.num_receivers
self.objects = [r.obj for r in tx.refs]
[docs]
class DownloadService:
_init_lock = threading.Lock()
_tx_table = {}
_ref_table = {}
# Ref tombstones let a client retry a lost/delayed EOF reply after the source
# transaction has been cleaned up without turning a completed transfer into a fatal missing-ref error.
_finished_refs = {}
_FINISHED_REFS_TTL = 1800.0
_logger = None
_tx_monitor = None
_tx_lock = threading.Lock()
_initialized_cells = {}
@classmethod
def _initialize(cls, cell: Cell):
with cls._init_lock:
if not cls._logger:
cls._logger = get_obj_logger(cls)
if not cls._tx_monitor:
cls._tx_monitor = threading.Thread(target=cls._monitor_tx, daemon=True)
cls._tx_monitor.start()
initialized = cls._initialized_cells.get(id(cell))
if not initialized:
# register CBs
cell.register_request_cb(
channel=OBJ_DOWNLOADER_CHANNEL,
topic=OBJ_DOWNLOADER_TOPIC,
cb=cls._handle_download,
)
cls._initialized_cells[id(cell)] = True
[docs]
@classmethod
def new_transaction(
cls,
cell: Cell,
timeout: float,
num_receivers: int = 0,
tx_id=None,
transaction_done_cb=None,
**cb_kwargs,
):
cls._initialize(cell)
tx = _Transaction(timeout, num_receivers, tx_id, transaction_done_cb, cb_kwargs)
with cls._tx_lock:
cls._tx_table[tx.tid] = tx
return tx.tid
[docs]
@classmethod
def add_object(
cls,
transaction_id: str,
obj: Downloadable,
ref_id=None,
) -> str:
if not isinstance(obj, Downloadable):
raise ValueError(f"obj must be of type {Downloadable} but got {type(obj)}")
tx = cls._tx_table.get(transaction_id)
if not tx:
raise ValueError(f"no such transaction {transaction_id}")
assert isinstance(tx, _Transaction)
ref = tx.add_object(obj, ref_id)
with cls._tx_lock:
cls._ref_table[ref.rid] = ref
cls._finished_refs.pop(ref.rid, None)
return ref.rid
[docs]
@classmethod
def delete_transaction(cls, transaction_id: str):
with cls._tx_lock:
tx = cls._tx_table.get(transaction_id)
if tx:
cls._delete_tx(tx)
tx.transaction_done(TransactionDoneStatus.DELETED)
[docs]
@classmethod
def shutdown(cls):
"""Shutdown and clean up resources.
Returns: None
"""
with cls._tx_lock:
tx_list = list(cls._tx_table.values())
if tx_list:
for tx in tx_list:
cls._delete_tx(tx)
tx.transaction_done(TransactionDoneStatus.DELETED)
cls._finished_refs.clear()
@classmethod
def _delete_tx(cls, tx: _Transaction, tombstone_finished_refs: bool = False):
cls._tx_table.pop(tx.tid, None)
# remove all refs
now = time.time() if tombstone_finished_refs else None
for r in tx.refs:
cls._ref_table.pop(r.rid, None)
if tombstone_finished_refs:
cls._finished_refs[r.rid] = _FinishedRef(dict(r.receiver_statuses), now)
else:
cls._finished_refs.pop(r.rid, None)
@classmethod
def _expire_finished_refs(cls, now: float):
if not cls._finished_refs:
return
expired_refs = [
rid for rid, finished_ref in cls._finished_refs.items() if finished_ref.expired(now, cls._FINISHED_REFS_TTL)
]
for rid in expired_refs:
cls._finished_refs.pop(rid, None)
@classmethod
def _get_finished_ref_status(cls, rid: str, requester: str) -> Optional[str]:
now = time.time()
finished_ref = cls._finished_refs.get(rid)
if not finished_ref:
return None
if finished_ref.expired(now, cls._FINISHED_REFS_TTL):
cls._finished_refs.pop(rid, None)
return None
status = finished_ref.receiver_statuses.get(requester)
return status
[docs]
@classmethod
def get_transaction_info(cls, transaction_id: str) -> Optional[TransactionInfo]:
tx = cls._tx_table.get(transaction_id)
if not tx:
return None
else:
return TransactionInfo(tx)
[docs]
@classmethod
def get_transaction_id(cls, ref_id: str) -> Optional[str]:
ref = cls._ref_table.get(ref_id)
if not ref:
return None
else:
assert isinstance(ref, _Ref)
return ref.tx.tid
@classmethod
def _handle_download(cls, request: Message) -> Message:
requester = request.get_header(MessageHeaderKey.ORIGIN)
payload = request.payload
assert isinstance(payload, dict)
rid = payload.get(_PropKey.REF_ID)
if not rid:
cls._logger.error(f"missing {_PropKey.REF_ID} in request from {requester}")
return make_reply(ReturnCode.INVALID_REQUEST)
current_state = payload.get(_PropKey.STATE)
with cls._tx_lock:
ref = cls._ref_table.get(rid)
if not ref:
finished_status = cls._get_finished_ref_status(rid, requester)
if finished_status == DownloadStatus.SUCCESS:
cls._logger.debug(f"finished ref {rid} from {requester} retried - returning EOF")
return make_reply(ReturnCode.OK, body={_PropKey.STATUS: ProduceRC.EOF})
elif finished_status == DownloadStatus.FAILED:
cls._logger.debug(f"finished ref {rid} from {requester} retried - returning ERROR")
return make_reply(ReturnCode.OK, body={_PropKey.STATUS: ProduceRC.ERROR})
cls._logger.error(f"no ref found for {rid} from {requester}")
return make_reply(ReturnCode.INVALID_REQUEST)
assert isinstance(ref, _Ref)
ref.mark_active()
tx = ref.tx
assert isinstance(tx, _Transaction)
try:
rc, data, new_state = ref.obj.produce(current_state, requester)
except Exception as ex:
cls._logger.error(
f"Object {type(ref.obj)} encountered exception when produce: {secure_format_exception(ex)}"
)
return make_reply(ReturnCode.PROCESS_EXCEPTION)
if rc != ProduceRC.OK:
# already done
ref.obj_downloaded(
requester, status=DownloadStatus.SUCCESS if rc == ProduceRC.EOF else DownloadStatus.FAILED
)
return make_reply(ReturnCode.OK, body={_PropKey.STATUS: rc})
else:
# continue — accumulate bytes for timing summary in transaction_done()
# CacheableObject returns a list of byte-chunks; FileDownloader returns raw bytes.
# Sum chunk lengths for lists (len(list) counts items, not bytes).
if data is not None:
tx.total_bytes += sum(len(c) for c in data) if isinstance(data, list) else len(data)
return make_reply(
ReturnCode.OK,
body={
_PropKey.STATUS: rc,
_PropKey.STATE: new_state,
_PropKey.DATA: data,
},
)
@classmethod
def _monitor_tx(cls):
while True:
now = time.time()
expired_tx = []
finished_tx = []
with cls._tx_lock:
for tid, tx in cls._tx_table.items():
assert isinstance(tx, _Transaction)
# check whether all refs are done
if tx.is_finished():
finished_tx.append(tx)
elif now - tx.last_active_time > tx.timeout:
expired_tx.append(tx)
for tx in expired_tx:
assert isinstance(tx, _Transaction)
tx.transaction_done(TransactionDoneStatus.TIMEOUT)
cls._delete_tx(tx)
for tx in finished_tx:
tx.transaction_done(TransactionDoneStatus.FINISHED)
cls._delete_tx(tx, tombstone_finished_refs=True)
cls._expire_finished_refs(now)
time.sleep(5.0)
[docs]
class Consumer(ABC):
def __init__(self):
self.logger = get_obj_logger(self)
[docs]
@abstractmethod
def consume(self, ref_id: str, state: dict, data: Any) -> dict:
"""Called to process the received data.
Args:
ref_id: ref id of the object being downloaded
state: current state of downloading
data: data to be processed
Returns: new state to be sent back to the data owner.
"""
pass
[docs]
@abstractmethod
def download_completed(self, ref_id: str):
"""Called when the downloading is finished successfully.
Args:
ref_id: ref id of the object being downloaded
Returns: None
"""
pass
[docs]
@abstractmethod
def download_failed(self, ref_id: str, reason: str):
"""Called when the downloading is finished unsuccessfully.
Args:
ref_id: ref id of the object being downloaded
reason: explain the reason of failure
Returns: None
"""
pass
[docs]
def download_object(
from_fqcn: str,
ref_id: str,
per_request_timeout: float,
cell: Cell,
consumer: Consumer,
secure=False,
optional=False,
abort_signal: Signal = None,
max_retries: int = 3,
):
"""Download a large object from the object owner.
Args:
from_fqcn: the FQCN of the object owner
ref_id: reference id of the object to be downloaded
per_request_timeout: timeout for each request to the object owner.
cell: the cell to be used for communication with the object owner.
consumer: the Consumer object used for processing received data
secure: use P2P private communication with the data owner
optional: suppress log messages
abort_signal: for signaling abort
max_retries: max number of retries per request on TIMEOUT (default 3).
Resending the same state causes the producer to re-generate the
same chunk, so retry is data-safe. Note: CacheableObject's
_adjust_cache may run twice for the same state on retry, which
can prematurely evict cache entries in multi-receiver scenarios
but does not affect data correctness.
Returns: None
"""
logger = get_obj_logger(download_object)
if max_retries < 0:
raise ValueError(f"max_retries must be non-negative, got {max_retries}")
consecutive_timeouts = 0
total_bytes = 0
download_start = time.time()
# Track current download state (None = initial request).
# On retry, resend the same state so producer re-generates the same chunk.
current_state = None
while True:
# Build a fresh request each iteration (including retries)
# to avoid re-encoding an already-encoded message.
request_payload = {_PropKey.REF_ID: ref_id}
if current_state is not None:
request_payload[_PropKey.STATE] = current_state
request = new_cell_message(headers={}, payload=request_payload)
start_time = time.time()
reply = cell.send_request(
channel=OBJ_DOWNLOADER_CHANNEL,
target=from_fqcn,
topic=OBJ_DOWNLOADER_TOPIC,
request=request,
timeout=per_request_timeout,
secure=secure,
optional=optional,
abort_signal=abort_signal,
)
duration = time.time() - start_time
if abort_signal and abort_signal.triggered:
consumer.download_failed(ref_id, f"download aborted after {duration} secs")
return
assert isinstance(reply, Message)
rc = reply.get_header(MessageHeaderKey.RETURN_CODE)
if rc != ReturnCode.OK:
# Retry on TIMEOUT: streaming transport may intermittently lose
# responses. Resending the same state re-generates the same
# chunk, making retry data-safe (see docstring for caveats).
if rc == ReturnCode.TIMEOUT:
if consecutive_timeouts < max_retries:
consecutive_timeouts += 1
backoff = min(2.0 * (2 ** (consecutive_timeouts - 1)), 60.0)
logger.warning(
f"[DOWNLOAD_RETRY] Request to {from_fqcn} timed out after {duration:.1f}s "
f"(ref={ref_id}, retry {consecutive_timeouts}/{max_retries}, "
f"backoff={backoff:.1f}s). Resending same state to re-request the chunk."
)
# Check abort signal before sleeping to minimise delay
if abort_signal and abort_signal.triggered:
consumer.download_failed(ref_id, f"download aborted after {duration} secs")
return
time.sleep(backoff)
if abort_signal and abort_signal.triggered:
consumer.download_failed(ref_id, f"download aborted after {duration} secs")
return
continue
else:
logger.warning(
f"[DOWNLOAD_FAILED] Max retries ({max_retries}) exhausted for {from_fqcn}, "
f"ref={ref_id}. Giving up."
)
consumer.download_failed(ref_id, f"error requesting data from {from_fqcn} after {duration} secs: {rc}")
return
# Log recovery if we were retrying
if consecutive_timeouts > 0:
logger.warning(
f"[DOWNLOAD_RECOVERED] Download from {from_fqcn} recovered after "
f"{consecutive_timeouts} timeout(s) (ref={ref_id})."
)
consecutive_timeouts = 0
payload = reply.payload
assert isinstance(payload, dict)
status = payload.get(_PropKey.STATUS)
if status == ProduceRC.EOF:
elapsed = time.time() - download_start
size_mb = total_bytes / (1024 * 1024)
logger.info(
f"[client] download ref={ref_id} done: elapsed={elapsed:.2f}s "
f"size={size_mb:.1f}MB ({total_bytes:,} bytes)"
)
consumer.download_completed(ref_id)
return
elif status == ProduceRC.ERROR:
consumer.download_failed(ref_id, f"producer error after {duration} secs")
return
# continue
# CacheableObject sends a list of byte-chunks; FileDownloader sends raw bytes.
data = payload.get(_PropKey.DATA)
if data is not None:
total_bytes += sum(len(c) for c in data) if isinstance(data, list) else len(data)
state = payload.get(_PropKey.STATE)
try:
new_state = consumer.consume(ref_id, state, data)
except Exception as ex:
consumer.download_failed(ref_id, f"exception when consuming data: {secure_format_exception(ex)}")
return
if not isinstance(new_state, dict):
consumer.download_failed(ref_id, f"consumer error: new_state should be dict but got {type(new_state)}")
return
if abort_signal and abort_signal.triggered:
consumer.download_failed(ref_id, "download aborted")
return
# Update state for next request
current_state = new_state