# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import Dict, Set
from nvflare.edge.assessors.device_manager import DeviceManager
from nvflare.fuel.utils.validation_utils import check_positive_int
[docs]
class BuffDeviceManager(DeviceManager):
def __init__(
self,
device_selection_size: int,
min_hole_to_fill: int = 1,
device_reuse: bool = True,
):
"""Initialize the BuffDeviceManager.
BuffDeviceManager is responsible for managing the selection of devices for model training.
It maintains a list of available devices, tracks the current selection, and refills the selection as needed.
The device_selection_size determines how many "concurrent" devices can be selected for the training session.
The min_hole_to_fill determines how many empty slots should be created before refilling.
- An empty slot is created when any device reports its update back.
- To fill a slot, a new device is selected from the available device pool.
The device_reuse flag indicates whether devices can be reused across different model versions, if False, we will always select new devices when filling holes.
Args:
device_selection_size (int): Number of devices to select for each model update round.
min_hole_to_fill (int): Minimum number of empty slots in device selection before refilling. Defaults to 1 - once received an update, immediately sample a new device and send the current task to it.
device_reuse (bool): Whether to allow reusing devices across different model versions. Defaults to True.
"""
super().__init__()
check_positive_int("device_selection_size", device_selection_size)
check_positive_int("min_hole_to_fill", min_hole_to_fill)
self.device_selection_size = device_selection_size
self.min_hole_to_fill = min_hole_to_fill
self.device_reuse = device_reuse
# also keep track of the current selection version and used devices
self.current_selection_version = 0
[docs]
def update_available_devices(self, devices: Dict, fl_ctx) -> None:
self.available_devices.update(devices)
self.log_debug(
fl_ctx,
f"assessor got reported {len(devices)} available devices from child. "
f"total num available devices: {len(self.available_devices)}",
)
[docs]
def fill_selection(self, current_model_version: int, fl_ctx) -> None:
num_holes = self.device_selection_size - len(self.current_selection)
self.log_info(fl_ctx, f"filling {num_holes} holes in selection list")
if num_holes > 0:
self.current_selection_version += 1
# remove all used devices from available devices
usable_devices = set(self.available_devices.keys()) - set(self.used_devices.keys())
if usable_devices:
for _ in range(num_holes):
device_id = random.choice(list(usable_devices))
usable_devices.remove(device_id)
# current_selection keeps track of devices selected for a particular model version
self.current_selection[device_id] = current_model_version
self.used_devices[device_id] = {
"model_version": current_model_version,
"selection_version": self.current_selection_version,
}
if not usable_devices:
break
self.log_info(
fl_ctx,
f"current selection with {len(self.current_selection)} items: V{self.current_selection_version}; {dict(sorted(self.current_selection.items()))}",
)
if len(self.current_selection) < self.device_selection_size:
self.log_warning(
fl_ctx,
f"current selection has only {len(self.current_selection)} devices, which is less than the expected {self.device_selection_size} devices. Please check the configuration to make sure this is expected.",
)
[docs]
def remove_devices_from_selection(self, devices: Set[str], fl_ctx) -> None:
for device_id in devices:
self.current_selection.pop(device_id, None)
[docs]
def remove_devices_from_used(self, devices: Set[str], fl_ctx) -> None:
for device_id in devices:
self.used_devices.pop(device_id, None)
[docs]
def has_enough_devices(self, fl_ctx) -> bool:
num_holes = self.device_selection_size - len(self.current_selection)
usable_devices = set(self.available_devices.keys()) - set(self.used_devices.keys())
num_usable_devices = len(usable_devices)
return num_usable_devices >= num_holes
[docs]
def should_fill_selection(self, fl_ctx) -> bool:
num_holes = self.device_selection_size - len(self.current_selection)
return num_holes >= self.min_hole_to_fill
[docs]
def get_active_model_versions(self, fl_ctx) -> Set[int]:
return set(self.current_selection.values())