Source code for nvflare.app_opt.feature_election.controller

# Copyright (c) 2026, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os
from typing import Dict

import numpy as np

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller, Task
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal

logger = logging.getLogger(__name__)


[docs] class FeatureElectionController(Controller): """ Three-phase FL controller for federated feature selection and FedAvg training. Phase 1 — Local Feature Selection: each client runs its configured FS method and returns a feature mask and per-feature scores. Phase 2 — Tuning & Global Mask Distribution: the server optionally runs hill-climbing to find the optimal ``freedom_degree``, then aggregates client masks via weighted voting and distributes the global feature mask to all clients. If fewer than ``min_clients`` clients acknowledge the mask, the entire workflow is aborted. Phase 3 — FedAvg Training: standard federated averaging on the reduced feature set for ``num_rounds`` rounds. Args: freedom_degree: Threshold in [0, 1] controlling which features survive the vote. 0 = intersection (all clients must select), 1 = union (any client suffices). aggregation_mode: ``'weighted'`` weights each client by sample count; ``'uniform'`` treats all clients equally. min_clients: Minimum number of clients that must respond in each phase. num_rounds: Number of FedAvg training rounds in Phase 3. task_name: Must match the ``task_name`` configured on ``FeatureElectionExecutor``. train_timeout: Per-phase timeout in seconds. auto_tune: If ``True``, Phase 2 runs hill-climbing to optimise ``freedom_degree``. Has no effect when ``tuning_rounds=0`` (a warning is logged in that case). tuning_rounds: Number of hill-climbing iterations. Must be >= 2 for meaningful tuning; ``tuning_rounds=0`` disables tuning (with a warning if ``auto_tune=True``); ``tuning_rounds=1`` is also disabled (same warning). wait_time_after_min_received: Seconds to wait for additional client responses after ``min_clients`` have already replied. Set to ``0`` only for local simulation; a non-zero value (default 10 s) prevents slower clients from being silently excluded in heterogeneous production networks. """ def __init__( self, freedom_degree: float = 0.5, aggregation_mode: str = "weighted", min_clients: int = 2, num_rounds: int = 5, task_name: str = "feature_election", train_timeout: int = 300, auto_tune: bool = False, tuning_rounds: int = 0, wait_time_after_min_received: int = 10, ): super().__init__() if aggregation_mode not in ("weighted", "uniform"): raise ValueError( f"aggregation_mode must be 'weighted' or 'uniform', got {aggregation_mode!r}. " "Check the 'aggregation_mode' field in your job configuration." ) # Configuration self.freedom_degree = freedom_degree self.aggregation_mode = aggregation_mode self.custom_task_name = task_name self.min_clients = min_clients self.fl_rounds = num_rounds self.train_timeout = train_timeout self.wait_time_after_min_received = wait_time_after_min_received self.auto_tune = auto_tune self.tuning_rounds = tuning_rounds if auto_tune else 0 if auto_tune and self.tuning_rounds == 0: logger.warning( "auto_tune=True has no effect when tuning_rounds=0 (the default). " "Set tuning_rounds >= 2 to enable hill-climbing optimisation of freedom_degree." ) elif auto_tune and self.tuning_rounds == 1: logger.warning( "auto_tune requires tuning_rounds >= 2 to explore alternative freedom degrees " "(one baseline evaluation plus at least one neighbour to compare). " "Got tuning_rounds=1; auto-tuning will be disabled." ) self.tuning_rounds = 0 # State self.global_feature_mask = None self.global_weights = None self.cached_client_selections = {} self.phase_results = {} # Hill Climbing for auto-tuning self.tuning_history = [] self.search_step = 0.1 self.current_direction = 1 self.n_features = None
[docs] def advance_tuning(self, score: float, first_step: bool = False) -> None: """Record a tuning-round score and update freedom_degree for the next round. This is the public interface for the simulation path in :meth:`FeatureElection.simulate_election` so that the simulation does not need to mutate private controller state directly. The real FL path in ``control_flow`` uses the same internal helpers. Args: score: Weighted evaluation score for the current ``freedom_degree``. first_step: ``True`` only on the very first tuning round; passed through to ``_calculate_next_fd`` to seed the initial direction. """ self.tuning_history.append((self.freedom_degree, score)) self.freedom_degree = self._calculate_next_fd(first_step=first_step)
[docs] def start_controller(self, fl_ctx: FLContext) -> None: logger.info("Initializing FeatureElectionController (Base Controller Mode)")
[docs] def stop_controller(self, fl_ctx: FLContext): # Save results workspace = fl_ctx.get_engine().get_workspace() run_dir = workspace.get_run_dir(fl_ctx.get_job_id()) results = { "global_mask": self.global_feature_mask.tolist() if self.global_feature_mask is not None else None, "freedom_degree": float(self.freedom_degree), "num_features_selected": ( int(np.sum(self.global_feature_mask)) if self.global_feature_mask is not None else 0 ), } try: with open(os.path.join(run_dir, "feature_election_results.json"), "w") as f: json.dump(results, f, indent=2) except OSError as e: logger.error(f"Failed to write feature election results to {run_dir}: {e}") logger.info("Stopping Feature Election Controller")
[docs] def process_result_of_unknown_task( self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext ): """ Called when a result is received for an unknown task. This is a fallback - normally results come through task_done_cb. """ logger.warning(f"Received result for unknown task '{task_name}' from {client.name}")
[docs] def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None: """Main Orchestration Loop""" try: # --- PHASE 1: LOCAL FEATURE SELECTION (ELECTION) --- if not self._phase_one_election(abort_signal, fl_ctx): return # --- PHASE 2: TUNING & GLOBAL MASKING --- if not self._phase_two_tuning_and_masking(abort_signal, fl_ctx): return # --- PHASE 3: AGGREGATION ROUNDS (FL TRAINING) --- self._phase_three_aggregation(abort_signal, fl_ctx) logger.info("Feature Election Workflow Completed Successfully.") except Exception as e: logger.exception(f"Workflow failed: {e}") abort_signal.trigger() raise
# ============================================================================== # PHASE IMPLEMENTATIONS # ============================================================================== def _result_received_cb(self, client_task: ClientTask, fl_ctx: FLContext): """ Callback called when a result is received from a client. This is the proper way to collect results in NVFLARE. """ client_name = client_task.client.name result = client_task.result if result is None: logger.warning(f"No result from client {client_name}") return rc = result.get_return_code() if rc != ReturnCode.OK: logger.warning(f"Client {client_name} returned error: {rc}") return # Store the result self.phase_results[client_name] = result logger.debug(f"Received result from {client_name}") def _broadcast_and_gather( self, task_data: Shareable, abort_signal: Signal, fl_ctx: FLContext, timeout: int = 0 ) -> Dict[str, Shareable]: """ Helper to send tasks and collect results safely. Uses result_received_cb to properly collect results. """ # Clear buffer self.phase_results = {} # Create Task with callback task = Task( name=self.custom_task_name, data=task_data, timeout=timeout, result_received_cb=self._result_received_cb, ) # Broadcast and wait for results. # wait_time_after_min_received > 0 gives slower clients a window to respond # after min_clients have already replied, preventing silent exclusion. self.broadcast_and_wait( task=task, min_responses=self.min_clients, wait_time_after_min_received=self.wait_time_after_min_received, fl_ctx=fl_ctx, abort_signal=abort_signal, ) # Also collect any results from client_tasks (backup method) for client_task in task.client_tasks: client_name = client_task.client.name if client_name not in self.phase_results and client_task.result is not None: rc = client_task.result.get_return_code() if rc == ReturnCode.OK: self.phase_results[client_name] = client_task.result logger.debug(f"Collected result from task.client_tasks: {client_name}") logger.info(f"Collected {len(self.phase_results)} results") return self.phase_results def _phase_one_election(self, abort_signal: Signal, fl_ctx: FLContext) -> bool: logger.info("=== PHASE 1: Local Feature Selection & Election ===") task_data = Shareable() task_data["request_type"] = "feature_selection" # Broadcast and collect results results = self._broadcast_and_gather(task_data, abort_signal, fl_ctx) if not results: logger.error("No feature votes received. Aborting.") return False # Extract client data self.cached_client_selections = self._extract_client_data(results) if not self.cached_client_selections: logger.error("Received responses, but failed to extract selection data. Aborting.") return False logger.info(f"Phase 1 Complete. Processed votes from {len(self.cached_client_selections)} clients.") return True def _phase_two_tuning_and_masking(self, abort_signal: Signal, fl_ctx: FLContext): logger.info("=== PHASE 2: Tuning & Global Mask Generation ===") # 1. Run Tuning Loop (if enabled) if self.auto_tune and self.tuning_rounds > 0: logger.info(f"Starting Auto-tuning ({self.tuning_rounds} rounds)...") for i in range(self.tuning_rounds): if abort_signal.triggered: logger.warning("Abort signal received during tuning") break # Evaluate current freedom_degree mask = self.aggregate_selections(self.cached_client_selections) task_data = Shareable() task_data["request_type"] = "tuning_eval" task_data["tuning_mask"] = mask.tolist() results = self._broadcast_and_gather(task_data, abort_signal, fl_ctx) # Aggregate scores using the same weighting as mask aggregation so # the tuning objective is consistent with the actual aggregation_mode. weighted_score, total_weight = 0.0, 0.0 for v in results.values(): if "tuning_score" not in v: continue n = v.get("num_samples", 1) if self.aggregation_mode == "weighted" else 1 weighted_score += v["tuning_score"] * n total_weight += n if total_weight == 0.0: logger.warning( f"Tuning round {i + 1}: no clients returned a valid score; " "skipping history entry to avoid corrupting hill-climbing signal." ) continue score = weighted_score / total_weight logger.info( f"Tuning Round {i + 1}/{self.tuning_rounds}: FD={self.freedom_degree:.4f} -> Score={score:.4f}" ) self.tuning_history.append((self.freedom_degree, score)) # Early exit when the last 3 evaluated scores are indistinguishably # flat — further hill-climbing cannot improve freedom_degree on a # plateau and would only waste FL communication rounds. if len(self.tuning_history) >= 3: recent = [s for _, s in self.tuning_history[-3:]] score_range = max(recent) - min(recent) if score_range < 1e-4: logger.info( f"Tuning early exit at round {i + 1}: score plateau detected " f"(range {score_range:.2e} < 1e-4 over last 3 rounds). " "Selecting best freedom_degree from evaluated history." ) break # Calculate next FD for next iteration (if not last round) if i < self.tuning_rounds - 1: self.freedom_degree = self._calculate_next_fd(first_step=(i == 0)) # Select best FD from evaluated options if self.tuning_history: best_fd, best_score = max(self.tuning_history, key=lambda x: x[1]) self.freedom_degree = best_fd logger.info(f"Tuning Complete. Optimal Freedom Degree: {best_fd:.4f} (Score: {best_score:.4f})") else: logger.warning("No tuning results, keeping initial freedom_degree") # 2. Generate Final Mask final_mask = self.aggregate_selections(self.cached_client_selections) self.global_feature_mask = final_mask n_sel = np.sum(final_mask) logger.info( f"Final Global Mask: {n_sel} features selected " f"(FD={self.freedom_degree:.4f}, aggregation_mode={self.aggregation_mode})" ) # 3. Distribute mask to clients task_data = Shareable() task_data["request_type"] = "apply_mask" task_data["global_feature_mask"] = final_mask.tolist() mask_results = self._broadcast_and_gather(task_data, abort_signal, fl_ctx) if len(mask_results) < self.min_clients: logger.warning( f"Global mask distribution incomplete: only {len(mask_results)}/{self.min_clients} " "clients acknowledged. The entire FL workflow is being aborted — " "Phase 3 aggregation will not run." ) return False logger.info(f"Global mask distributed to {len(mask_results)} clients") return True def _phase_three_aggregation(self, abort_signal: Signal, fl_ctx: FLContext): logger.info(f"=== PHASE 3: Aggregation Rounds (FL Training - {self.fl_rounds} Rounds) ===") completed_rounds = 0 for i in range(1, self.fl_rounds + 1): if abort_signal.triggered: logger.warning( f"Abort signal received during FL training after {completed_rounds}/{self.fl_rounds} rounds" ) break logger.info(f"--- FL Round {i}/{self.fl_rounds} ---") task_data = Shareable() task_data["request_type"] = "train" if self.global_weights is not None: task_data["params"] = { k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in self.global_weights.items() } results = self._broadcast_and_gather(task_data, abort_signal, fl_ctx, timeout=self.train_timeout) if len(results) < self.min_clients: logger.warning( f"FL Round {i}: only {len(results)}/{self.min_clients} clients responded; " "proceeding with partial results." ) # Aggregate Weights (FedAvg) self._aggregate_weights(results) completed_rounds += 1 if completed_rounds == self.fl_rounds: logger.info(f"FL Training phase complete ({completed_rounds}/{self.fl_rounds} rounds)") else: logger.warning(f"FL Training phase ended early: {completed_rounds}/{self.fl_rounds} rounds completed") # ============================================================================== # HELPER METHODS # ============================================================================== def _aggregate_weights(self, results: Dict[str, Shareable]): """FedAvg-style weight aggregation""" total_samples = 0 weighted_weights = {} for shareable in results.values(): if "params" not in shareable: continue n = shareable.get("num_samples", 1) weights = shareable.get("params") if weights is not None: # Initialize weighted_weights from first valid weights if not weighted_weights: weighted_weights = {k: np.zeros_like(np.array(v)) for k, v in weights.items()} # Validate all keys before accumulating — a partial update would corrupt FedAvg. # Check both directions: unexpected client keys and missing client keys. client_valid = True missing_keys = [k for k in weighted_weights if k not in weights] if missing_keys: logger.warning(f"Client weights are missing expected keys {missing_keys}; skipping client") client_valid = False for k, v in weights.items(): if not client_valid: break v_array = np.array(v) if k not in weighted_weights: logger.warning(f"Unexpected weight key '{k}' from client, skipping client") client_valid = False break if weighted_weights[k].shape != v_array.shape: logger.error( f"Weight shape mismatch for key '{k}': expected {weighted_weights[k].shape}, got {v_array.shape}" ) client_valid = False break if client_valid: for k, v in weights.items(): weighted_weights[k] += np.array(v) * n total_samples += n if total_samples > 0: self.global_weights = {k: v / total_samples for k, v in weighted_weights.items()} logger.info(f"Aggregated weights from {len(results)} clients ({total_samples} samples)") else: logger.warning("Weight aggregation skipped: no clients returned valid parameters; global weights unchanged") def _extract_client_data(self, results: Dict[str, Shareable]) -> Dict[str, Dict]: """Extract feature selection data from client results""" client_data = {} for key, contrib in results.items(): if "selected_features" in contrib: selected = np.array(contrib["selected_features"]) # Get n_features from first client response if self.n_features is None: self.n_features = len(selected) logger.debug(f"Inferred n_features={self.n_features} from {key}") # Reject all-zero masks: a client that selected no features would # silently bias the global mask toward the intersection of other clients' # masks without contributing any signal of its own. This mirrors the # ValueError raised in the simulation path (simulate_election). if not np.any(selected): logger.warning( f"Client {key} returned an all-False feature mask; skipping this " "client's vote to avoid corrupting the global mask. " "Consider lowering the regularisation strength " "(e.g. reduce 'alpha' for Lasso/ElasticNet)." ) continue client_data[key] = { "selected_features": selected, "feature_scores": np.array(contrib["feature_scores"]), "num_samples": contrib.get("num_samples", 1), } logger.debug(f"Extracted {np.sum(contrib['selected_features'])} features from {key}") return client_data
[docs] def aggregate_selections(self, client_selections: Dict[str, Dict]) -> np.ndarray: """ Aggregate feature selections from all clients. Freedom degree controls the blend between intersection and union: - FD=0: Intersection (only features selected by ALL clients) - FD=1: Union (features selected by ANY client) - 0<FD<1: Weighted voting based on scores """ if not client_selections: logger.warning("No client selections to aggregate") n = self.n_features if n is None: logger.error("Cannot create empty mask: self.n_features is None") raise ValueError("Total number of features (n_features) must be known before aggregation") return np.zeros(n, dtype=bool) masks = [s["selected_features"] for s in client_selections.values()] scores = [s["feature_scores"] for s in client_selections.values()] weights = [s["num_samples"] for s in client_selections.values()] masks = np.array(masks) scores = np.array(scores) total = sum(weights) weights = np.array(weights) / total if total > 0 else np.ones(len(weights)) / len(weights) intersection = np.all(masks, axis=0) union = np.any(masks, axis=0) # Handle edge cases if self.freedom_degree <= 0.01: return intersection if self.freedom_degree >= 0.99: return union return self._weighted_election(masks, scores, weights, intersection, union)
def _weighted_election( self, masks: np.ndarray, scores: np.ndarray, weights: np.ndarray, intersection: np.ndarray, union: np.ndarray ) -> np.ndarray: """ Perform weighted voting for features in the difference set. Uses aggregation_mode to determine weighting strategy. """ diff_mask = union & ~intersection if not np.any(diff_mask): return intersection # Compute aggregated scores based on aggregation_mode agg_scores = np.zeros(len(intersection)) # Determine weights based on aggregation mode if self.aggregation_mode == "uniform": # Equal weight for all clients effective_weights = np.ones(len(weights)) / len(weights) else: # "weighted" mode (default) # Use sample-size-based weights effective_weights = weights for i, (m, s) in enumerate(zip(masks, scores)): valid = m.astype(bool) if not np.any(valid): logger.warning(f"Client {i} has no selected features, skipping") continue min_s, max_s = np.min(s[valid]), np.max(s[valid]) if max_s > min_s: norm_s = np.where(valid, (s - min_s) / (max_s - min_s), 0.0) else: norm_s = np.where(valid, 0.5, 0.0) agg_scores += norm_s * effective_weights[i] # Select top features from (Union - Intersection) based on freedom_degree n_add = int(np.ceil(np.sum(diff_mask) * self.freedom_degree)) if n_add > 0: diff_indices = np.where(diff_mask)[0] diff_scores = agg_scores[diff_indices] top_indices = diff_indices[np.argsort(diff_scores, kind="stable")[-n_add:]] selected_diff = np.zeros_like(diff_mask) selected_diff[top_indices] = True return intersection | selected_diff # No features to add else: return intersection def _calculate_next_fd(self, first_step: bool) -> float: """Hill-climbing to find optimal freedom degree""" MIN_FD, MAX_FD = 0.05, 1.0 if first_step: # Choose the initial direction so the first step stays within [MIN_FD, MAX_FD] # without clipping, which would waste a tuning round on a no-op move. # Check both boundaries symmetrically: prefer the direction with more headroom. near_max = self.freedom_degree + self.search_step > MAX_FD near_min = self.freedom_degree - self.search_step < MIN_FD if near_max and not near_min: self.current_direction = -1 elif near_min and not near_max: self.current_direction = 1 # If near both (search_step is very large relative to the range), keep # current_direction and let np.clip handle the boundary; the step will # land at the nearer bound, which is the best available move. return np.clip(self.freedom_degree + (self.current_direction * self.search_step), MIN_FD, MAX_FD) if len(self.tuning_history) < 2: return self.freedom_degree curr_fd, curr_score = self.tuning_history[-1] prev_fd, prev_score = self.tuning_history[-2] if curr_score > prev_score: new_fd = curr_fd + (self.current_direction * self.search_step) else: self.current_direction *= -1 self.search_step = max(self.search_step * 0.5, 1e-3) # Step from curr_fd (not prev_fd) so the explorer backtracks from its # current position rather than skipping the region between the last two # evaluated points. new_fd = curr_fd + (self.current_direction * self.search_step) return np.clip(new_fd, MIN_FD, MAX_FD)