Source code for nvflare.app_common.executors.statistics.statistics_task_handler

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional

from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.statistics_spec import Feature, Histogram, HistogramType, StatisticConfig, Statistics
from nvflare.app_common.abstract.task_handler import TaskHandler
from nvflare.app_common.app_constant import StatisticsConstants as StC
from nvflare.app_common.statistics.numeric_stats import filter_numeric_features
from nvflare.app_common.statistics.statisitcs_objects_decomposer import fobs_registration
from nvflare.app_common.statistics.statistics_config_utils import get_feature_bin_range
from nvflare.fuel.utils import fobs
from nvflare.security.logging import secure_format_exception


[docs]class StatisticsTaskHandler(TaskHandler): """ StatisticsTaskHandler is to be used together with StatisticsExecutor. StatisticsExecutor is client-side executor that perform local statistics generation and communication to FL Server global statistics controller. The actual local statistics calculation would delegate to Statistics spec implementor. """ def __init__(self, generator_id: str, precision: int = 4): super().__init__(generator_id, Statistics) self.stats_generator: Optional[Statistics] = None self.precision = precision fobs_registration()
[docs] def initialize(self, fl_ctx: FLContext): super().initialize(fl_ctx) self.stats_generator = self.local_comp
[docs] def execute_task(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: client_name = fl_ctx.get_identity_name() self.log_info(fl_ctx, f"Executing task '{task_name}' for client: '{client_name}'") result = Shareable() statistics_result = {} if task_name == StC.FED_STATS_PRE_RUN: # initial handshake target_statistics: List[StatisticConfig] = fobs.loads(shareable.get(StC.STATS_TARGET_STATISTICS)) return self.pre_run(target_statistics) elif task_name == StC.FED_STATS_TASK: ds_features = self.get_numeric_features() statistics_task = shareable.get(StC.STATISTICS_TASK_KEY) target_statistics: List[StatisticConfig] = fobs.loads(shareable.get(StC.STATS_TARGET_STATISTICS)) if StC.STATS_FAILURE_COUNT not in target_statistics: target_statistics.append(StatisticConfig(StC.STATS_FAILURE_COUNT, {})) for tm in target_statistics: fn = self.statistic_functions()[tm.name] statistics_result[tm.name] = {} self._populate_result_statistics(statistics_result, ds_features, tm, shareable, fl_ctx, fn) # always add count for data privacy needs if StC.STATS_COUNT not in statistics_result: tm = StatisticConfig(StC.STATS_COUNT, {}) fn = self.get_count statistics_result[tm.name] = {} self._populate_result_statistics(statistics_result, ds_features, tm, shareable, fl_ctx, fn) result[StC.STATISTICS_TASK_KEY] = statistics_task if statistics_task == StC.STATS_1st_STATISTICS: result[StC.STATS_FEATURES] = fobs.dumps(ds_features) result[statistics_task] = fobs.dumps(statistics_result) return result else: raise RuntimeError(ReturnCode.TASK_UNKNOWN)
[docs] def statistic_functions(self) -> dict: return { StC.STATS_COUNT: self.get_count, StC.STATS_FAILURE_COUNT: self.get_failure_count, StC.STATS_SUM: self.get_sum, StC.STATS_MEAN: self.get_mean, StC.STATS_STDDEV: self.get_stddev, StC.STATS_VAR: self.get_variance_with_mean, StC.STATS_HISTOGRAM: self.get_histogram, StC.STATS_MAX: self.get_max_value, StC.STATS_MIN: self.get_min_value, }
def _populate_result_statistics(self, statistics_result, ds_features, tm: StatisticConfig, shareable, fl_ctx, fn): for ds_name in ds_features: statistics_result[tm.name][ds_name] = {} features: List[Feature] = ds_features[ds_name] for feature in features: try: statistics_result[tm.name][ds_name][feature.feature_name] = fn( ds_name, feature.feature_name, tm, shareable, fl_ctx ) except Exception as e: self.log_exception( fl_ctx, f"Failed to populate result statistics of dataset {ds_name}" f" and feature {feature.feature_name} with exception: {secure_format_exception(e)}", )
[docs] def get_numeric_features(self) -> Dict[str, List[Feature]]: ds_features: Dict[str, List[Feature]] = self.stats_generator.features() return filter_numeric_features(ds_features)
[docs] def pre_run(self, target_statistics: List[StatisticConfig]): feature_num_of_bins = None feature_bin_ranges = None target_statistic_keys = [] for mc in target_statistics: target_statistic_keys.append(mc.name) if mc.name == StC.STATS_HISTOGRAM: hist_config = mc.config feature_num_of_bins = {} feature_bin_ranges = {} for feature_name in hist_config: num_of_bins: int = self.get_number_of_bins(feature_name, hist_config) feature_num_of_bins[feature_name] = num_of_bins bin_range = get_feature_bin_range(feature_name, hist_config) feature_bin_ranges[feature_name] = bin_range return self.stats_generator.pre_run(target_statistic_keys, feature_num_of_bins, feature_bin_ranges)
[docs] def get_count( self, dataset_name: str, feature_name: str, statistic_configs: StatisticConfig, inputs: Shareable, fl_ctx: FLContext, ) -> int: result = self.stats_generator.count(dataset_name, feature_name) return result
[docs] def get_failure_count( self, dataset_name: str, feature_name: str, statistic_configs: StatisticConfig, inputs: Shareable, fl_ctx: FLContext, ) -> int: result = self.stats_generator.failure_count(dataset_name, feature_name) return result
[docs] def get_sum( self, dataset_name: str, feature_name: str, statistic_configs: StatisticConfig, inputs: Shareable, fl_ctx: FLContext, ) -> float: result = round(self.stats_generator.sum(dataset_name, feature_name), self.precision) return result
[docs] def get_mean( self, dataset_name: str, feature_name: str, statistic_configs: StatisticConfig, inputs: Shareable, fl_ctx: FLContext, ) -> float: count = self.stats_generator.count(dataset_name, feature_name) sum_value = self.stats_generator.sum(dataset_name, feature_name) if count is not None and sum_value is not None: return round(sum_value / count, self.precision) else: # user did not implement count and/or sum, call means directly. mean = round(self.stats_generator.mean(dataset_name, feature_name), self.precision) # self._check_result(mean, statistic_configs.name) return mean
[docs] def get_stddev( self, dataset_name: str, feature_name: str, statistic_configs: StatisticConfig, inputs: Shareable, fl_ctx: FLContext, ) -> float: result = round(self.stats_generator.stddev(dataset_name, feature_name), self.precision) return result
[docs] def get_variance_with_mean( self, dataset_name: str, feature_name: str, statistic_configs: StatisticConfig, inputs: Shareable, fl_ctx: FLContext, ) -> float: result = None if StC.STATS_GLOBAL_MEAN in inputs and StC.STATS_GLOBAL_COUNT in inputs: global_mean = self._get_global_value_from_input(StC.STATS_GLOBAL_MEAN, dataset_name, feature_name, inputs) global_count = self._get_global_value_from_input(StC.STATS_GLOBAL_COUNT, dataset_name, feature_name, inputs) if global_mean is not None and global_count is not None: result = self.stats_generator.variance_with_mean(dataset_name, feature_name, global_mean, global_count) result = round(result, self.precision) return result
[docs] def get_histogram( self, dataset_name: str, feature_name: str, statistic_configs: StatisticConfig, inputs: Shareable, fl_ctx: FLContext, ) -> Histogram: if StC.STATS_MIN in inputs and StC.STATS_MAX in inputs: global_min_value = self._get_global_value_from_input(StC.STATS_MIN, dataset_name, feature_name, inputs) global_max_value = self._get_global_value_from_input(StC.STATS_MAX, dataset_name, feature_name, inputs) if global_min_value is not None and global_max_value is not None: hist_config: dict = statistic_configs.config num_of_bins: int = self.get_number_of_bins(feature_name, hist_config) bin_range: List[float] = self.get_bin_range( feature_name, global_min_value, global_max_value, hist_config ) result = self.stats_generator.histogram( dataset_name, feature_name, num_of_bins, bin_range[0], bin_range[1] ) return result else: return Histogram(HistogramType.STANDARD, list()) else: return Histogram(HistogramType.STANDARD, list())
[docs] def get_max_value( self, dataset_name: str, feature_name: str, statistic_configs: StatisticConfig, inputs: Shareable, fl_ctx: FLContext, ) -> float: """ get randomized max value """ hist_config: dict = statistic_configs.config feature_bin_range = get_feature_bin_range(feature_name, hist_config) if feature_bin_range is None: client_max_value = self.stats_generator.max_value(dataset_name, feature_name) return client_max_value else: return feature_bin_range[1]
[docs] def get_min_value( self, dataset_name: str, feature_name: str, statistic_configs: StatisticConfig, inputs: Shareable, fl_ctx: FLContext, ) -> float: """ get randomized min value """ hist_config: dict = statistic_configs.config feature_bin_range = get_feature_bin_range(feature_name, hist_config) if feature_bin_range is None: client_min_value = self.stats_generator.min_value(dataset_name, feature_name) return client_min_value else: return feature_bin_range[0]
[docs] def get_number_of_bins(self, feature_name: str, hist_config: dict) -> int: err_msg = ( f"feature name = '{feature_name}': " f"missing required '{StC.STATS_BINS}' config in histogram config = {hist_config}" ) try: num_of_bins = None if feature_name in hist_config: num_of_bins = hist_config[feature_name][StC.STATS_BINS] else: if "*" in hist_config: default_config = hist_config["*"] num_of_bins = default_config[StC.STATS_BINS] if num_of_bins: return num_of_bins else: raise Exception(err_msg) except KeyError as e: raise Exception(err_msg)
[docs] def get_bin_range( self, feature_name: str, global_min_value: float, global_max_value: float, hist_config: dict ) -> List[float]: global_bin_range = [global_min_value, global_max_value] bin_range = get_feature_bin_range(feature_name, hist_config) if bin_range is None: bin_range = global_bin_range return bin_range
def _get_global_value_from_input(self, statistic_key: str, dataset_name: str, feature_name: str, inputs): global_value = None if dataset_name in inputs[statistic_key]: if feature_name in inputs[statistic_key][dataset_name]: global_value = inputs[statistic_key][dataset_name][feature_name] elif "*" in inputs[StC.STATS_MIN][dataset_name]: global_value = inputs[statistic_key][dataset_name][feature_name] return global_value