# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.streaming import StreamableEngine, StreamContext
from nvflare.fuel.utils.log_utils import get_obj_logger
from .producer import TensorProducer
from .types import TENSORS_CHANNEL, TensorsMap
from .utils import get_dxo_from_ctx, get_targets_from_ctx_and_prop_key, get_topic_for_ctx_prop_key
[docs]
class TensorSender:
"""Handles sending tensors between server and clients."""
def __init__(
self,
engine: StreamableEngine,
ctx_prop_key: str,
format: str,
tasks: list[str],
channel: str = TENSORS_CHANNEL,
):
"""Initialize the TensorSender.
Args:
engine (StreamableEngine): The streamable engine to use for streaming.
format (str): The format of the tensors to send.
tasks (list[str]): The list of tasks to send tensors for.
ctx_prop_key (str): The context property key to send tensors for.
channel (str): The channel to use for streaming. Default is TENSORS_CHANNEL.
"""
self.engine = engine
self.ctx_prop_key = ctx_prop_key
self.format = format
self.tasks = tasks
self.channel = channel
# key: task_id, value: tensors to send to the peer
self.task_params: dict[str, TensorsMap] = defaultdict(dict)
self.logger = get_obj_logger(self)
[docs]
def store_tensors(self, fl_ctx: FLContext):
"""Parse tensors from the FLContext and store them for sending.
Args:
fl_ctx (FLContext): The FLContext for the current operation.
"""
peer_name = fl_ctx.get_peer_context().get_identity_name()
task_id = fl_ctx.get_prop(FLContextKey.TASK_ID, None)
if not task_id:
raise ValueError("No task_id found in FLContext.")
try:
dxo = get_dxo_from_ctx(fl_ctx, self.ctx_prop_key, self.tasks)
except ValueError as exc:
self.logger.warning(f"{exc} Nothing to send.")
return False
params = dxo.data
self.task_params[task_id] = params
self.logger.info(f"Stored reference to params to be sent to peer '{peer_name}'. Task ID: '{task_id}'.")
del params
[docs]
def send(
self,
fl_ctx: FLContext,
tensor_send_timeout: float,
) -> None:
"""Send tensors to the peer.
Args:
fl_ctx (FLContext): The FLContext for the current operation.
tensor_send_timeout (float): Timeout for each tensor entry transfer.
"""
peer_name = fl_ctx.get_peer_context().get_identity_name()
task_id = fl_ctx.get_prop(FLContextKey.TASK_ID, None)
if not task_id:
raise ValueError("No task_id found in FLContext.")
targets = get_targets_from_ctx_and_prop_key(fl_ctx, self.ctx_prop_key)
# Important: pop the tensors to release memory after sending
# Each task_id is unique per client, so we only send once per task_id
params = self.task_params.pop(task_id, None)
if not params:
raise ValueError(f"No tensors stored for peer '{peer_name}'. Task ID: '{task_id}'.")
producer = TensorProducer(params, task_id, tensor_send_timeout)
msg = f"Starting to send tensors to peer '{peer_name}'."
msg += f" Task ID: '{task_id}'."
self.logger.info(msg)
self._send_tensors(targets, producer, fl_ctx)
def _send_tensors(self, targets: list[str], producer: TensorProducer, fl_ctx: FLContext):
"""Send tensors to the peer using the StreamableEngine."""
stream_ctx = StreamContext()
self.engine.stream_objects(
channel=self.channel,
topic=get_topic_for_ctx_prop_key(self.ctx_prop_key),
stream_ctx=stream_ctx,
targets=targets,
producer=producer,
fl_ctx=fl_ctx,
optional=False,
secure=False,
)