Source code for nvflare.app_common.workflows.statistics_controller

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from typing import Callable, Dict, List, Optional

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.dxo import from_shareable
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.statistics_spec import Bin, Histogram, StatisticConfig
from nvflare.app_common.abstract.statistics_writer import StatisticsWriter
from nvflare.app_common.app_constant import StatisticsConstants as StC
from nvflare.app_common.statistics.numeric_stats import get_global_stats
from nvflare.app_common.statistics.statisitcs_objects_decomposer import fobs_registration
from nvflare.fuel.utils import fobs


[docs]class StatisticsController(Controller): def __init__( self, statistic_configs: Dict[str, dict], writer_id: str, wait_time_after_min_received: int = 1, result_wait_timeout: int = 10, precision=4, min_clients: Optional[int] = None, enable_pre_run_task: bool = True, ): """Controller for Statistics. Args: statistic_configs: defines the input statistic to be computed and each statistic's configuration, see below for details. writer_id: ID for StatisticsWriter. The StatisticWriter will save the result to output specified by the StatisticsWriter wait_time_after_min_received: numbers of seconds to wait after minimum numer of clients specified has received. result_wait_timeout: numbers of seconds to wait until we received all results. Notice this is after the min_clients have arrived, and we wait for result process callback, this becomes important if the data size to be processed is large precision: number of precision digits min_clients: if specified, min number of clients we have to wait before process. For statistic_configs, the key is one of statistics' names sum, count, mean, stddev, histogram, and the value is the arguments needed. All other statistics except histogram require no argument. .. code-block:: text "statistic_configs": { "count": {}, "mean": {}, "sum": {}, "stddev": {}, "histogram": { "*": {"bins": 20}, "Age": {"bins": 10, "range": [0, 120]} } }, Histogram requires the following arguments: 1) numbers of bins or buckets of the histogram 2) the histogram range values [min, max] These arguments are different for each feature. Here are few examples: .. code-block:: text "histogram": { "*": {"bins": 20 }, "Age": {"bins": 10, "range":[0,120]} } The configuration specifies that the feature 'Age' will have 10 bins for and the range is within [0, 120). For all other features, the default ("*") configuration is used, with bins = 20. The range of histogram is not specified, thus requires the Statistics controller to dynamically estimate histogram range for each feature. Then this estimated global range (est global min, est. global max) will be used as the histogram range. To dynamically estimate such a histogram range, we need the client to provide the local min and max values in order to calculate the global bin and max value. In order to protect data privacy and avoid data leakage, a noise level is added to the local min/max value before sending to the controller. Therefore the controller only gets the 'estimated' values, and the global min/max are estimated, or more accurately, they are noised global min/max values. Here is another example: .. code-block:: text "histogram": { "density": {"bins": 10, "range":[0,120]} } In this example, there is no default histogram configuration for other features. This will work correctly if there is only one feature called "density" but will fail if there are other features in the dataset. In the following configuration: .. code-block:: text "statistic_configs": { "count": {}, "mean": {}, "stddev": {} } Only count, mean and stddev statistics are specified, so the statistics_controller will only set tasks to calculate these three statistics. """ super().__init__() self.statistic_configs: Dict[str, dict] = statistic_configs self.writer_id = writer_id self.task_name = StC.FED_STATS_TASK self.client_statistics = {} self.global_statistics = {} self.client_features = {} self.result_wait_timeout = result_wait_timeout self.wait_time_after_min_received = wait_time_after_min_received self.precision = precision self.min_clients = min_clients self.result_cb_status = {} self.client_handshake_ok = {} self.enable_pre_run_task = enable_pre_run_task self.result_callback_fns: Dict[str, Callable] = { StC.STATS_1st_STATISTICS: self.results_cb, StC.STATS_2nd_STATISTICS: self.results_cb, } fobs_registration() self.fl_ctx = None self.abort_job_in_error = { ReturnCode.EXECUTION_EXCEPTION: True, ReturnCode.TASK_UNKNOWN: True, ReturnCode.EXECUTION_RESULT_ERROR: False, ReturnCode.TASK_DATA_FILTER_ERROR: True, ReturnCode.TASK_RESULT_FILTER_ERROR: True, }
[docs] def start_controller(self, fl_ctx: FLContext): if self.statistic_configs is None or len(self.statistic_configs) == 0: self.system_panic( "At least one statistic_config must be configured for task StatisticsController", fl_ctx=fl_ctx ) self.fl_ctx = fl_ctx clients = fl_ctx.get_engine().get_clients() if not self.min_clients: self.min_clients = len(clients)
[docs] def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): self.log_info(fl_ctx, f"{self.task_name} control flow started.") if abort_signal.triggered: return False if self.enable_pre_run_task: self.pre_run_task_flow(abort_signal, fl_ctx) self.statistics_task_flow(abort_signal, fl_ctx, StC.STATS_1st_STATISTICS) self.statistics_task_flow(abort_signal, fl_ctx, StC.STATS_2nd_STATISTICS) if not StatisticsController._wait_for_all_results( self.logger, self.result_wait_timeout, self.min_clients, self.client_statistics, 1.0, abort_signal ): self.log_info(fl_ctx, f"task {self.task_name} timeout on wait for all results.") return False self.log_info(fl_ctx, "start post processing") self.post_fn(self.task_name, fl_ctx) self.log_info(fl_ctx, f"task {self.task_name} control flow end.")
[docs] def stop_controller(self, fl_ctx: FLContext): pass
[docs] def process_result_of_unknown_task( self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext ): pass
def _get_all_statistic_configs(self) -> List[StatisticConfig]: all_statistics = { StC.STATS_COUNT: StatisticConfig(StC.STATS_COUNT, {}), StC.STATS_FAILURE_COUNT: StatisticConfig(StC.STATS_FAILURE_COUNT, {}), StC.STATS_SUM: StatisticConfig(StC.STATS_SUM, {}), StC.STATS_MEAN: StatisticConfig(StC.STATS_MEAN, {}), StC.STATS_VAR: StatisticConfig(StC.STATS_VAR, {}), StC.STATS_STDDEV: StatisticConfig(StC.STATS_STDDEV, {}), } if StC.STATS_HISTOGRAM in self.statistic_configs: hist_config = self.statistic_configs[StC.STATS_HISTOGRAM] all_statistics[StC.STATS_MIN] = StatisticConfig(StC.STATS_MIN, hist_config) all_statistics[StC.STATS_MAX] = StatisticConfig(StC.STATS_MAX, hist_config) all_statistics[StC.STATS_HISTOGRAM] = StatisticConfig(StC.STATS_HISTOGRAM, hist_config) return [all_statistics[k] for k in all_statistics if k in self.statistic_configs]
[docs] def pre_run_task_flow(self, abort_signal: Signal, fl_ctx: FLContext): client_name = fl_ctx.get_identity_name() self.log_info(fl_ctx, f"start pre_run task for client {client_name}") inputs = Shareable() target_statistics: List[StatisticConfig] = self._get_all_statistic_configs() inputs[StC.STATS_TARGET_STATISTICS] = fobs.dumps(target_statistics) results_cb_fn = self.results_pre_run_cb if abort_signal.triggered: return False task = Task(name=StC.FED_STATS_PRE_RUN, data=inputs, result_received_cb=results_cb_fn) self.broadcast_and_wait( task=task, targets=None, min_responses=self.min_clients, fl_ctx=fl_ctx, wait_time_after_min_received=self.wait_time_after_min_received, abort_signal=abort_signal, ) self.log_info(fl_ctx, f" client {client_name} pre_run task flow end.")
[docs] def statistics_task_flow(self, abort_signal: Signal, fl_ctx: FLContext, statistic_task: str): self.log_info(fl_ctx, f"start prepare inputs for task {statistic_task}") inputs = self._prepare_inputs(statistic_task) results_cb_fn = self._get_result_cb(statistic_task) self.log_info(fl_ctx, f"task: {self.task_name} statistics_flow for {statistic_task} started.") if abort_signal.triggered: return False task_props = {StC.STATISTICS_TASK_KEY: statistic_task} task = Task(name=self.task_name, data=inputs, result_received_cb=results_cb_fn, props=task_props) self.broadcast_and_wait( task=task, targets=None, min_responses=self.min_clients, fl_ctx=fl_ctx, wait_time_after_min_received=self.wait_time_after_min_received, abort_signal=abort_signal, ) self.global_statistics = get_global_stats(self.global_statistics, self.client_statistics, statistic_task) self.log_info(fl_ctx, f"task {self.task_name} statistics_flow for {statistic_task} flow end.")
[docs] def handle_client_errors(self, rc: str, client_task: ClientTask, fl_ctx: FLContext): client_name = client_task.client.name task_name = client_task.task.name abort = self.abort_job_in_error[rc] if abort: self.system_panic( f"Failed in client-site statistics_executor for {client_name} during task {task_name}." f"statistics controller is exiting.", fl_ctx=fl_ctx, ) self.log_info(fl_ctx, f"Execution failed for {client_name}") else: self.log_info(fl_ctx, f"Execution result is not received for {client_name}")
[docs] def results_pre_run_cb(self, client_task: ClientTask, fl_ctx: FLContext): client_name = client_task.client.name task_name = client_task.task.name self.log_info(fl_ctx, f"Processing {task_name} pre_run from client {client_name}") result = client_task.result rc = result.get_return_code() if rc == ReturnCode.OK: self.log_info(fl_ctx, f"Received pre-run handshake result from client:{client_name} for task {task_name}") self.client_handshake_ok = {client_name: True} fl_ctx.set_prop(StC.PRE_RUN_RESULT, {client_name: from_shareable(result)}) self.fire_event(EventType.PRE_RUN_RESULT_AVAILABLE, fl_ctx) else: if rc in self.abort_job_in_error.keys(): self.handle_client_errors(rc, client_task, fl_ctx) self.client_handshake_ok = {client_name: False} # Cleanup task result client_task.result = None
[docs] def results_cb(self, client_task: ClientTask, fl_ctx: FLContext): client_name = client_task.client.name task_name = client_task.task.name self.log_info(fl_ctx, f"Processing {task_name} result from client {client_name}") result = client_task.result rc = result.get_return_code() if rc == ReturnCode.OK: self.log_info(fl_ctx, f"Received result entries from client:{client_name}, " f"for task {task_name}") dxo = from_shareable(result) client_result = dxo.data statistics_task = client_result[StC.STATISTICS_TASK_KEY] self.log_info(fl_ctx, f"handle client {client_name} results for statistics task: {statistics_task}") statistics = fobs.loads(client_result[statistics_task]) for statistic in statistics: if statistic not in self.client_statistics: self.client_statistics[statistic] = {client_name: statistics[statistic]} else: self.client_statistics[statistic].update({client_name: statistics[statistic]}) ds_features = client_result.get(StC.STATS_FEATURES, None) if ds_features: self.client_features.update({client_name: fobs.loads(ds_features)}) elif rc in self.abort_job_in_error.keys(): self.handle_client_errors(rc, client_task, fl_ctx) self.result_cb_status[client_name] = {client_task.task.props[StC.STATISTICS_TASK_KEY]: False} else: self.result_cb_status[client_name] = {client_task.task.props[StC.STATISTICS_TASK_KEY]: True} self.result_cb_status[client_name] = {client_task.task.props[StC.STATISTICS_TASK_KEY]: True} # Cleanup task result client_task.result = None
def _validate_min_clients(self, min_clients: int, client_statistics: dict) -> bool: self.logger.info("check if min_client result received for all features") resulting_clients = {} for statistic in client_statistics: clients = client_statistics[statistic].keys() if len(clients) < min_clients: return False for client in clients: ds_feature_statistics = client_statistics[statistic][client] for ds_name in ds_feature_statistics: if ds_name not in resulting_clients: resulting_clients[ds_name] = set() if ds_feature_statistics[ds_name]: resulting_clients[ds_name].update([client]) for ds in resulting_clients: if len(resulting_clients[ds]) < min_clients: return False return True
[docs] def post_fn(self, task_name: str, fl_ctx: FLContext): ok_to_proceed = self._validate_min_clients(self.min_clients, self.client_statistics) if not ok_to_proceed: self.system_panic(f"not all required {self.min_clients} received, abort the job.", fl_ctx) else: self.log_info(fl_ctx, "combine all clients' statistics") ds_stats = self._combine_all_statistics() self.log_info(fl_ctx, "save statistics result to persistence store") writer: StatisticsWriter = fl_ctx.get_engine().get_component(self.writer_id) writer.save(ds_stats, overwrite_existing=True, fl_ctx=fl_ctx)
def _combine_all_statistics(self): result = {} filtered_client_statistics = [ statistic for statistic in self.client_statistics if statistic in self.statistic_configs ] filtered_global_statistics = [ statistic for statistic in self.global_statistics if statistic in self.statistic_configs ] for statistic in filtered_client_statistics: for client in self.client_statistics[statistic]: for ds in self.client_statistics[statistic][client]: for feature_name in self.client_statistics[statistic][client][ds]: if feature_name not in result: result[feature_name] = {} if statistic not in result[feature_name]: result[feature_name][statistic] = {} if client not in result[feature_name][statistic]: result[feature_name][statistic][client] = {} if ds not in result[feature_name][statistic][client]: result[feature_name][statistic][client][ds] = {} if statistic == StC.STATS_HISTOGRAM: hist: Histogram = self.client_statistics[statistic][client][ds][feature_name] buckets = StatisticsController._apply_histogram_precision(hist.bins, self.precision) result[feature_name][statistic][client][ds] = buckets else: result[feature_name][statistic][client][ds] = round( self.client_statistics[statistic][client][ds][feature_name], self.precision ) precision = self.precision for statistic in filtered_global_statistics: for ds in self.global_statistics[statistic]: for feature_name in self.global_statistics[statistic][ds]: if StC.GLOBAL not in result[feature_name][statistic]: result[feature_name][statistic][StC.GLOBAL] = {} if ds not in result[feature_name][statistic][StC.GLOBAL]: result[feature_name][statistic][StC.GLOBAL][ds] = {} if statistic == StC.STATS_HISTOGRAM: hist: Histogram = self.global_statistics[statistic][ds][feature_name] buckets = StatisticsController._apply_histogram_precision(hist.bins, self.precision) result[feature_name][statistic][StC.GLOBAL][ds] = buckets else: result[feature_name][statistic][StC.GLOBAL].update( {ds: round(self.global_statistics[statistic][ds][feature_name], precision)} ) return result @staticmethod def _apply_histogram_precision(bins: List[Bin], precision) -> List[Bin]: buckets = [] for bucket in bins: buckets.append( Bin( round(bucket.low_value, precision), round(bucket.high_value, precision), bucket.sample_count, ) ) return buckets @staticmethod def _get_target_statistics(statistic_configs: dict, ordered_statistics: list) -> List[StatisticConfig]: # make sure the execution order of the statistics calculation targets = [] if statistic_configs: for statistic in statistic_configs: # if target statistic has histogram, we are not in 2nd statistic task # we only need to estimate the global min/max if we have histogram statistic, # If the user provided the global min/max for a specified feature, then we do nothing # if the user did not provide the global min/max for the feature, then we need to ask # client to provide the local estimated min/max for that feature. # then we used the local estimate min/max to estimate global min/max. # to do that, we calculate the local min/max in 1st statistic task. # in all cases, we will still send the STATS_MIN/MAX tasks, but client executor may or may not # delegate to stats generator to calculate the local min/max depends on if the global bin ranges # are specified. to do this, we send over the histogram configuration when calculate the local min/max if statistic == StC.STATS_HISTOGRAM and statistic not in ordered_statistics: targets.append(StatisticConfig(StC.STATS_MIN, statistic_configs[StC.STATS_HISTOGRAM])) targets.append(StatisticConfig(StC.STATS_MAX, statistic_configs[StC.STATS_HISTOGRAM])) if statistic == StC.STATS_STDDEV and statistic in ordered_statistics: targets.append(StatisticConfig(StC.STATS_VAR, {})) for rm in ordered_statistics: if rm == statistic: targets.append(StatisticConfig(statistic, statistic_configs[statistic])) return targets def _prepare_inputs(self, statistic_task: str) -> Shareable: inputs = Shareable() target_statistics: List[StatisticConfig] = StatisticsController._get_target_statistics( self.statistic_configs, StC.ordered_statistics[statistic_task] ) for tm in target_statistics: if tm.name == StC.STATS_HISTOGRAM: if StC.STATS_MIN in self.global_statistics: inputs[StC.STATS_MIN] = self.global_statistics[StC.STATS_MIN] if StC.STATS_MAX in self.global_statistics: inputs[StC.STATS_MAX] = self.global_statistics[StC.STATS_MAX] if tm.name == StC.STATS_VAR: if StC.STATS_COUNT in self.global_statistics: inputs[StC.STATS_GLOBAL_COUNT] = self.global_statistics[StC.STATS_COUNT] if StC.STATS_MEAN in self.global_statistics: inputs[StC.STATS_GLOBAL_MEAN] = self.global_statistics[StC.STATS_MEAN] inputs[StC.STATISTICS_TASK_KEY] = statistic_task inputs[StC.STATS_TARGET_STATISTICS] = fobs.dumps(target_statistics) return inputs @staticmethod def _wait_for_all_results( logger, result_wait_timeout: float, requested_client_size: int, client_statistics: dict, sleep_time: float = 1, abort_signal=None, ) -> bool: """Waits for all results. For each statistic, we check if the number of requested clients (min_clients or all clients) is available, if not, we wait until result_wait_timeout. result_wait_timeout is reset for next statistic. result_wait_timeout is per statistic, not overall timeout for all results. Args: result_wait_timeout: timeout we have to wait for each statistic. reset for each statistic requested_client_size: requested client size, usually min_clients or all clients client_statistics: client specific statistics received so far abort_signal: abort signal Returns: False, when job is aborted else True """ # record of each statistic, number of clients processed statistics_client_received = {} # current statistics obtained so far (across all clients) statistic_names = client_statistics.keys() for m in statistic_names: statistics_client_received[m] = len(client_statistics[m].keys()) timeout = result_wait_timeout for m in statistics_client_received: if requested_client_size > statistics_client_received[m]: t = 0 while t < timeout and requested_client_size > statistics_client_received[m]: if abort_signal and abort_signal.triggered: return False msg = ( f"not all client received the statistic '{m}', need to wait for {sleep_time} seconds." f"currently available clients are '{client_statistics[m].keys()}'." ) logger.info(msg) time.sleep(sleep_time) t += sleep_time # check and update number of client processed for statistics again statistics_client_received[m] = len(client_statistics[m].keys()) return True def _get_result_cb(self, statistics_task: str): return self.result_callback_fns[statistics_task]