Source code for nvflare.app_opt.tensor_stream.server

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from collections import defaultdict
from threading import Lock

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.streaming import StreamableEngine
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.client.config import ExchangeFormat

from .receiver import TensorReceiver
from .sender import TensorSender
from .utils import clean_task_data


[docs] class TensorServerStreamer(FLComponent): """Handles sending task data tensors to clients and receiving task results from clients. It uses a StreamableEngine, TensorReceiver, and TensorSender to manage tensor streaming on the server side. Attributes: format (str): The format of the tensors to send/receive. Default is "pytorch". tensor_send_timeout (float): Timeout for tensor entry transfer operations. Default is 30.0 seconds. engine (StreamableEngine): The StreamableEngine used for tensor streaming. sender (TensorSender): The TensorSender used to send tensors to clients. receiver (TensorReceiver): The TensorReceiver used to receive tensors from clients. start_sending_time (dict[int, float]): The timestamp when sending to clients started for the current round. seen_tasks (dict[int, set[str]]): The set of task IDs seen in the current round. num_task_data_sent (dict[int, int]): The number of task data sent to clients successfully for the current round. num_task_skipped (dict[int, int]): The number of task data skipped (not sent) to clients for the current round. data_cleaned (dict[int, bool]): Flag indicating whether the task data has been cleaned from the FLContext for the current round. lock (Lock): A lock to protect shared data structures. Methods: initialize(fl_ctx): Initializes the TensorServerStreamer component. handle_event(event_type, fl_ctx): Handles events for the TensorSender component. send_tensors_to_client(fl_ctx): Sends tensors to the client after task data filtering. clean_counters(current_round): Cleans the counters for the current round. wait_clients_to_complete(num_clients, fl_ctx): Waits until all clients have received the tensors or timeout occurs. try_to_clean_task_data(fl_ctx): Cleans the task data in the FLContext if all clients have received the tensors. """ # Buffer time added to wait_send_task_data_all_clients_timeout when calculating minimum get_task_timeout # This accounts for network latency, processing time, and other overhead GET_TASK_TIMEOUT_BUFFER = 60.0 # seconds def __init__( self, format: str = ExchangeFormat.PYTORCH, tasks: list[str] = None, tensor_send_timeout: float = 30.0, wait_send_task_data_all_clients_timeout: float = 300.0, ): """Initialize the TensorServerStreamer component. The server automatically communicates the required minimum get_task_timeout to clients to prevent fast clients from timing out while waiting for slow clients to receive tensors. Background: Fast clients finish receiving tensors early and immediately request the next task. However, the server blocks waiting for all clients to receive tensors (up to wait_send_task_data_all_clients_timeout). Without proper timeout configuration, fast clients would timeout and fail. Automatic Timeout Management: - Server calculates: min_timeout = wait_send_task_data_all_clients_timeout + 60s buffer - Server sends this requirement to clients in task responses - Clients automatically adjust their get_task_timeout if it's too small - Transparent logging shows when auto-adjustment occurs Optional Manual Configuration: Users can still explicitly set get_task_timeout in config_fed_client.json to override the automatic behavior if needed: { "get_task_timeout": 400.0 // Explicit override } Args: format (str): The format of the tensors to send/receive. Default is ExchangeFormat.PYTORCH. tasks (list[str]): The list of tasks to send tensors for. Default is None, which means the "train" task. tensor_send_timeout (float): Timeout for each tensor chunk transfer operation. Default is 30.0 seconds. wait_send_task_data_all_clients_timeout (float): Maximum time to wait for all clients to receive task tensors. Default is 300.0 seconds. """ super().__init__() self.format = format self.tasks = tasks if tasks is not None else ["train"] self.tensor_send_timeout = tensor_send_timeout self.wait_task_data_sent_to_all_clients_timeout = wait_send_task_data_all_clients_timeout self.engine: StreamableEngine = None self.sender: TensorSender = None self.receiver: TensorReceiver = None self.start_sending_time: dict[int, float] = defaultdict(float) self.seen_tasks: dict[int, set[str]] = defaultdict(set) self.num_task_data_sent: dict[int, int] = defaultdict(int) self.num_task_skipped: dict[int, int] = defaultdict(int) self.data_cleaned: dict[int, bool] = defaultdict(bool) self.lock = Lock()
[docs] def initialize(self, fl_ctx: FLContext): """Initialize the TensorServerSender component. Args: fl_ctx (FLContext): The FLContext for the current operation. """ engine: StreamableEngine = fl_ctx.get_engine() if not engine: self.system_panic(f"Engine not found. {self.__class__.__name__} exiting.", fl_ctx) return if not isinstance(engine, StreamableEngine): self.system_panic( f"Engine is not a StreamableEngine. {self.__class__.__name__} exiting.", fl_ctx, ) return self.engine = engine try: self.receiver = TensorReceiver(engine, FLContextKey.TASK_RESULT, self.format) self.sender = TensorSender(engine, FLContextKey.TASK_DATA, self.format, self.tasks) except Exception as e: self.system_panic(str(e), fl_ctx) return # Set minimum get_task_timeout requirement for clients # This will be automatically communicated to clients in task responses via GetTaskCommand recommended_get_task_timeout = self.wait_task_data_sent_to_all_clients_timeout + self.GET_TASK_TIMEOUT_BUFFER fl_ctx.set_prop(FLContextKey.MIN_GET_TASK_TIMEOUT, recommended_get_task_timeout, sticky=True) self.log_info( fl_ctx, f"TensorServerStreamer: Requiring clients to use get_task_timeout >= {recommended_get_task_timeout}s " f"(wait_send_task_data_all_clients_timeout={self.wait_task_data_sent_to_all_clients_timeout}s " f"+ {self.GET_TASK_TIMEOUT_BUFFER}s buffer). " f"This will be automatically communicated to clients.", )
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): """Handle events for the TensorServerStreamer component. Args: event_type (str): The type of event to handle. fl_ctx (FLContext): The FLContext for the current operation. """ if event_type == EventType.START_RUN: self.initialize(fl_ctx) elif event_type == EventType.BEFORE_TASK_DATA_FILTER: current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) task_id = fl_ctx.get_prop(FLContextKey.TASK_ID) self.seen_tasks[current_round].add(task_id) elif event_type == EventType.AFTER_TASK_DATA_FILTER: # Store tensors after filtering (to get the filtered reference) # Then send to each client self.sender.store_tensors(fl_ctx) self.send_tensors_to_client(fl_ctx) num_clients = len(self.engine.get_clients()) self.wait_sending_task_data_all_clients(num_clients, fl_ctx) self.try_to_clean_task_data(num_clients, fl_ctx) elif event_type == EventType.BEFORE_TASK_RESULT_FILTER: task_id = fl_ctx.get_prop(FLContextKey.TASK_ID) peer_name = fl_ctx.get_peer_context().get_identity_name() try: self.receiver.wait_for_tensors(task_id, peer_name) self.receiver.set_ctx_with_tensors(fl_ctx) except Exception as e: self.system_panic(str(e), fl_ctx) elif event_type == AppEventType.ROUND_DONE: current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) # Clear received tensors in case they were set back to the FLContext # it can happen when the aggregator only accepts part of the clients for task_id in self.seen_tasks[current_round]: if task_id in self.receiver.tensors: self.receiver.tensors.pop(task_id) self.clean_counters(current_round)
[docs] def clean_counters(self, current_round: int): """Clean the counters for the current round. Args: current_round (int): The current round number. """ with self.lock: self.num_task_data_sent.pop(current_round, None) self.num_task_skipped.pop(current_round, None) self.data_cleaned.pop(current_round, None) self.seen_tasks.pop(current_round, None) self.start_sending_time.pop(current_round, None)
[docs] def send_tensors_to_client(self, fl_ctx: FLContext): """Send tensors to the client after task data filtering. Args: fl_ctx (FLContext): The FLContext for the current operation. """ current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) with self.lock: if not self.start_sending_time.get(current_round): self.start_sending_time[current_round] = time.time() try: self.sender.send(fl_ctx, self.tensor_send_timeout) except ValueError as e: self.log_warning(fl_ctx, f"No tensors to send to client: {str(e)}") success = False else: success = True with self.lock: if success: self.num_task_data_sent[current_round] += 1 else: self.num_task_skipped[current_round] += 1
[docs] def wait_sending_task_data_all_clients(self, num_clients: int, fl_ctx: FLContext): """Wait until all clients have received the task data tensors or timeout occurs. Args: num_clients (int): The number of clients to wait for. fl_ctx (FLContext): The FLContext for the current operation. Raises: TimeoutError: If not all clients have received the tensors within the timeout period. """ current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) wait_timeout = self.wait_task_data_sent_to_all_clients_timeout while True: time.sleep(0.1) num_processed = self.num_task_data_sent[current_round] + self.num_task_skipped[current_round] if num_processed >= num_clients: return if time.time() - self.start_sending_time[current_round] > wait_timeout: self.system_panic( "Timeout waiting for all clients to receive tensors. " f"Sent to {self.num_task_data_sent[current_round]} out of {num_clients}," f" skipped {self.num_task_skipped[current_round]}.", fl_ctx, ) return
[docs] def try_to_clean_task_data(self, num_clients: int, fl_ctx: FLContext): """Clean the task data in the FLContext. Args: num_clients (int): The number of clients to wait for. fl_ctx (FLContext): The FLContext to clean the task data from. """ current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) # only clean if we successfully sent to all clients with self.lock: if not self.data_cleaned[current_round] and self.num_task_data_sent[current_round] >= num_clients: self.log_info( fl_ctx, f"Tensors were sent to all clients, removing them from task data. " f"Sent {self.num_task_data_sent[current_round]} out of {num_clients}", ) clean_task_data(fl_ctx) self.data_cleaned[current_round] = True