Source code for nvflare.app_opt.tensor_stream.client

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.streaming import StreamableEngine
from nvflare.client.config import ExchangeFormat

from .receiver import TensorReceiver
from .sender import TensorSender
from .utils import clean_task_result


[docs] class TensorClientStreamer(FLComponent): """TensorClientSender handles receiving task data and sending task results tensors from/to server. It uses a StreamableEngine, TensorReceiver, and TensorSender to manage tensor streaming on the client side. Attributes: format (str): The format of the tensors to send. Default is "pytorch". tensor_send_timeout (float): Timeout for tensor entry transfer operations. Default is 30.0 seconds. engine (StreamableEngine): The StreamableEngine used for tensor streaming. sender (TensorSender): The TensorSender used to send tensors to the server. receiver (TensorReceiver): The TensorReceiver used to receive tensors from the server. Methods: initialize(fl_ctx): Initializes the TensorClientStreamer component. handle_event(event_type, fl_ctx): Handles events for the TensorSender component. send_tensors_to_server(fl_ctx): Sends tensors to the server before sending the task result. """ def __init__( self, format: str = ExchangeFormat.PYTORCH, tasks: list[str] = None, tensor_send_timeout=30.0, ): """Initialize the TensorClientStreamer component. Args: format (str): The format of the tensors to send. Default is ExchangeFormat.TORCH. tasks (list[str]): The list of tasks to send tensors for. Default is None, which means the "train" task. tensor_send_timeout (float): Timeout for tensor entry transfer operations. Default is 30.0 seconds. """ super().__init__() self.format = format self.tasks = tasks if tasks is not None else ["train"] self.tensor_send_timeout = tensor_send_timeout self.engine: StreamableEngine = None self.sender: TensorSender = None self.receiver: TensorReceiver = None
[docs] def initialize(self, fl_ctx: FLContext): """Initialize the TensorClientStreamer component. Args: fl_ctx (FLContext): The FLContext for the current operation. """ engine: StreamableEngine = fl_ctx.get_engine() if not engine: self.system_panic(f"Engine not found. {self.__class__.__name__} exiting.", fl_ctx) return if not isinstance(engine, StreamableEngine): self.system_panic( f"Engine is not a StreamableEngine. {self.__class__.__name__} exiting.", fl_ctx, ) return self.engine = engine try: self.receiver = TensorReceiver(engine, FLContextKey.TASK_DATA, self.format) except Exception as e: self.system_panic(str(e), fl_ctx) return
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): """Handle events for the TensorSender component. Args: event_type (str): The type of event to handle. fl_ctx (FLContext): The FLContext for the current operation. """ if event_type == EventType.START_RUN: self.initialize(fl_ctx) elif event_type == EventType.BEFORE_TASK_DATA_FILTER: task_id = fl_ctx.get_prop(FLContextKey.TASK_ID) peer_name = fl_ctx.get_peer_context().get_identity_name() try: self.receiver.wait_for_tensors(task_id, peer_name) self.receiver.set_ctx_with_tensors(fl_ctx) except Exception as e: self.system_panic(str(e), fl_ctx) elif event_type == EventType.AFTER_TASK_RESULT_FILTER: try: self.send_tensors_to_server(fl_ctx) except Exception as e: self.system_panic(str(e), fl_ctx)
[docs] def send_tensors_to_server(self, fl_ctx: FLContext): """Sends tensors to the server before sending the task result. Args: fl_ctx (FLContext): The FLContext for the current operation. """ self.sender = TensorSender(self.engine, FLContextKey.TASK_RESULT, self.format, self.tasks) self.sender.store_tensors(fl_ctx) try: self.sender.send(fl_ctx, self.tensor_send_timeout) except ValueError as e: self.log_warning(fl_ctx, f"No tensors to send to server: {str(e)}") else: clean_task_result(fl_ctx) finally: # Clear sender to release any references to tensors self.sender = None