# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import threading
import time
from typing import Optional
from nvflare.apis.dxo import DXO
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.edge.aggregators.num_dxo import NumDXOAggregator
from nvflare.edge.assessor import Assessment, Assessor
from nvflare.edge.mud import BaseState, ModelUpdate, StateUpdateReply, StateUpdateReport
class _ModelState:
def __init__(self, aggr: NumDXOAggregator):
self.aggregator = aggr
self.devices = {}
self.last_update_time = None
def accept(self, model_update: ModelUpdate, fl_ctx: FLContext):
self.last_update_time = time.time()
self.devices.update(model_update.devices)
return self.aggregator.accept(model_update.update, fl_ctx)
[docs]
class AsyncNumAssessor(Assessor):
def __init__(
self,
num_updates_for_model,
max_model_version,
max_model_history,
device_selection_size,
min_hole_to_fill=1,
device_reuse=True,
):
Assessor.__init__(self)
self.current_model_version = 0
self.current_model = None
self.current_selection_version = 0
self.current_selection = {}
self.updates = {} # model_version => _ModelState
self.available_devices = {}
self.used_devices = {}
self.num_updates_for_model = num_updates_for_model
self.max_model_version = max_model_version
self.max_model_history = max_model_history
self.device_selection_size = device_selection_size
self.min_hole_to_fill = min_hole_to_fill
self.device_reuse = device_reuse
self.update_lock = threading.Lock()
self.start_time = None
[docs]
def start_task(self, fl_ctx: FLContext) -> Shareable:
self.start_time = time.time()
base_state = BaseState(
model_version=self.current_model_version,
model=self.current_model,
device_selection_version=self.current_selection_version,
device_selection=self.current_selection,
)
return base_state.to_shareable()
def _generate_new_model(self, fl_ctx: FLContext):
total = 0.0
self.current_model_version += 1
old_model_versions = []
aggr_info = {}
for v, ms in self.updates.items():
weight = 1 / (self.current_model_version - v)
assert isinstance(ms, _ModelState)
aggr = ms.aggregator
assert isinstance(aggr, NumDXOAggregator)
score = aggr.value / aggr.count if aggr.count > 0 else 0.0
aggr_info[v] = {"weight": weight, "value": aggr.value, "count": aggr.count, "score": score}
total += weight * score
if self.current_model_version - v >= self.max_model_history:
old_model_versions.append(v)
# create the ModelState for the new model version
self.updates[self.current_model_version] = _ModelState(NumDXOAggregator())
self.log_info(fl_ctx, f"model version info: {aggr_info}")
self.log_info(fl_ctx, f"generated new model version {self.current_model_version}: value={total}")
for v in old_model_versions:
self.updates.pop(v)
if old_model_versions:
self.log_info(fl_ctx, f"removed old model versions {old_model_versions}")
self.current_model = DXO(data_kind="number", data={"value": total})
[docs]
def process_child_update(self, update: Shareable, fl_ctx: FLContext) -> (bool, Optional[Shareable]):
with self.update_lock:
return self._do_child_update(update, fl_ctx)
def _do_child_update(self, update: Shareable, fl_ctx: FLContext) -> (bool, Optional[Shareable]):
report = StateUpdateReport.from_shareable(update)
if report.available_devices:
self.available_devices.update(report.available_devices)
self.log_debug(
fl_ctx,
f"assessor got reported {len(report.available_devices)} available devices from child. "
f"total num available devices: {len(self.available_devices)}",
)
accepted = True
if report.model_updates:
self.log_info(fl_ctx, f"got reported {len(report.model_updates)} model versions")
for model_version, model_update in report.model_updates.items():
if model_version <= 0:
continue
if not model_update:
self.log_error(fl_ctx, f"bad child update version {model_version}: no update data")
continue
if self.current_model_version - model_version > self.max_model_history:
# this version is too old
self.log_info(
fl_ctx,
f"dropped child update version {model_version}. Current version {self.current_model_version}",
)
continue
model_state = self.updates.get(model_version)
if not model_state:
self.log_error(fl_ctx, f"No model state for version {model_version}")
continue
accepted = model_state.accept(model_update, fl_ctx)
self.log_info(
fl_ctx,
f"processed child update V{model_version} with {len(model_update.devices)} devices: {accepted=}",
)
# remove reported devices from selection, and from used devices if device_reuse is enabled
# indicating that the reported devices becomes available again for reuse
for k in model_update.devices.keys():
if k not in self.current_selection:
self.log_error(
fl_ctx,
f"got update from device {k} but it's not in device selection",
)
self.current_selection.pop(k, None)
if self.device_reuse:
self.used_devices.pop(k, None)
current_model_state = self.updates.get(self.current_model_version)
if not isinstance(current_model_state, _ModelState):
self.log_error(
fl_ctx, f"bad model state for version {self.current_model_version}: {type(current_model_state)}"
)
else:
num_updates = len(current_model_state.devices)
if num_updates >= self.num_updates_for_model:
self.log_info(
fl_ctx,
f"model V{self.current_model_version} got {num_updates} updates: generate new model version",
)
self._generate_new_model(fl_ctx)
# recompute selection
num_holes = self.device_selection_size - len(self.current_selection)
if num_holes >= self.min_hole_to_fill:
self._fill_selection(fl_ctx)
else:
self.log_debug(fl_ctx, "no model updates")
# reply
if self.current_model_version == 0:
# do we have enough devices?
if len(self.available_devices) >= self.device_selection_size:
self.log_info(fl_ctx, f"got {len(self.available_devices)} devices - generate initial model")
self._generate_new_model(fl_ctx)
self._fill_selection(fl_ctx)
model = None
if self.current_model_version != report.current_model_version:
model = self.current_model
reply = StateUpdateReply(
model_version=self.current_model_version,
model=model,
device_selection_version=self.current_selection_version,
device_selection=self.current_selection,
)
return accepted, reply.to_shareable()
def _fill_selection(self, fl_ctx: FLContext):
num_holes = self.device_selection_size - len(self.current_selection)
self.log_info(fl_ctx, f"filling {num_holes} holes in selection list")
if num_holes > 0:
self.current_selection_version += 1
# remove all used devices from available devices
usable_devices = set(self.available_devices.keys()) - set(self.used_devices.keys())
if usable_devices:
for _ in range(num_holes):
device_id = random.choice(list(usable_devices))
usable_devices.remove(device_id)
self.current_selection[device_id] = self.current_model_version
self.used_devices[device_id] = {
"model_version": self.current_model_version,
"selection_version": self.current_selection_version,
}
if not usable_devices:
break
self.log_info(
fl_ctx,
f"current selection with {len(self.current_selection)} items: V{self.current_selection_version}; {dict(sorted(self.current_selection.items()))}",
)
if len(self.current_selection) < self.device_selection_size:
self.log_warning(
fl_ctx,
f"current selection has only {len(self.current_selection)} devices, which is less than the expected {self.device_selection_size} devices. Please check the configuration to make sure this is expected.",
)
[docs]
def assess(self, fl_ctx: FLContext) -> Assessment:
if self.current_model_version >= self.max_model_version:
model_version = self.current_model_version
selection_version = self.current_selection_version
self.log_info(
fl_ctx,
f"Max model version {self.max_model_version} reached: {model_version=} {selection_version=} "
f"num of devices used: {len(self.used_devices)}",
)
return Assessment.WORKFLOW_DONE
else:
return Assessment.CONTINUE