nvflare.app_opt.pt.tensor_downloader module

class DiskTensorConsumer(temp_dir: str)[source]

Bases: ItemConsumer

Writes raw safetensors bytes to disk without deserializing to tensors.

cleanup() None[source]
consume_items(items: List[Any], result: Any) Any[source]

Process items and return updated result.

download_failed(ref_id, reason: str)[source]

Called when the downloading is finished unsuccessfully.

Parameters:
  • ref_id – ref id of the object being downloaded

  • reason – explain the reason of failure

Returns: None

release() None[source]
class TensorConsumer(tensors_received_cb, cb_kwargs)[source]

Bases: ItemConsumer

consume_items(items: List[Any], result: Any) Any[source]

Process items and return updated result.

class TensorDownloadable(tensors: dict[str, Tensor], max_chunk_size: int)[source]

Bases: CacheableObject

Constructor of CacheableObject.

Parameters:
  • obj – the object to be downloaded.

  • max_chunk_size – max number of bytes for each chunk.

Notes: The object must be able to be divided into multiple items. A chunk is generated for each item.

get_item_count() int[source]

The subclass must implement this method to return the number of items the object contains.

Returns: the number of items the object contains

produce_item(index: int) bytes[source]

This method is called to produce the chunk for the specified item.

Parameters:

index – index of the item.

Returns: a chunk for the item

add_tensors(downloader: ObjectDownloader, tensors: dict[str, Tensor], max_chunk_size: int = 2097152) str[source]

Add tensors to be downloaded to the specified downloader.

Parameters:
  • downloader – the downloader to add tensors to.

  • tensors – state dict to be downloaded

  • max_chunk_size – max chunk size

Returns: reference id for the state dict.

cleanup_active_disk_tensor_downloads(reason: str = 'download aborted') None[source]

Clean partial tensor offload dirs still owned by active disk consumers.

download_tensors(from_fqcn: str, ref_id: str, per_request_timeout: float, cell: Cell, secure=False, optional=False, abort_signal=None, tensors_received_cb=None, **cb_kwargs) Tuple[str, dict[str, Tensor] | None][source]

Download the referenced state dict from the source.

Parameters:
  • from_fqcn – FQCN of the data source.

  • ref_id – reference ID of the state dict to be downloaded.

  • per_request_timeout – timeout for requests sent to the data source.

  • cell – cell to be used for communicating to the data source.

  • secure – P2P private mode for communication

  • optional – supress log messages of communication

  • abort_signal – signal for aborting download.

  • tensors_received_cb – the callback to be called when one set of tensors are received

Returns: tuple of (error message if any, downloaded state dict).

download_tensors_to_disk(from_fqcn: str, ref_id: str, per_request_timeout: float, cell: Cell, secure=False, optional=False, abort_signal=None) Tuple[str, LazyTensorDict | None][source]

Download tensors to disk instead of memory.

Returns: tuple of (error message if any, LazyTensorDict for lazy access).