nvflare.app_opt.tensor_stream.consumer module

class TensorConsumer(stream_ctx: dict, fl_ctx: FLContext)[source]

Bases: ObjectConsumer

TensorConsumer handles receiving and reconstructing torch tensors from a stream of byte objects.

logger

Logger for logging messages.

tensors_map

Dictionary to store received tensors.

total_bytes

Dictionary to track total bytes received per root key.

num_tensors

Dictionary to track number of tensors received per root key.

task_ids

Set to track unique task IDs received.

Initialize the TensorConsumer.

Parameters:
  • stream_ctx (StreamContext) – The stream context for the current operation. (not used)

  • fl_ctx (FLContext) – The FL context for the current operation. (not used)

consume(shareable: Shareable, stream_ctx: dict, fl_ctx: FLContext) tuple[bool, Shareable][source]

Consume a shareable object and extract tensors.

Parameters:
  • shareable (Shareable) – The shareable object containing tensor data.

  • stream_ctx (StreamContext) – The stream context for the current operation. (not used)

  • fl_ctx (FLContext) – The FL context for the current operation. (not used)

Returns:

A tuple containing a success flag and a reply shareable.

Return type:

tuple[bool, Shareable]

finalize(stream_ctx: dict, fl_ctx: FLContext)[source]

Finalize the consumer, ensuring all data is written and resources are released.

It updates the FLContext with the received tensors.

Parameters:
  • stream_ctx (StreamContext) – The stream context. (not used)

  • fl_ctx (FLContext) – The FL context. (not used)

log_received(task_id: str, identity: str, peer_name: str)[source]

Log the received tensors for debugging purposes.

Parameters:
  • identity (str) – The identity of the peer.

  • peer_name (str) – The name of the peer.

  • tensor_keys (list[str]) – The keys of the received tensors.

process_shareable(shareable: Shareable)[source]

Process a received shareable object containing tensor data.

Parameters:

shareable (Shareable) – The shareable object containing tensor data.

Raises:

ValueError – If the shareable object is invalid or contains errors.

class TensorConsumerFactory[source]

Bases: ConsumerFactory

Factory for creating TensorConsumer instances.

get_consumer(stream_ctx, fl_ctx)[source]

Creates and returns a TensorConsumer instance.

get_consumer(stream_ctx: dict, fl_ctx: FLContext) ObjectConsumer[source]

Called to get an ObjectConsumer to process a new stream on the receiving side. This is called only when the 1st streaming object is received for each stream.

Parameters:
  • stream_ctx – the context of the stream

  • fl_ctx – FLContext object

Returns: an ObjectConsumer