Source code for nvflare.fuel.f3.streaming.transfer_progress

# Copyright (c) 2026, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import time
from dataclasses import dataclass
from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple

from nvflare.fuel.utils.validation_utils import check_positive_number

DIRECTION_TASK_PAYLOAD_DOWNLOAD = "task_payload_download"
DIRECTION_RESULT_UPLOAD = "result_upload"
VALID_TRANSFER_DIRECTIONS = {DIRECTION_TASK_PAYLOAD_DOWNLOAD, DIRECTION_RESULT_UPLOAD}

STREAMING_IDLE_TIMEOUT = "streaming_idle_timeout"
STREAMING_MAX_PEER_SILENCE = "streaming_max_peer_silence"

DEFAULT_STREAMING_IDLE_TIMEOUT = 600.0
DEFAULT_STREAMING_MAX_PEER_SILENCE = 900.0
STREAMING_MAX_PEER_SILENCE_IDLE_MULTIPLIER = 1.5
STREAM_PROGRESS_COMPLETION_ACK_GRACE = 30.0
_RECEIVER_ID_UNSET = object()


[docs] def check_positive_finite_number(name: str, value) -> float: check_positive_number(name, value) value = float(value) if not math.isfinite(value): raise ValueError(f"{name} must be finite, but got {value}") return value
[docs] class TransferProgressState: ACTIVE = "active" COMPLETED = "completed" FAILED = "failed" ABORTED = "aborted" TERMINAL_STATES = {COMPLETED, FAILED, ABORTED} VALID_STATES = {ACTIVE, COMPLETED, FAILED, ABORTED}
TransferProgressKey = Tuple[str, str, str, str, Optional[str]]
[docs] @dataclass(frozen=True) class StreamingProgressConfig: streaming_idle_timeout: float = DEFAULT_STREAMING_IDLE_TIMEOUT streaming_max_peer_silence: float = DEFAULT_STREAMING_MAX_PEER_SILENCE
[docs] @dataclass class TransferProgressRecord: job_id: str task_id: str transfer_id: str direction: str receiver_id: Optional[str] sequence: int bytes_done: int items_done: Optional[int] started_time: float last_progress_time: float state: str = TransferProgressState.ACTIVE transfer_id_kind: Optional[str] = None @property def key(self) -> TransferProgressKey: return self.job_id, self.task_id, self.transfer_id, self.direction, self.receiver_id @property def terminal(self) -> bool: return self.state in TransferProgressState.TERMINAL_STATES
[docs] @dataclass(frozen=True) class TransferProgressUpdate: accepted: bool progressed: bool record: Optional[TransferProgressRecord] reason: str = ""
[docs] def resolve_streaming_progress_config( config: Optional[Mapping[str, object]] = None, *, streaming_idle_timeout: Optional[float] = None, streaming_max_peer_silence: Optional[float] = None, ) -> StreamingProgressConfig: """Resolve generic streaming progress timeouts. Explicit keyword values take precedence over mapping values. When the idle timeout is raised above the default and max peer silence is not explicit, derive max silence from the idle timeout. """ config = config or {} idle_explicit = streaming_idle_timeout is not None or STREAMING_IDLE_TIMEOUT in config if streaming_idle_timeout is None: streaming_idle_timeout = config.get(STREAMING_IDLE_TIMEOUT, DEFAULT_STREAMING_IDLE_TIMEOUT) if streaming_max_peer_silence is None: streaming_max_peer_silence = config.get(STREAMING_MAX_PEER_SILENCE) streaming_idle_timeout = check_positive_finite_number(STREAMING_IDLE_TIMEOUT, streaming_idle_timeout) if streaming_max_peer_silence is None: if idle_explicit and streaming_idle_timeout > DEFAULT_STREAMING_IDLE_TIMEOUT: streaming_max_peer_silence = max( DEFAULT_STREAMING_MAX_PEER_SILENCE, STREAMING_MAX_PEER_SILENCE_IDLE_MULTIPLIER * streaming_idle_timeout, ) else: streaming_max_peer_silence = DEFAULT_STREAMING_MAX_PEER_SILENCE else: streaming_max_peer_silence = check_positive_finite_number( STREAMING_MAX_PEER_SILENCE, streaming_max_peer_silence ) return StreamingProgressConfig( streaming_idle_timeout=streaming_idle_timeout, streaming_max_peer_silence=streaming_max_peer_silence, )
[docs] class TransferProgressTracker: """Direction-neutral monotonic progress tracker for streamed transfers.""" def __init__( self, *, idle_timeout: float = DEFAULT_STREAMING_IDLE_TIMEOUT, clock: Optional[Callable[[], float]] = None, ): self.set_idle_timeout(idle_timeout) self._clock = clock or time.time self._records: Dict[TransferProgressKey, TransferProgressRecord] = {}
[docs] def set_idle_timeout(self, idle_timeout: float): self.idle_timeout = check_positive_finite_number("idle_timeout", idle_timeout)
[docs] def update( self, *, job_id: str, task_id: str, transfer_id: str, direction: str, sequence: int, bytes_done: int, items_done: Optional[int] = None, state: str = TransferProgressState.ACTIVE, transfer_id_kind: Optional[str] = None, receiver_id: Optional[str] = None, timestamp: Optional[float] = None, ) -> TransferProgressUpdate: self._validate_direction(direction) self._validate_update(sequence=sequence, bytes_done=bytes_done, items_done=items_done, state=state) now = self._clock() if timestamp is None else float(timestamp) receiver_id = self._normalize_receiver_id(direction, receiver_id) key = self._make_key( job_id=job_id, task_id=task_id, transfer_id=transfer_id, direction=direction, receiver_id=receiver_id, ) record = self._records.get(key) if record is None: new_record = TransferProgressRecord( job_id=job_id, task_id=task_id, transfer_id=transfer_id, direction=direction, receiver_id=receiver_id, sequence=sequence, bytes_done=bytes_done, items_done=items_done, started_time=now, last_progress_time=now, state=state, transfer_id_kind=transfer_id_kind, ) self._records[key] = new_record return TransferProgressUpdate(accepted=True, progressed=True, record=new_record) if record.terminal: return TransferProgressUpdate(accepted=False, progressed=False, record=record, reason="terminal") if sequence <= record.sequence: return TransferProgressUpdate(accepted=False, progressed=False, record=record, reason="stale_sequence") if bytes_done < record.bytes_done: return TransferProgressUpdate(accepted=False, progressed=False, record=record, reason="bytes_regressed") if record.items_done is not None and items_done is not None and items_done < record.items_done: return TransferProgressUpdate(accepted=False, progressed=False, record=record, reason="items_regressed") next_items_done = self._next_items_done(record.items_done, items_done) counters_advanced = bytes_done > record.bytes_done or self._items_advanced(record.items_done, next_items_done) terminal_observed = state in TransferProgressState.TERMINAL_STATES progressed = counters_advanced or terminal_observed record.sequence = sequence record.bytes_done = max(record.bytes_done, bytes_done) record.items_done = next_items_done record.state = state if transfer_id_kind and not record.transfer_id_kind: record.transfer_id_kind = transfer_id_kind if progressed: record.last_progress_time = max(record.last_progress_time, now) return TransferProgressUpdate(accepted=True, progressed=progressed, record=record)
[docs] def get_record( self, *, job_id: str, task_id: str, transfer_id: str, direction: str, receiver_id: Optional[str] = None, ) -> Optional[TransferProgressRecord]: return self._records.get( self._make_key( job_id=job_id, task_id=task_id, transfer_id=transfer_id, direction=direction, receiver_id=receiver_id, ) )
[docs] def records( self, *, job_id: Optional[str] = None, task_id: Optional[str] = None, direction: Optional[str] = None, receiver_id: object = _RECEIVER_ID_UNSET, ) -> Iterable[TransferProgressRecord]: filter_receiver = receiver_id is not _RECEIVER_ID_UNSET normalized_receiver_id = None if filter_receiver: normalized_receiver_id = self._normalize_receiver_id(direction, receiver_id) if direction else receiver_id return [ record for record in self._records.values() if (job_id is None or record.job_id == job_id) and (task_id is None or record.task_id == task_id) and (direction is None or record.direction == direction) and (not filter_receiver or record.receiver_id == normalized_receiver_id) ]
[docs] def is_stalled( self, *, job_id: str, task_id: str, transfer_id: str, direction: str, receiver_id: Optional[str] = None, now: Optional[float] = None, ) -> bool: record = self.get_record( job_id=job_id, task_id=task_id, transfer_id=transfer_id, direction=direction, receiver_id=receiver_id, ) if record is None or record.terminal: return False return self._is_record_stalled(record, self._clock() if now is None else now)
[docs] def stalled_records( self, now: Optional[float] = None, *, job_id: Optional[str] = None, task_id: Optional[str] = None, direction: Optional[str] = None, receiver_id: object = _RECEIVER_ID_UNSET, ) -> Iterable[TransferProgressRecord]: check_time = self._clock() if now is None else now return [ record for record in self.records(job_id=job_id, task_id=task_id, direction=direction, receiver_id=receiver_id) if self._is_record_stalled(record, check_time) ]
[docs] def mark_terminal( self, *, job_id: str, task_id: str, transfer_id: str, direction: str, state: str, sequence: Optional[int] = None, receiver_id: Optional[str] = None, timestamp: Optional[float] = None, ) -> TransferProgressUpdate: if state not in TransferProgressState.TERMINAL_STATES: raise ValueError(f"terminal state must be one of {TransferProgressState.TERMINAL_STATES}, but got {state}") record = self.get_record( job_id=job_id, task_id=task_id, transfer_id=transfer_id, direction=direction, receiver_id=receiver_id, ) if record is None: if sequence is None: sequence = 0 return self.update( job_id=job_id, task_id=task_id, transfer_id=transfer_id, direction=direction, receiver_id=receiver_id, sequence=sequence, bytes_done=0, items_done=None, state=state, timestamp=timestamp, ) if sequence is None: sequence = record.sequence + 1 return self.update( job_id=job_id, task_id=task_id, transfer_id=transfer_id, direction=direction, receiver_id=receiver_id, sequence=sequence, bytes_done=record.bytes_done, items_done=record.items_done, state=state, transfer_id_kind=record.transfer_id_kind, timestamp=timestamp, )
[docs] def prune( self, *, before_time: Optional[float] = None, include_active: bool = False, direction: Optional[str] = None, ) -> int: if before_time is None: before_time = self._clock() if direction is not None: self._validate_direction(direction) keys_to_remove = [ key for key, record in self._records.items() if (direction is None or record.direction == direction) and (include_active or record.terminal) and record.last_progress_time <= before_time ] for key in keys_to_remove: del self._records[key] return len(keys_to_remove)
[docs] def remove( self, *, job_id: str, task_id: str, transfer_id: str, direction: str, receiver_id: Optional[str] = None, ) -> bool: return ( self._records.pop( self._make_key( job_id=job_id, task_id=task_id, transfer_id=transfer_id, direction=direction, receiver_id=receiver_id, ), None, ) is not None )
[docs] def clear(self): self._records.clear()
@staticmethod def _validate_update(*, sequence: int, bytes_done: int, items_done: Optional[int], state: str): if not isinstance(sequence, int): raise TypeError(f"sequence must be an int, but got {type(sequence)}.") if sequence < 0: raise ValueError(f"sequence must >= 0, but got {sequence}") if not isinstance(bytes_done, int): raise TypeError(f"bytes_done must be an int, but got {type(bytes_done)}.") if bytes_done < 0: raise ValueError(f"bytes_done must >= 0, but got {bytes_done}") if items_done is not None: if not isinstance(items_done, int): raise TypeError(f"items_done must be an int, but got {type(items_done)}.") if items_done < 0: raise ValueError(f"items_done must >= 0, but got {items_done}") if state not in TransferProgressState.VALID_STATES: raise ValueError(f"state must be one of {TransferProgressState.VALID_STATES}, but got {state}") @staticmethod def _validate_direction(direction: str): if direction not in VALID_TRANSFER_DIRECTIONS: raise ValueError(f"direction must be one of {VALID_TRANSFER_DIRECTIONS}, but got {direction}") @staticmethod def _next_items_done(current_items_done: Optional[int], update_items_done: Optional[int]) -> Optional[int]: if current_items_done is None: return update_items_done if update_items_done is None: return current_items_done return max(current_items_done, update_items_done) @staticmethod def _items_advanced(current_items_done: Optional[int], next_items_done: Optional[int]) -> bool: if next_items_done is None: return False if current_items_done is None: return next_items_done > 0 return next_items_done > current_items_done @staticmethod def _normalize_receiver_id(direction: Optional[str], receiver_id: Optional[str]) -> Optional[str]: if direction != DIRECTION_RESULT_UPLOAD: return None return None if receiver_id is None else str(receiver_id) @classmethod def _make_key( cls, *, job_id: str, task_id: str, transfer_id: str, direction: str, receiver_id: Optional[str] = None, ) -> TransferProgressKey: return job_id, task_id, transfer_id, direction, cls._normalize_receiver_id(direction, receiver_id) def _is_record_stalled(self, record: TransferProgressRecord, now: float) -> bool: return not record.terminal and now - record.last_progress_time >= self.idle_timeout