nvflare.app_opt.tensor_stream.consumer module¶
- class TensorConsumer(stream_ctx: dict, fl_ctx: FLContext)[source]¶
Bases:
ObjectConsumerTensorConsumer handles receiving and reconstructing torch tensors from a stream of byte objects.
- logger¶
Logger for logging messages.
- tensors_map¶
Dictionary to store received tensors.
- total_bytes¶
Dictionary to track total bytes received per root key.
- num_tensors¶
Dictionary to track number of tensors received per root key.
- task_ids¶
Set to track unique task IDs received.
Initialize the TensorConsumer.
- Parameters:
stream_ctx (StreamContext) – The stream context for the current operation. (not used)
fl_ctx (FLContext) – The FL context for the current operation. (not used)
- consume(shareable: Shareable, stream_ctx: dict, fl_ctx: FLContext) tuple[bool, Shareable][source]¶
Consume a shareable object and extract tensors.
- Parameters:
shareable (Shareable) – The shareable object containing tensor data.
stream_ctx (StreamContext) – The stream context for the current operation. (not used)
fl_ctx (FLContext) – The FL context for the current operation. (not used)
- Returns:
A tuple containing a success flag and a reply shareable.
- Return type:
tuple[bool, Shareable]
- finalize(stream_ctx: dict, fl_ctx: FLContext)[source]¶
Finalize the consumer, ensuring all data is written and resources are released.
It updates the FLContext with the received tensors.
- Parameters:
stream_ctx (StreamContext) – The stream context. (not used)
fl_ctx (FLContext) – The FL context. (not used)
- log_received(task_id: str, identity: str, peer_name: str)[source]¶
Log the received tensors for debugging purposes.
- Parameters:
identity (str) – The identity of the peer.
peer_name (str) – The name of the peer.
tensor_keys (list[str]) – The keys of the received tensors.
Process a received shareable object containing tensor data.
- Parameters:
shareable (Shareable) – The shareable object containing tensor data.
- Raises:
ValueError – If the shareable object is invalid or contains errors.
- class TensorConsumerFactory[source]¶
Bases:
ConsumerFactoryFactory for creating TensorConsumer instances.
- get_consumer(stream_ctx: dict, fl_ctx: FLContext) ObjectConsumer[source]¶
Called to get an ObjectConsumer to process a new stream on the receiving side. This is called only when the 1st streaming object is received for each stream.
- Parameters:
stream_ctx – the context of the stream
fl_ctx – FLContext object
Returns: an ObjectConsumer