# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from collections import Counter, defaultdict
from typing import Dict, Set
from nvflare.edge.assessors.device_manager import DeviceManager
from nvflare.edge.mud import PropKey
from nvflare.fuel.utils.validation_utils import check_positive_int
[docs]
class BuffDeviceManager(DeviceManager):
def __init__(
self,
device_selection_size: int,
initial_min_client_num: int = 1,
min_hole_to_fill: int = 1,
device_reuse: bool = True,
device_sampling_strategy: str = "balanced",
):
"""Initialize the BuffDeviceManager.
BuffDeviceManager is responsible for managing the selection of devices for model training.
It maintains a list of available devices, tracks the current selection, and refills the selection as needed.
The device_selection_size determines how many "concurrent" devices can be selected for the training session.
The min_hole_to_fill determines how many empty slots should be created before refilling.
- An empty slot is created when any device reports its update back.
- To fill a slot, a new device is selected from the available device pool.
The device_reuse flag indicates whether devices can be reused across different model versions, if False, we will always select new devices when filling holes.
Args:
device_selection_size (int): Number of devices to select for each model update round.
initial_min_client_num (int): Minimum number of clients to have at the beginning. This can be useful for initial model dispatch.
min_hole_to_fill (int): Minimum number of empty slots in device selection before refilling. Defaults to 1 - once received an update, immediately sample a new device and send the current task to it.
device_reuse (bool): Whether to allow reusing devices across different model versions. Defaults to True.
device_sampling_strategy (str): Strategy for sampling devices when filling selection. Defaults to "balanced".
- "balanced": try to balance the usage of devices across clients.
- "random": randomly select devices from the available pool.
"""
super().__init__()
check_positive_int("device_selection_size", device_selection_size)
check_positive_int("min_hole_to_fill", min_hole_to_fill)
check_positive_int("initial_min_client_num", initial_min_client_num)
if device_sampling_strategy not in ("balanced", "random"):
raise ValueError(
f"device_sampling_strategy must be 'balanced' or 'random', got '{device_sampling_strategy}'"
)
self.device_selection_size = device_selection_size
self.initial_min_client_num = initial_min_client_num
self.min_hole_to_fill = min_hole_to_fill
self.device_reuse = device_reuse
self.device_sampling_strategy = device_sampling_strategy
# also keep track of the current selection version and used devices
self.current_selection_version = 0
self.used_devices = {}
# keep a map of device_id -> client_name
self.device_client_map = {}
def _balanced_device_sampling(self, usable_devices: Set[str], num_holes: int) -> Set[str]:
"""Sample devices while balancing across clients.
Args:
usable_devices: Set of device IDs that can be selected
num_holes: Number of devices to sample
Returns:
Set of selected device IDs
"""
if not usable_devices or num_holes <= 0:
return set()
# Count devices per client efficiently using Counter
client_device_counts = Counter(
self.device_client_map[device_id] for device_id in usable_devices if device_id in self.device_client_map
)
# Group devices by client using defaultdict for efficiency
client_devices = defaultdict(list)
for device_id in usable_devices:
if device_id in self.device_client_map:
client_devices[self.device_client_map[device_id]].append(device_id)
if not client_device_counts:
# Fallback to random sampling if no client mapping
return set(random.sample(list(usable_devices), min(num_holes, len(usable_devices))))
# Randomize client order for more balanced distribution
clients_list = list(client_device_counts.items())
random.shuffle(clients_list)
selected_devices = set()
remaining_holes = num_holes
# First pass: assign minimum possible to each client
min_per_client = remaining_holes // len(clients_list)
extra_holes = remaining_holes % len(clients_list)
for i, (client_name, device_count) in enumerate(clients_list):
# Calculate how many devices this client should get
if i < extra_holes:
target_count = min_per_client + 1
else:
target_count = min_per_client
# Don't exceed what the client has available
actual_count = min(target_count, device_count)
if actual_count > 0:
# Randomly sample from this client's devices
sampled = random.sample(client_devices[client_name], actual_count)
selected_devices.update(sampled)
remaining_holes -= actual_count
# Remove selected devices from available pool
client_devices[client_name] = [d for d in client_devices[client_name] if d not in sampled]
# Second pass: if we still have holes and some clients have remaining devices,
# distribute remaining holes as evenly as possible with random starting point
if remaining_holes > 0:
clients_with_devices = [(name, devices) for name, devices in client_devices.items() if devices]
if clients_with_devices:
# Shuffle clients to randomize the round-robin starting point
random.shuffle(clients_with_devices)
# Round-robin distribution of remaining holes
client_idx = 0
while remaining_holes > 0 and clients_with_devices:
client_name, devices = clients_with_devices[client_idx]
if devices:
# Take one device from this client
device_id = random.choice(devices)
selected_devices.add(device_id)
devices.remove(device_id)
remaining_holes -= 1
# Remove client if no more devices
if not devices:
clients_with_devices.pop(client_idx)
if clients_with_devices:
client_idx = client_idx % len(clients_with_devices)
else:
client_idx = (client_idx + 1) % len(clients_with_devices)
else:
clients_with_devices.pop(client_idx)
if clients_with_devices:
client_idx = client_idx % len(clients_with_devices)
return selected_devices
[docs]
def update_available_devices(self, devices: Dict, fl_ctx) -> None:
self.available_devices.update(devices)
self.log_debug(
fl_ctx,
f"assessor got reported {len(devices)} available devices from child. "
f"total num available devices: {len(self.available_devices)}",
)
# add new devices to device_client_map
for device_id, device in devices.items():
client_name = device.to_dict().get(PropKey.CLIENT_NAME)
if client_name:
self.device_client_map[device_id] = client_name
[docs]
def fill_selection(self, current_model_version: int, fl_ctx) -> None:
num_holes = self.device_selection_size - len(self.current_selection)
self.log_info(fl_ctx, f"filling {num_holes} holes in selection list")
if num_holes > 0:
self.current_selection_version += 1
# remove all used devices from available devices
usable_devices = set(self.available_devices.keys()) - set(self.used_devices.keys())
if usable_devices:
if self.device_sampling_strategy == "balanced":
# try to balance the usage of devices across clients
selected_devices = self._balanced_device_sampling(usable_devices, num_holes)
for device_id in selected_devices:
# current_selection keeps track of devices selected for a particular model version
self.current_selection[device_id] = current_model_version
self.used_devices[device_id] = {
"model_version": current_model_version,
"selection_version": self.current_selection_version,
}
elif self.device_sampling_strategy == "random":
for _ in range(num_holes):
device_id = random.choice(list(usable_devices))
usable_devices.remove(device_id)
# current_selection keeps track of devices selected for a particular model version
self.current_selection[device_id] = current_model_version
self.used_devices[device_id] = {
"model_version": current_model_version,
"selection_version": self.current_selection_version,
}
if not usable_devices:
break
else:
raise ValueError(f"Invalid device sampling strategy: {self.device_sampling_strategy}")
self.log_info(
fl_ctx,
f"current selection with {len(self.current_selection)} items: V{self.current_selection_version}; {dict(sorted(self.current_selection.items()))}",
)
if len(self.current_selection) < self.device_selection_size:
self.log_warning(
fl_ctx,
f"current selection has only {len(self.current_selection)} devices, which is less than the expected {self.device_selection_size} devices. Please check the configuration to make sure this is expected.",
)
[docs]
def remove_devices_from_selection(self, devices: Set[str], fl_ctx) -> None:
for device_id in devices:
self.current_selection.pop(device_id, None)
[docs]
def remove_devices_from_used(self, devices: Set[str], fl_ctx) -> None:
for device_id in devices:
self.used_devices.pop(device_id, None)
[docs]
def has_enough_devices_and_clients(self, fl_ctx) -> bool:
num_holes = self.device_selection_size - len(self.current_selection)
usable_devices = set(self.available_devices.keys()) - set(self.used_devices.keys())
num_usable_devices = len(usable_devices)
if num_usable_devices < num_holes:
return False
# Further check if we have enough clients
unique_clients = set(self.device_client_map.values())
return len(unique_clients) >= self.initial_min_client_num
[docs]
def should_fill_selection(self, fl_ctx) -> bool:
num_holes = self.device_selection_size - len(self.current_selection)
return num_holes >= self.min_hole_to_fill
[docs]
def get_active_model_versions(self, fl_ctx) -> Set[int]:
return set(self.current_selection.values())