Source code for nvflare.edge.assessors.buff_model_manager

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from typing import Dict, Optional, Set

import numpy as np

from nvflare.apis.dxo import DXO, DataKind
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.abstract.model import make_model_learnable
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.edge.aggregators.model_update_dxo import ModelUpdateDXOAggregator
from nvflare.edge.assessors.model_manager import ModelManager
from nvflare.edge.mud import ModelUpdate


class _ModelState:
    def __init__(self, aggr: ModelUpdateDXOAggregator):
        self.aggregator = aggr
        self.devices = {}
        self.last_update_time = None

    def accept(self, model_update: ModelUpdate, fl_ctx: FLContext):
        self.last_update_time = time.time()
        self.devices.update(model_update.devices)
        return self.aggregator.accept(model_update.update, fl_ctx)


[docs] class BuffModelManager(ModelManager): def __init__( self, num_updates_for_model: int, max_model_history: Optional[int] = None, global_lr: float = 1.0, staleness_weight: bool = False, ): """Initialize the ModelManager. The aggregation scheme and weights are calculated following FedBuff paper "Federated Learning with Buffered Asynchronous Aggregation". The staleness_weight can be enabled to apply staleness weighting to model updates. Special cases for max_model_history: - If None: Keep every model versions, only remove a version when all devices processing it reports back (version no longer related with any device_id in the current_selection from device_manager). Args: num_updates_for_model (int): Number of updates required before generating a new model version. max_model_history (int): Maximum number of historical model versions to keep in memory. - None (default): keep every version until all devices processing a particular version report back. - positive integer: keep only the latest n versions global_lr (float): Global learning rate for model aggregation, default is 1.0. staleness_weight (bool): Whether to apply staleness weighting to model updates, default is False. """ super().__init__() self.num_updates_for_model = num_updates_for_model self.num_updates_counter = 0 self.max_model_history = max_model_history self.global_lr = global_lr self.staleness_weight = staleness_weight
[docs] def initialize_model(self, model: DXO, fl_ctx: FLContext): self.current_model = model # updates is a dict of model version to _ModelState self.updates[self.current_model_version] = _ModelState(ModelUpdateDXOAggregator())
[docs] def prune_model_versions(self, versions_to_keep: Set[int], fl_ctx: FLContext) -> None: # go through all versions and remove the ones: # - either not in versions_to_keep # - or too old (current_model_version - v >= max_model_history) versions_to_remove = set() for v in self.updates.keys(): if v not in versions_to_keep: versions_to_remove.add(v) if self.max_model_history and self.current_model_version - v >= self.max_model_history: versions_to_remove.add(v) # Remove the identified versions for v in versions_to_remove: self.log_info(fl_ctx, f"removed model version {v}") self.updates.pop(v) # log the current total number of model versions self.log_info(fl_ctx, f"current total number of active model versions: {len(self.updates)}")
[docs] def generate_new_model(self, fl_ctx: FLContext) -> None: # New model generated based on the current global weights and all updates new_model = {} self.current_model_version += 1 # counter to confirm the number of updates num_updates = 0 if self.current_model_version == 1: # Initial global weights new_model = self.current_model.data else: # Aggregate all updates for v, ms in self.updates.items(): if self.staleness_weight: weight = 1 / (1 + (self.current_model_version - v) ** 0.5) else: weight = 1.0 aggr = ms.aggregator # Add the dict to new_model by multiplying the weight and dividing by the count update_dict = aggr.dict count = aggr.count if count > 0: # aggregate updates for key, value in update_dict.items(): # apply weight and divide by count value = weight * value / count # update the new model if key not in new_model: new_model[key] = value else: new_model[key] = new_model[key] + value # Reset aggr after counting its contribution ms.aggregator.reset(fl_ctx) num_updates += count # Add the aggregated updates to the current global weights global_weights = self.current_model.data for key, value in new_model.items(): if key not in global_weights: self.log_error(fl_ctx, f"key {key} not in new model") continue new_model[key] = np.array(global_weights[key]) + value * self.global_lr # create the ModelState for the new model version self.updates[self.current_model_version] = _ModelState(ModelUpdateDXOAggregator()) self.log_info(fl_ctx, f"generated new model version {self.current_model_version} with {num_updates} updates") # update the current model # convert new_model items from numpy arrays to lists for serialization new_model = {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in new_model.items()} self.current_model = DXO(data_kind=DataKind.WEIGHTS, data=new_model) # reset the num_updates_counter self.num_updates_counter = 0 # set fl_ctx and fire the event # wrap new_model to a learnable learnable = make_model_learnable(new_model, {}) fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, learnable, private=True, sticky=True) fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self.current_model_version, private=True, sticky=True) self.fire_event(AppEventType.GLOBAL_WEIGHTS_UPDATED, fl_ctx)
[docs] def process_updates(self, model_updates: Dict[int, ModelUpdate], fl_ctx: FLContext) -> bool: accepted = True for model_version, model_update in model_updates.items(): if model_version <= 0: continue if not model_update: self.log_error(fl_ctx, f"bad child update version {model_version}: no update data") continue # Check if version is too old before accepting if self.max_model_history: # if max_model_history is set, output warning for updates that are too old if self.current_model_version - model_version >= self.max_model_history: self.log_warning( fl_ctx, f"dropped child update version {model_version}. Current version {self.current_model_version}. Max history {self.max_model_history}", ) continue # Accept the update and aggregate it to the corresponding model version model_state = self.updates.get(model_version) accepted = model_state.accept(model_update, fl_ctx) self.log_info( fl_ctx, f"processed child update V{model_version} with {len(model_update.devices)} devices: {accepted=}", ) # update the global num_updates_counter self.num_updates_counter += len(model_update.devices) current_model_state = self.updates.get(self.current_model_version) if isinstance(current_model_state, _ModelState): if self.num_updates_counter >= self.num_updates_for_model: self.log_info( fl_ctx, f"Globally got {self.num_updates_counter} updates: generate new model version", ) self.generate_new_model(fl_ctx) return accepted