nvflare.app_opt.tensor_stream.receiver module

class TensorReceiver(engine: StreamableEngine, ctx_prop_key: str, format: str = 'pytorch', channel: str = 'tensor_stream')[source]

Bases: object

A component to receive tensors from clients using NVFlare’s streaming capabilities.

Initialize the TensorReceiver.

Parameters:
  • engine (StreamableEngine) – The streamable engine to use for streaming.

  • ctx_prop_key (str) – The context property key to receive tensors for.

  • format (str) – The format of the tensors to receive. Default is ExchangeFormat.PYTORCH.

  • channel (str) – The channel to use for streaming. Default is TENSORS_CHANNEL.

on_tensor_received(task_id: str, tensor: dict[str, Tensor] | dict[str, dict[str, Tensor]])[source]

Callback when tensors are received.

Parameters:
  • task_id (str) – The task ID associated with the tensors.

  • tensor (TensorsMap) – The tensors received.

set_ctx_with_tensors(fl_ctx: FLContext)[source]

Update the context with the received tensors.

Parameters:

fl_ctx (FLContext) – The FLContext for the current operation.

wait_for_tensors(task_id: str, peer_name: str, timeout: float = 5.0)[source]

Wait for tensors to be received for a specific task ID.

Tensors are always sent before the task data and results to ensure they arrive before any related events are handled. However, the consumer may still be processing the final tensor chunk when the task data or result is received. This processing usually finishes within a few milliseconds, but in some cases, the client or server might receive the task before the tensors are fully available. To handle this safely, a default timeout of 5 seconds is applied.

Parameters:
  • task_id (str) – The task ID to wait for.

  • peer_name (str) – The peer name associated with the task.

  • timeout (float) – The maximum time to wait in seconds.