# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
from typing import Dict
import numpy as np
from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller, Task
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
logger = logging.getLogger(__name__)
[docs]
class FeatureElectionController(Controller):
"""
Three-phase FL controller for federated feature selection and FedAvg training.
Phase 1 — Local Feature Selection: each client runs its configured FS method and
returns a feature mask and per-feature scores.
Phase 2 — Tuning & Global Mask Distribution: the server optionally runs hill-climbing
to find the optimal ``freedom_degree``, then aggregates client masks via weighted voting
and distributes the global feature mask to all clients. If fewer than ``min_clients``
clients acknowledge the mask, the entire workflow is aborted.
Phase 3 — FedAvg Training: standard federated averaging on the reduced feature set
for ``num_rounds`` rounds.
Args:
freedom_degree: Threshold in [0, 1] controlling which features survive the vote.
0 = intersection (all clients must select), 1 = union (any client suffices).
aggregation_mode: ``'weighted'`` weights each client by sample count;
``'uniform'`` treats all clients equally.
min_clients: Minimum number of clients that must respond in each phase.
num_rounds: Number of FedAvg training rounds in Phase 3.
task_name: Must match the ``task_name`` configured on ``FeatureElectionExecutor``.
train_timeout: Per-phase timeout in seconds.
auto_tune: If ``True``, Phase 2 runs hill-climbing to optimise ``freedom_degree``.
Has no effect when ``tuning_rounds=0`` (a warning is logged in that case).
tuning_rounds: Number of hill-climbing iterations. Must be >= 2 for meaningful
tuning; ``tuning_rounds=0`` disables tuning (with a warning if ``auto_tune=True``);
``tuning_rounds=1`` is also disabled (same warning).
wait_time_after_min_received: Seconds to wait for additional client responses after
``min_clients`` have already replied. Set to ``0`` only for local simulation;
a non-zero value (default 10 s) prevents slower clients from being silently
excluded in heterogeneous production networks.
"""
def __init__(
self,
freedom_degree: float = 0.5,
aggregation_mode: str = "weighted",
min_clients: int = 2,
num_rounds: int = 5,
task_name: str = "feature_election",
train_timeout: int = 300,
auto_tune: bool = False,
tuning_rounds: int = 0,
wait_time_after_min_received: int = 10,
):
super().__init__()
if aggregation_mode not in ("weighted", "uniform"):
raise ValueError(
f"aggregation_mode must be 'weighted' or 'uniform', got {aggregation_mode!r}. "
"Check the 'aggregation_mode' field in your job configuration."
)
# Configuration
self.freedom_degree = freedom_degree
self.aggregation_mode = aggregation_mode
self.custom_task_name = task_name
self.min_clients = min_clients
self.fl_rounds = num_rounds
self.train_timeout = train_timeout
self.wait_time_after_min_received = wait_time_after_min_received
self.auto_tune = auto_tune
self.tuning_rounds = tuning_rounds if auto_tune else 0
if auto_tune and self.tuning_rounds == 0:
logger.warning(
"auto_tune=True has no effect when tuning_rounds=0 (the default). "
"Set tuning_rounds >= 2 to enable hill-climbing optimisation of freedom_degree."
)
elif auto_tune and self.tuning_rounds == 1:
logger.warning(
"auto_tune requires tuning_rounds >= 2 to explore alternative freedom degrees "
"(one baseline evaluation plus at least one neighbour to compare). "
"Got tuning_rounds=1; auto-tuning will be disabled."
)
self.tuning_rounds = 0
# State
self.global_feature_mask = None
self.global_weights = None
self.cached_client_selections = {}
self.phase_results = {}
# Hill Climbing for auto-tuning
self.tuning_history = []
self.search_step = 0.1
self.current_direction = 1
self.n_features = None
[docs]
def advance_tuning(self, score: float, first_step: bool = False) -> None:
"""Record a tuning-round score and update freedom_degree for the next round.
This is the public interface for the simulation path in
:meth:`FeatureElection.simulate_election` so that the simulation does not
need to mutate private controller state directly. The real FL path in
``control_flow`` uses the same internal helpers.
Args:
score: Weighted evaluation score for the current ``freedom_degree``.
first_step: ``True`` only on the very first tuning round; passed
through to ``_calculate_next_fd`` to seed the initial direction.
"""
self.tuning_history.append((self.freedom_degree, score))
self.freedom_degree = self._calculate_next_fd(first_step=first_step)
[docs]
def start_controller(self, fl_ctx: FLContext) -> None:
logger.info("Initializing FeatureElectionController (Base Controller Mode)")
[docs]
def stop_controller(self, fl_ctx: FLContext):
# Save results
workspace = fl_ctx.get_engine().get_workspace()
run_dir = workspace.get_run_dir(fl_ctx.get_job_id())
results = {
"global_mask": self.global_feature_mask.tolist() if self.global_feature_mask is not None else None,
"freedom_degree": float(self.freedom_degree),
"num_features_selected": (
int(np.sum(self.global_feature_mask)) if self.global_feature_mask is not None else 0
),
}
try:
with open(os.path.join(run_dir, "feature_election_results.json"), "w") as f:
json.dump(results, f, indent=2)
except OSError as e:
logger.error(f"Failed to write feature election results to {run_dir}: {e}")
logger.info("Stopping Feature Election Controller")
[docs]
def process_result_of_unknown_task(
self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext
):
"""
Called when a result is received for an unknown task.
This is a fallback - normally results come through task_done_cb.
"""
logger.warning(f"Received result for unknown task '{task_name}' from {client.name}")
[docs]
def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None:
"""Main Orchestration Loop"""
try:
# --- PHASE 1: LOCAL FEATURE SELECTION (ELECTION) ---
if not self._phase_one_election(abort_signal, fl_ctx):
return
# --- PHASE 2: TUNING & GLOBAL MASKING ---
if not self._phase_two_tuning_and_masking(abort_signal, fl_ctx):
return
# --- PHASE 3: AGGREGATION ROUNDS (FL TRAINING) ---
self._phase_three_aggregation(abort_signal, fl_ctx)
logger.info("Feature Election Workflow Completed Successfully.")
except Exception as e:
logger.exception(f"Workflow failed: {e}")
abort_signal.trigger()
raise
# ==============================================================================
# PHASE IMPLEMENTATIONS
# ==============================================================================
def _result_received_cb(self, client_task: ClientTask, fl_ctx: FLContext):
"""
Callback called when a result is received from a client.
This is the proper way to collect results in NVFLARE.
"""
client_name = client_task.client.name
result = client_task.result
if result is None:
logger.warning(f"No result from client {client_name}")
return
rc = result.get_return_code()
if rc != ReturnCode.OK:
logger.warning(f"Client {client_name} returned error: {rc}")
return
# Store the result
self.phase_results[client_name] = result
logger.debug(f"Received result from {client_name}")
def _broadcast_and_gather(
self, task_data: Shareable, abort_signal: Signal, fl_ctx: FLContext, timeout: int = 0
) -> Dict[str, Shareable]:
"""
Helper to send tasks and collect results safely.
Uses result_received_cb to properly collect results.
"""
# Clear buffer
self.phase_results = {}
# Create Task with callback
task = Task(
name=self.custom_task_name,
data=task_data,
timeout=timeout,
result_received_cb=self._result_received_cb,
)
# Broadcast and wait for results.
# wait_time_after_min_received > 0 gives slower clients a window to respond
# after min_clients have already replied, preventing silent exclusion.
self.broadcast_and_wait(
task=task,
min_responses=self.min_clients,
wait_time_after_min_received=self.wait_time_after_min_received,
fl_ctx=fl_ctx,
abort_signal=abort_signal,
)
# Also collect any results from client_tasks (backup method)
for client_task in task.client_tasks:
client_name = client_task.client.name
if client_name not in self.phase_results and client_task.result is not None:
rc = client_task.result.get_return_code()
if rc == ReturnCode.OK:
self.phase_results[client_name] = client_task.result
logger.debug(f"Collected result from task.client_tasks: {client_name}")
logger.info(f"Collected {len(self.phase_results)} results")
return self.phase_results
def _phase_one_election(self, abort_signal: Signal, fl_ctx: FLContext) -> bool:
logger.info("=== PHASE 1: Local Feature Selection & Election ===")
task_data = Shareable()
task_data["request_type"] = "feature_selection"
# Broadcast and collect results
results = self._broadcast_and_gather(task_data, abort_signal, fl_ctx)
if not results:
logger.error("No feature votes received. Aborting.")
return False
# Extract client data
self.cached_client_selections = self._extract_client_data(results)
if not self.cached_client_selections:
logger.error("Received responses, but failed to extract selection data. Aborting.")
return False
logger.info(f"Phase 1 Complete. Processed votes from {len(self.cached_client_selections)} clients.")
return True
def _phase_two_tuning_and_masking(self, abort_signal: Signal, fl_ctx: FLContext):
logger.info("=== PHASE 2: Tuning & Global Mask Generation ===")
# 1. Run Tuning Loop (if enabled)
if self.auto_tune and self.tuning_rounds > 0:
logger.info(f"Starting Auto-tuning ({self.tuning_rounds} rounds)...")
for i in range(self.tuning_rounds):
if abort_signal.triggered:
logger.warning("Abort signal received during tuning")
break
# Evaluate current freedom_degree
mask = self.aggregate_selections(self.cached_client_selections)
task_data = Shareable()
task_data["request_type"] = "tuning_eval"
task_data["tuning_mask"] = mask.tolist()
results = self._broadcast_and_gather(task_data, abort_signal, fl_ctx)
# Aggregate scores using the same weighting as mask aggregation so
# the tuning objective is consistent with the actual aggregation_mode.
weighted_score, total_weight = 0.0, 0.0
for v in results.values():
if "tuning_score" not in v:
continue
n = v.get("num_samples", 1) if self.aggregation_mode == "weighted" else 1
weighted_score += v["tuning_score"] * n
total_weight += n
if total_weight == 0.0:
logger.warning(
f"Tuning round {i + 1}: no clients returned a valid score; "
"skipping history entry to avoid corrupting hill-climbing signal."
)
continue
score = weighted_score / total_weight
logger.info(
f"Tuning Round {i + 1}/{self.tuning_rounds}: FD={self.freedom_degree:.4f} -> Score={score:.4f}"
)
self.tuning_history.append((self.freedom_degree, score))
# Early exit when the last 3 evaluated scores are indistinguishably
# flat — further hill-climbing cannot improve freedom_degree on a
# plateau and would only waste FL communication rounds.
if len(self.tuning_history) >= 3:
recent = [s for _, s in self.tuning_history[-3:]]
score_range = max(recent) - min(recent)
if score_range < 1e-4:
logger.info(
f"Tuning early exit at round {i + 1}: score plateau detected "
f"(range {score_range:.2e} < 1e-4 over last 3 rounds). "
"Selecting best freedom_degree from evaluated history."
)
break
# Calculate next FD for next iteration (if not last round)
if i < self.tuning_rounds - 1:
self.freedom_degree = self._calculate_next_fd(first_step=(i == 0))
# Select best FD from evaluated options
if self.tuning_history:
best_fd, best_score = max(self.tuning_history, key=lambda x: x[1])
self.freedom_degree = best_fd
logger.info(f"Tuning Complete. Optimal Freedom Degree: {best_fd:.4f} (Score: {best_score:.4f})")
else:
logger.warning("No tuning results, keeping initial freedom_degree")
# 2. Generate Final Mask
final_mask = self.aggregate_selections(self.cached_client_selections)
self.global_feature_mask = final_mask
n_sel = np.sum(final_mask)
logger.info(
f"Final Global Mask: {n_sel} features selected "
f"(FD={self.freedom_degree:.4f}, aggregation_mode={self.aggregation_mode})"
)
# 3. Distribute mask to clients
task_data = Shareable()
task_data["request_type"] = "apply_mask"
task_data["global_feature_mask"] = final_mask.tolist()
mask_results = self._broadcast_and_gather(task_data, abort_signal, fl_ctx)
if len(mask_results) < self.min_clients:
logger.warning(
f"Global mask distribution incomplete: only {len(mask_results)}/{self.min_clients} "
"clients acknowledged. The entire FL workflow is being aborted — "
"Phase 3 aggregation will not run."
)
return False
logger.info(f"Global mask distributed to {len(mask_results)} clients")
return True
def _phase_three_aggregation(self, abort_signal: Signal, fl_ctx: FLContext):
logger.info(f"=== PHASE 3: Aggregation Rounds (FL Training - {self.fl_rounds} Rounds) ===")
completed_rounds = 0
for i in range(1, self.fl_rounds + 1):
if abort_signal.triggered:
logger.warning(
f"Abort signal received during FL training after {completed_rounds}/{self.fl_rounds} rounds"
)
break
logger.info(f"--- FL Round {i}/{self.fl_rounds} ---")
task_data = Shareable()
task_data["request_type"] = "train"
if self.global_weights is not None:
task_data["params"] = {
k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in self.global_weights.items()
}
results = self._broadcast_and_gather(task_data, abort_signal, fl_ctx, timeout=self.train_timeout)
if len(results) < self.min_clients:
logger.warning(
f"FL Round {i}: only {len(results)}/{self.min_clients} clients responded; "
"proceeding with partial results."
)
# Aggregate Weights (FedAvg)
self._aggregate_weights(results)
completed_rounds += 1
if completed_rounds == self.fl_rounds:
logger.info(f"FL Training phase complete ({completed_rounds}/{self.fl_rounds} rounds)")
else:
logger.warning(f"FL Training phase ended early: {completed_rounds}/{self.fl_rounds} rounds completed")
# ==============================================================================
# HELPER METHODS
# ==============================================================================
def _aggregate_weights(self, results: Dict[str, Shareable]):
"""FedAvg-style weight aggregation"""
total_samples = 0
weighted_weights = {}
for shareable in results.values():
if "params" not in shareable:
continue
n = shareable.get("num_samples", 1)
weights = shareable.get("params")
if weights is not None:
# Initialize weighted_weights from first valid weights
if not weighted_weights:
weighted_weights = {k: np.zeros_like(np.array(v)) for k, v in weights.items()}
# Validate all keys before accumulating — a partial update would corrupt FedAvg.
# Check both directions: unexpected client keys and missing client keys.
client_valid = True
missing_keys = [k for k in weighted_weights if k not in weights]
if missing_keys:
logger.warning(f"Client weights are missing expected keys {missing_keys}; skipping client")
client_valid = False
for k, v in weights.items():
if not client_valid:
break
v_array = np.array(v)
if k not in weighted_weights:
logger.warning(f"Unexpected weight key '{k}' from client, skipping client")
client_valid = False
break
if weighted_weights[k].shape != v_array.shape:
logger.error(
f"Weight shape mismatch for key '{k}': expected {weighted_weights[k].shape}, got {v_array.shape}"
)
client_valid = False
break
if client_valid:
for k, v in weights.items():
weighted_weights[k] += np.array(v) * n
total_samples += n
if total_samples > 0:
self.global_weights = {k: v / total_samples for k, v in weighted_weights.items()}
logger.info(f"Aggregated weights from {len(results)} clients ({total_samples} samples)")
else:
logger.warning("Weight aggregation skipped: no clients returned valid parameters; global weights unchanged")
def _extract_client_data(self, results: Dict[str, Shareable]) -> Dict[str, Dict]:
"""Extract feature selection data from client results"""
client_data = {}
for key, contrib in results.items():
if "selected_features" in contrib:
selected = np.array(contrib["selected_features"])
# Get n_features from first client response
if self.n_features is None:
self.n_features = len(selected)
logger.debug(f"Inferred n_features={self.n_features} from {key}")
# Reject all-zero masks: a client that selected no features would
# silently bias the global mask toward the intersection of other clients'
# masks without contributing any signal of its own. This mirrors the
# ValueError raised in the simulation path (simulate_election).
if not np.any(selected):
logger.warning(
f"Client {key} returned an all-False feature mask; skipping this "
"client's vote to avoid corrupting the global mask. "
"Consider lowering the regularisation strength "
"(e.g. reduce 'alpha' for Lasso/ElasticNet)."
)
continue
client_data[key] = {
"selected_features": selected,
"feature_scores": np.array(contrib["feature_scores"]),
"num_samples": contrib.get("num_samples", 1),
}
logger.debug(f"Extracted {np.sum(contrib['selected_features'])} features from {key}")
return client_data
[docs]
def aggregate_selections(self, client_selections: Dict[str, Dict]) -> np.ndarray:
"""
Aggregate feature selections from all clients.
Freedom degree controls the blend between intersection and union:
- FD=0: Intersection (only features selected by ALL clients)
- FD=1: Union (features selected by ANY client)
- 0<FD<1: Weighted voting based on scores
"""
if not client_selections:
logger.warning("No client selections to aggregate")
n = self.n_features
if n is None:
logger.error("Cannot create empty mask: self.n_features is None")
raise ValueError("Total number of features (n_features) must be known before aggregation")
return np.zeros(n, dtype=bool)
masks = [s["selected_features"] for s in client_selections.values()]
scores = [s["feature_scores"] for s in client_selections.values()]
weights = [s["num_samples"] for s in client_selections.values()]
masks = np.array(masks)
scores = np.array(scores)
total = sum(weights)
weights = np.array(weights) / total if total > 0 else np.ones(len(weights)) / len(weights)
intersection = np.all(masks, axis=0)
union = np.any(masks, axis=0)
# Handle edge cases
if self.freedom_degree <= 0.01:
return intersection
if self.freedom_degree >= 0.99:
return union
return self._weighted_election(masks, scores, weights, intersection, union)
def _weighted_election(
self, masks: np.ndarray, scores: np.ndarray, weights: np.ndarray, intersection: np.ndarray, union: np.ndarray
) -> np.ndarray:
"""
Perform weighted voting for features in the difference set.
Uses aggregation_mode to determine weighting strategy.
"""
diff_mask = union & ~intersection
if not np.any(diff_mask):
return intersection
# Compute aggregated scores based on aggregation_mode
agg_scores = np.zeros(len(intersection))
# Determine weights based on aggregation mode
if self.aggregation_mode == "uniform":
# Equal weight for all clients
effective_weights = np.ones(len(weights)) / len(weights)
else: # "weighted" mode (default)
# Use sample-size-based weights
effective_weights = weights
for i, (m, s) in enumerate(zip(masks, scores)):
valid = m.astype(bool)
if not np.any(valid):
logger.warning(f"Client {i} has no selected features, skipping")
continue
min_s, max_s = np.min(s[valid]), np.max(s[valid])
if max_s > min_s:
norm_s = np.where(valid, (s - min_s) / (max_s - min_s), 0.0)
else:
norm_s = np.where(valid, 0.5, 0.0)
agg_scores += norm_s * effective_weights[i]
# Select top features from (Union - Intersection) based on freedom_degree
n_add = int(np.ceil(np.sum(diff_mask) * self.freedom_degree))
if n_add > 0:
diff_indices = np.where(diff_mask)[0]
diff_scores = agg_scores[diff_indices]
top_indices = diff_indices[np.argsort(diff_scores, kind="stable")[-n_add:]]
selected_diff = np.zeros_like(diff_mask)
selected_diff[top_indices] = True
return intersection | selected_diff
# No features to add
else:
return intersection
def _calculate_next_fd(self, first_step: bool) -> float:
"""Hill-climbing to find optimal freedom degree"""
MIN_FD, MAX_FD = 0.05, 1.0
if first_step:
# Choose the initial direction so the first step stays within [MIN_FD, MAX_FD]
# without clipping, which would waste a tuning round on a no-op move.
# Check both boundaries symmetrically: prefer the direction with more headroom.
near_max = self.freedom_degree + self.search_step > MAX_FD
near_min = self.freedom_degree - self.search_step < MIN_FD
if near_max and not near_min:
self.current_direction = -1
elif near_min and not near_max:
self.current_direction = 1
# If near both (search_step is very large relative to the range), keep
# current_direction and let np.clip handle the boundary; the step will
# land at the nearer bound, which is the best available move.
return np.clip(self.freedom_degree + (self.current_direction * self.search_step), MIN_FD, MAX_FD)
if len(self.tuning_history) < 2:
return self.freedom_degree
curr_fd, curr_score = self.tuning_history[-1]
prev_fd, prev_score = self.tuning_history[-2]
if curr_score > prev_score:
new_fd = curr_fd + (self.current_direction * self.search_step)
else:
self.current_direction *= -1
self.search_step = max(self.search_step * 0.5, 1e-3)
# Step from curr_fd (not prev_fd) so the explorer backtracks from its
# current position rather than skipping the region between the last two
# evaluated points.
new_fd = curr_fd + (self.current_direction * self.search_step)
return np.clip(new_fd, MIN_FD, MAX_FD)