Source code for nvflare.fuel.f3.streaming.obj_downloader

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
import uuid
from abc import ABC, abstractmethod
from typing import Any

from nvflare.apis.signal import Signal
from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode
from nvflare.fuel.f3.cellnet.utils import make_reply, new_cell_message
from nvflare.fuel.f3.message import Message
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.fuel.utils.validation_utils import check_callable, check_object_type
from nvflare.security.logging import secure_format_exception

OBJ_DOWNLOADER_CHANNEL = "obj_downloader__"
OBJ_DOWNLOADER_TOPIC = "obj_downloader__download"

"""
This package provides a framework for building object downloading capability (e.g. file download).

A large object takes a lot of memory space. Sending a large object in one message needs even more memory space since
the object needs to be serialized into large number of bytes. Additional memory space may still be needed for the
transport layer to send the message. If the message is to be sent to multiple endpoints, even more memory is needed.

Object Downloading can drastically reduce memory consumption:
- Instead of sending the large object in one message, it is divided into many smaller objects;
- Instead of pushing the message to the endpoints, each endpoint will come to request. This makes it more reliable when
different endpoints have different speed.

Object Downloading works as follows:
- The sender prepares the object(s) for downloading. It first creates a transaction to get a tx_id. It then adds each
object to be downloaded to the transaction, and get a reference id (ref_id).
- The sender sends the ref_id(s) to all recipients through a separate message.
- Each recipient then calls the download_object function to download each referenced large object.

To develop the downloading capability for a type of object (e.g. a file, a large dict, etc.), you need to provide
the implementation of a Producer and a Consumer.
- On the sending side, the Producer is responsible for producing the next small object to be sent (a chunk of bytes;
a small subset of the large dict; etc.).
- On the receiving side, the Consumer is responsible for processing the received small objects (writing the received
bytes to a temp file; putting the received small dict to the end result; etc.).

One issue with object downloading is object life cycle management. Since the large objects to be downloaded are usually
temporary, you need to remove them when they are downloaded by all sites. But the problem is that you don't know how
quickly each site can finish downloading these large objects. When a transaction contains multiple objects to be
downloaded, it's even harder to know it.

There are two ways to handle this issue: object downloaded callback, and transaction timeout.

You can register an object_downloaded CB when adding an object to transaction. When the object is fully downloaded
to a site, this CB will be called. The obj_downloaded CB must follow this signature:

    downloaded_cb(ref_id: str, to_site: str, status: str, obj: Any, **cb_kwargs)

where ref_id is the reference id of the object;
to_site is the FQCN of the site that has just finished downloading;
status is the status of downloading, as defined in DownloadStatus class;
obj is the large object that was just downloaded;
cb_kwargs are the kw args registered with the CB.

Transaction timeout is the amount of time after the last downloading activity on any object in the
transaction from any site. For example, suppose you want to send 2 large files to 3 sites, each time a download
request is received on any file from any of the 3 sites, the last activity time of the transaction is updated to now.
If no downloading activity is received from any site on any objects in the transaction for the specified timeout,
the transaction is considered "timed out", and the timeout callback registered with the transaction is called.
The transaction timeout CB must follow this signature:

    timeout_cb(tx_id: str, objs: List[Any], **cb_kwargs)

where tx_id is the ID of the transaction;
objs is the list of large objects registered with the transaction;
cb_kwargs are the kw args registered with the CB.

You may need to use both mechanisms to fully take care of object life cycles. The object downloaded CB may never be
called since the site somehow didn't finish the downloading. In reality the timeout mechanism may be sufficient.

Unlike with Object Streamer that the object owner pushes small objects to the recipients; with Object Downloader,
each recipient pulls the data from the object owner.
"""


class _PropKey:
    REF_ID = "ref_id"
    STATE = "state"
    DATA = "data"
    STATUS = "status"


class _Ref:

    def __init__(
        self,
        tx,
        obj: Any,
        ref_id=None,
        obj_downloaded_cb=None,
        **cb_kwargs,
    ):
        if ref_id:
            # use provided ref_id
            self.rid = ref_id
        else:
            self.rid = "R" + str(uuid.uuid4())
        self.tx = tx
        self.obj = obj
        self.obj_downloaded_cb = obj_downloaded_cb
        self.cb_kwargs = cb_kwargs

    def mark_active(self):
        self.tx.mark_active()

    def obj_downloaded(self, to_site: str, status: str):
        if self.obj_downloaded_cb:
            self.obj_downloaded_cb(self.rid, to_site, status, self.obj, **self.cb_kwargs)


[docs] class ProduceRC: """Defines return code for the Producer's produce method.""" OK = "ok" ERROR = "error" EOF = "eof"
[docs] class DownloadStatus: SUCCESS = "success" FAILED = "failed"
[docs] class Producer(ABC): def __init__(self): self.logger = get_obj_logger(self)
[docs] @abstractmethod def produce(self, ref_id: str, obj: Any, state: dict, requester: str) -> (str, Any, dict): """Produce a small object to be sent (on object sender side). Args: ref_id: the ref id of the object being downloaded obj: the large object state: current state of downloading, received from the downloading site requester: the FQCN of the site that is downloading Returns: a tuple of (return code, a small object to be sent, new state to be sent). """ pass
class _Transaction: def __init__( self, producer: Producer, timeout: float, tx_id=None, timeout_cb=None, **cb_kwargs, ): """Constructor of the transaction object. Args: producer: the Producer object to produce small objects. timeout: amount of time since last activity tx_id: if provided, use it; otherwise create one timeout_cb: the CB to be called when the transaction timed out **cb_kwargs: args to be passed to the timeout CB """ check_callable("timeout_cb", timeout_cb) check_object_type("producer", producer, Producer) if tx_id: self.tid = tx_id else: self.tid = "T" + str(uuid.uuid4()) self.producer = producer self.timeout = timeout self.timeout_cb = timeout_cb self.cb_kwargs = cb_kwargs self.last_active_time = time.time() self.refs = [] def produce(self, ref_id: str, obj: Any, state: dict, requester: str): """Called to produce the next small object to be sent. Args: ref_id: ref id of the object being downloaded obj: the large object being downloaded state: current state received from the downloading site (requester). requester: FQCN of the requester Returns: """ return self.producer.produce(ref_id, obj, state, requester) def mark_active(self): """Called to update the last active time of the transaction. Returns: """ self.last_active_time = time.time() def add_download_object( self, obj: Any, ref_id=None, obj_downloaded_cb=None, **cb_kwargs, ): """Add a large object (to be downloaded) to the transaction. Args: obj: the large object to be downloaded ref_id: the ref id to be used, if specified obj_downloaded_cb: the CB to be called when the object is fully downloaded **cb_kwargs: args to be passed to the CB. Returns: """ r = _Ref(self, obj, ref_id, obj_downloaded_cb, **cb_kwargs) self.refs.append(r) return r def timed_out(self): """Called when the transaction is timed out. Returns: """ if self.timeout_cb: self.timeout_cb(self.tid, [r.obj for r in self.refs], **self.cb_kwargs)
[docs] class ObjDownloader: _init_lock = threading.Lock() _tx_table = {} _ref_table = {} _logger = None _tx_monitor = None _tx_lock = threading.Lock() _initialized_cells = {} @classmethod def _initialize(cls, cell: Cell): with cls._init_lock: if not cls._logger: cls._logger = get_obj_logger(cls) if not cls._tx_monitor: cls._tx_monitor = threading.Thread(target=cls._monitor_tx, daemon=True) cls._tx_monitor.start() initialized = cls._initialized_cells.get(id(cell)) if not initialized: # register CBs cell.register_request_cb( channel=OBJ_DOWNLOADER_CHANNEL, topic=OBJ_DOWNLOADER_TOPIC, cb=cls._handle_download, ) cls._initialized_cells[id(cell)] = True
[docs] @classmethod def new_transaction(cls, cell: Cell, producer: Producer, timeout: float, tx_id=None, timeout_cb=None, **cb_kwargs): cls._initialize(cell) tx = _Transaction(producer, timeout, tx_id, timeout_cb, **cb_kwargs) with cls._tx_lock: cls._tx_table[tx.tid] = tx return tx.tid
[docs] @classmethod def add_download_object( cls, transaction_id: str, obj: Any, ref_id=None, obj_downloaded_cb=None, **cb_kwargs, ) -> str: if obj_downloaded_cb is not None: check_callable("obj_downloaded_cb", obj_downloaded_cb) tx = cls._tx_table.get(transaction_id) if not tx: raise ValueError(f"no such transaction {transaction_id}") assert isinstance(tx, _Transaction) ref = tx.add_download_object(obj, ref_id, obj_downloaded_cb, **cb_kwargs) with cls._tx_lock: cls._ref_table[ref.rid] = ref return ref.rid
[docs] @classmethod def delete_transaction(cls, transaction_id: str, call_cb=False): with cls._tx_lock: tx = cls._tx_table.get(transaction_id) if tx: cls._delete_tx(tx, call_cb)
[docs] @classmethod def shutdown(cls): """Shutdown and clean up resources. Returns: None """ with cls._tx_lock: tx_list = list(cls._tx_table.values()) if tx_list: for tx in tx_list: cls._delete_tx(tx, True)
@classmethod def _delete_tx(cls, tx: _Transaction, call_cb=False): if call_cb: try: tx.timed_out() except Exception as ex: cls._logger.error(f"exception from timeout_cb: {secure_format_exception(ex)}") cls._tx_table.pop(tx.tid, None) # remove all refs for r in tx.refs: cls._ref_table.pop(r.rid, None) @classmethod def _handle_download(cls, request: Message) -> Message: requester = request.get_header(MessageHeaderKey.ORIGIN) payload = request.payload assert isinstance(payload, dict) rid = payload.get(_PropKey.REF_ID) if not rid: cls._logger.erro(f"missing {_PropKey.REF_ID} in request from {requester}") return make_reply(ReturnCode.INVALID_REQUEST) current_state = payload.get(_PropKey.STATE) with cls._tx_lock: ref = cls._ref_table.get(rid) if not ref: cls._logger.error(f"no ref found for {rid} from {requester}") return make_reply(ReturnCode.INVALID_REQUEST) assert isinstance(ref, _Ref) ref.mark_active() tx = ref.tx assert isinstance(tx, _Transaction) try: rc, data, new_state = tx.produce(rid, ref.obj, current_state, requester) except Exception as ex: cls._logger.error(f"Producer {type(tx.producer)} encountered exception: {secure_format_exception(ex)}") return make_reply(ReturnCode.PROCESS_EXCEPTION) if rc != ProduceRC.OK: # already done ref.obj_downloaded( requester, status=DownloadStatus.SUCCESS if rc == ProduceRC.EOF else DownloadStatus.FAILED ) return make_reply(ReturnCode.OK, body={_PropKey.STATUS: rc}) else: # continue return make_reply( ReturnCode.OK, body={ _PropKey.STATUS: rc, _PropKey.STATE: new_state, _PropKey.DATA: data, }, ) @classmethod def _monitor_tx(cls): while True: now = time.time() expired_tx = [] with cls._tx_lock: for tid, tx in cls._tx_table.items(): assert isinstance(tx, _Transaction) if now - tx.last_active_time > tx.timeout: expired_tx.append(tx) if expired_tx: for tx in expired_tx: cls._delete_tx(tx, True) time.sleep(5.0)
[docs] class Consumer(ABC): def __init__(self): self.logger = get_obj_logger(self)
[docs] @abstractmethod def consume(self, ref_id: str, state: dict, data: Any) -> dict: """Called to process the received data. Args: ref_id: ref id of the object being downloaded state: current state of downloading data: data to be processed Returns: new state to be sent back to the data owner. """ pass
[docs] @abstractmethod def download_completed(self, ref_id: str): """Called when the downloading is finished successfully. Args: ref_id: ref id of the object being downloaded Returns: None """ pass
[docs] @abstractmethod def download_failed(self, ref_id: str, reason: str): """Called when the downloading is finished unsuccessfully. Args: ref_id: ref id of the object being downloaded reason: explain the reason of failure Returns: None """ pass
[docs] def download_object( from_fqcn: str, ref_id: str, per_request_timeout: float, cell: Cell, consumer: Consumer, secure=False, optional=False, abort_signal: Signal = None, ): """Download a large object from the object owner. Args: from_fqcn: the FQCN of the object owner ref_id: reference id of the object to be downloaded per_request_timeout: timeout for each request to the object owner. cell: the cell to be used for communication withe object owner. consumer: the Consumer object used for processing received data secure: use P2P private communication with the data owner optional: supress log messages abort_signal: for signaling abort Returns: None """ request = new_cell_message( headers={}, payload={ _PropKey.REF_ID: ref_id, }, ) while True: start_time = time.time() reply = cell.send_request( channel=OBJ_DOWNLOADER_CHANNEL, target=from_fqcn, topic=OBJ_DOWNLOADER_TOPIC, request=request, timeout=per_request_timeout, secure=secure, optional=optional, abort_signal=abort_signal, ) duration = time.time() - start_time if abort_signal and abort_signal.triggered: consumer.download_failed(ref_id, f"download aborted after {duration} secs") return assert isinstance(reply, Message) rc = reply.get_header(MessageHeaderKey.RETURN_CODE) if rc != ReturnCode.OK: consumer.download_failed(ref_id, f"error requesting data from {from_fqcn} after {duration} secs: {rc}") return payload = reply.payload assert isinstance(payload, dict) status = payload.get(_PropKey.STATUS) if status == ProduceRC.EOF: consumer.download_completed(ref_id) return elif status == ProduceRC.ERROR: consumer.download_failed(ref_id, f"producer error after {duration} secs") return # continue data = payload.get(_PropKey.DATA) state = payload.get(_PropKey.STATE) try: new_state = consumer.consume(ref_id, state, data) except Exception as ex: consumer.download_failed(ref_id, f"exception when consuming data: {secure_format_exception(ex)}") return if not isinstance(new_state, dict): consumer.download_failed(ref_id, f"consumer error: new_state should be dict but got {type(new_state)}") return if abort_signal and abort_signal.triggered: consumer.download_failed(ref_id, "download aborted") return # ask for more request = new_cell_message(headers={}, payload={_PropKey.REF_ID: ref_id, _PropKey.STATE: new_state})