Source code for nvflare.edge.assessors.model_update

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
import time
from typing import Optional

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor
from nvflare.app_common.abstract.model import ModelLearnable, model_learnable_to_dxo
from nvflare.app_common.app_event_type import AppEventType
from nvflare.edge.assessor import Assessment, Assessor
from nvflare.edge.mud import BaseState, StateUpdateReply, StateUpdateReport


[docs] class ModelUpdateAssessor(Assessor): def __init__( self, persistor_id, model_manager_id, device_manager_id, max_model_version, device_wait_timeout: Optional[float] = None, device_status_log_interval: Optional[float] = 30.0, ): """Initialize the ModelUpdateAssessor. Enable both asynchronous and synchronous model updates from clients. For asynchronous updates, the staleness is calculated based on the starting and current model version. And the aggregation scheme and weights are calculated following FedBuff paper "Federated Learning with Buffered Asynchronous Aggregation". Args: persistor_id (str): ID of the persistor component used to load and save models. model_manager_id (str): ID of the model manager component. device_manager_id (str): ID of the device manager component. max_model_version (int): Maximum model version to stop the workflow. device_wait_timeout (float, optional): Timeout in seconds for waiting for sufficient devices. None means no timeout. Default is None. device_status_log_interval (float, optional): Interval in seconds for logging device status. Default is 30 seconds. """ Assessor.__init__(self) self.persistor_id = persistor_id self.model_manager_id = model_manager_id self.device_manager_id = device_manager_id self.persistor = None self.model_manager = None self.device_manager = None self.max_model_version = max_model_version self.update_lock = threading.Lock() self.device_wait_timeout = device_wait_timeout self.device_wait_start_time = None self._last_device_status_log_time = time.time() self.device_status_log_interval = device_status_log_interval self.register_event_handler(EventType.START_RUN, self._handle_start_run) def _is_device_wait_timeout_exceeded(self, fl_ctx: FLContext) -> bool: """Check if device wait timeout has been exceeded. Args: fl_ctx: FL context Returns: bool: True if timeout exceeded, False otherwise """ if self.device_wait_start_time is None or self.device_wait_timeout is None: return False try: elapsed = time.time() - self.device_wait_start_time if elapsed > self.device_wait_timeout: usable_devices = set(self.device_manager.get_available_devices(fl_ctx).keys()) - set( self.device_manager.get_used_devices(fl_ctx).keys() ) self.log_error( fl_ctx, f"Device wait timeout ({self.device_wait_timeout}s) exceeded. " f"Elapsed time: {elapsed:.1f}s. " f"Total devices: {len(self.device_manager.get_available_devices(fl_ctx))}, " f"usable: {len(usable_devices)}, " f"expected: {self.device_manager.device_selection_size}. " f"Device_reuse flag: {self.device_manager.device_reuse}. " "Stopping the job.", ) return True return False except Exception as e: self.log_error(fl_ctx, f"Error checking device timeout: {e}") return False def _log_device_status(self, fl_ctx: FLContext): """Log device status information independently of timeout logic.""" if self.device_status_log_interval is None: return current_time = time.time() elapsed = current_time - self._last_device_status_log_time if elapsed >= self.device_status_log_interval: usable_devices = set(self.device_manager.get_available_devices(fl_ctx).keys()) - set( self.device_manager.get_used_devices(fl_ctx).keys() ) # Add timeout info if we're actually waiting with a timeout timeout_msg = "" if self.device_wait_start_time is not None and self.device_wait_timeout is not None: remaining_time = self.device_wait_timeout - (current_time - self.device_wait_start_time) timeout_msg = f" Timeout in {remaining_time:.1f} seconds." elif self.device_wait_start_time is not None: timeout_msg = " No timeout set (waiting indefinitely)." self.log_info( fl_ctx, f"Device Status: " f"Total: {len(self.device_manager.available_devices)}, " f"usable: {len(usable_devices)}, " f"expected: {self.device_manager.device_selection_size}.{timeout_msg}", ) self._last_device_status_log_time = current_time def _handle_start_run(self, event_type: str, fl_ctx: FLContext): engine = fl_ctx.get_engine() # Get persistor component self.persistor = engine.get_component(self.persistor_id) if not isinstance(self.persistor, LearnablePersistor): self.system_panic(reason="persistor must be a Persistor type object", fl_ctx=fl_ctx) return # Get model manager component self.model_manager = engine.get_component(self.model_manager_id) if not self.model_manager: self.system_panic(reason=f"cannot find model manager component '{self.model_manager_id}'", fl_ctx=fl_ctx) return # Get device manager component self.device_manager = engine.get_component(self.device_manager_id) if not self.device_manager: self.system_panic(reason=f"cannot find device manager component '{self.device_manager_id}'", fl_ctx=fl_ctx) return if self.persistor: model = self.persistor.load(fl_ctx) if not isinstance(model, ModelLearnable): self.system_panic( reason=f"Expected model loaded by persistor to be `ModelLearnable` but received {type(model)}", fl_ctx=fl_ctx, ) return # Wrap learnable model into a DXO dxo_model = model_learnable_to_dxo(model) self.model_manager.initialize_model(dxo_model, fl_ctx) self.fire_event(AppEventType.INITIAL_MODEL_LOADED, fl_ctx) else: self.system_panic(reason="cannot find persistor component '{}'".format(self.persistor_id), fl_ctx=fl_ctx) return
[docs] def start_task(self, fl_ctx: FLContext) -> Shareable: # empty base state to start with base_state = BaseState( model_version=0, model=None, device_selection_version=0, device_selection={}, ) return base_state.to_shareable()
[docs] def process_child_update(self, update: Shareable, fl_ctx: FLContext) -> (bool, Optional[Shareable]): with self.update_lock: return self._do_child_update(update, fl_ctx)
def _do_child_update(self, update: Shareable, fl_ctx: FLContext) -> (bool, Optional[Shareable]): report = StateUpdateReport.from_shareable(update) # Update available devices if report.available_devices: self.device_manager.update_available_devices(report.available_devices, fl_ctx) # Reset wait timer if we now have enough devices if self.device_wait_start_time is not None and self.device_manager.has_enough_devices(fl_ctx): self.device_wait_start_time = None self.log_info(fl_ctx, "Sufficient devices now available, resetting wait timer") # Check for device wait timeout if we are waiting for devices if self.device_wait_start_time is not None and self._is_device_wait_timeout_exceeded(fl_ctx): # Timeout exceeded, prepare an empty reply and stop the job usable_devices = set(self.device_manager.get_available_devices(fl_ctx).keys()) - set( self.device_manager.get_used_devices(fl_ctx).keys() ) self.log_error( fl_ctx, f"Total devices: {len(self.device_manager.available_devices)}, usable: {len(usable_devices)}, expected: {self.device_manager.device_selection_size}. " f"Device_reuse flag is set to: {self.device_manager.device_reuse}. " "Not enough devices joining, please adjust the server params. Stopping the job.", ) reply = StateUpdateReply( model_version=0, model=None, device_selection_version=self.device_manager.current_selection_version, device_selection=self.device_manager.get_selection(fl_ctx), ) return False, reply.to_shareable() accepted = True if report.model_updates: self.log_info(fl_ctx, f"got reported {len(report.model_updates)} model versions") # Process model updates accepted = self.model_manager.process_updates(report.model_updates, fl_ctx) # Remove reported devices from selection for model_update in report.model_updates.values(): if model_update: self.device_manager.remove_devices_from_selection(set(model_update.devices.keys()), fl_ctx) # if device_reuse, remove devices from used_devices # indicating that the reported devices becomes available again for reuse if self.device_manager.device_reuse: self.device_manager.remove_devices_from_used(set(model_update.devices.keys()), fl_ctx) else: self.log_debug(fl_ctx, "no model updates") # Handle device selection if self.device_manager.should_fill_selection(fl_ctx): # check if we have enough devices to fill selection if self.device_manager.has_enough_devices(fl_ctx): if self.model_manager.current_model_version == 0: self.log_info(fl_ctx, "Generate initial model and fill selection") self.model_manager.generate_new_model(fl_ctx) self.device_manager.fill_selection(self.model_manager.current_model_version, fl_ctx) # prune old model versions that are no longer active active_model_versions = self.device_manager.get_active_model_versions(fl_ctx) self.model_manager.prune_model_versions(active_model_versions, fl_ctx) # Reset wait timer since we have enough devices self.device_wait_start_time = None else: # Start wait timer if not already started if self.device_wait_start_time is None: self.device_wait_start_time = time.time() self.log_info(fl_ctx, f"Starting device wait timer (timeout: {self.device_wait_timeout}s)") # Prepare reply model = None if self.model_manager.current_model_version != report.current_model_version: model = self.model_manager.get_current_model(fl_ctx) reply = StateUpdateReply( model_version=self.model_manager.current_model_version, model=model, device_selection_version=self.device_manager.current_selection_version, device_selection=self.device_manager.get_selection(fl_ctx), ) return accepted, reply.to_shareable()
[docs] def assess(self, fl_ctx: FLContext) -> Assessment: # Check if we're waiting for devices and timeout exceeded self._log_device_status(fl_ctx) if self._is_device_wait_timeout_exceeded(fl_ctx): self.log_error(fl_ctx, "Job stopped due to insufficient devices joining within timeout period") return Assessment.WORKFLOW_DONE elif self.model_manager.current_model_version >= self.max_model_version: model_version = self.model_manager.current_model_version self.log_info(fl_ctx, f"Max model version {self.max_model_version} reached: {model_version=}") return Assessment.WORKFLOW_DONE else: return Assessment.CONTINUE