# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from math import sqrt
from typing import Dict, List, TypeVar
from nvflare.app_common.abstract.statistics_spec import Bin, BinRange, DataType, Feature, Histogram, HistogramType
from nvflare.app_common.app_constant import StatisticsConstants as StC
T = TypeVar("T")
[docs]
def get_initial_structure(client_metrics: dict, ordered_metrics: dict) -> dict:
"""Calculate initial output structure that is common at all the hierarchical levels.
Args:
client_metrics: Local stats for each client.
ordered_metrics: Ordered target metrics.
Returns:
A dict containing initial output structure.
"""
stats = {}
for metric in ordered_metrics:
stats[metric] = {}
for stat in client_metrics:
for site in client_metrics[stat]:
for ds in client_metrics[stat][site]:
stats[metric][ds] = {}
for feature in client_metrics[stat][site][ds]:
stats[metric][ds][feature] = 0
return stats
[docs]
def create_output_structure(
client_metrics: dict, metric_task: str, ordered_metrics: dict, hierarchy_config: dict
) -> dict:
"""Recursively calculate the hierarchical global stats structure from the given hierarchy config.
Args:
client_metrics: Local stats for each client.
metric_task: Statistics task.
ordered_metrics: Ordered target metrics.
hierarchy_config: Hierarchy configuration for the global stats.
Returns:
A dict containing hierarchical global stats structure.
"""
def recursively_add_values(structure: dict, value_json: dict, metric_task: str, ordered_metrics: dict):
if isinstance(structure, dict):
new_items = {}
for key, value in list(structure.items()):
if key == StC.NAME:
continue
if isinstance(value, list):
if key not in new_items:
new_items[StC.GLOBAL] = get_initial_structure(value_json, ordered_metrics)
for i, item in enumerate(value):
if isinstance(item, str):
value[i] = {
StC.NAME: item,
StC.LOCAL: get_initial_structure(value_json, ordered_metrics),
}
else:
recursively_add_values(item, value_json, metric_task, ordered_metrics)
else:
recursively_add_values(value, value_json, metric_task, ordered_metrics)
structure.update(new_items)
elif isinstance(structure, list):
for item in structure:
recursively_add_values(item, value_json, metric_task, ordered_metrics)
return structure
filled_structure = copy.deepcopy(hierarchy_config)
final_strcture = recursively_add_values(filled_structure, client_metrics, metric_task, ordered_metrics)
return final_strcture
[docs]
def get_output_structure(client_metrics: dict, metric_task: str, ordered_metrics: dict, hierarchy_config: dict) -> dict:
"""Create required global statistics hierarchical output structure.
Args:
client_metrics: Local stats for each client.
metric_task: Statistics task.
ordered_metrics: Ordered target metrics.
hierarchy_config: Hierarchy configuration for the global stats.
Returns:
A dict containing hierarchical global stats structure that also includes
top level global stats structure.
"""
top_strcture = get_initial_structure(client_metrics, ordered_metrics)
output_structure = {
StC.GLOBAL: top_strcture,
**create_output_structure(client_metrics, metric_task, ordered_metrics, hierarchy_config),
}
return output_structure
[docs]
def update_output_strcture(
client_metrics: dict,
metric_task: str,
ordered_metrics: dict,
global_metrics: dict,
) -> None:
"""Update global statistics hierarchical output structure with the new ordered metrics.
Args:
client_metrics: Local stats for each client.
metric_task: Statistics task.
ordered_metrics: Ordered target metrics.
global_metrics: The current global metrics.
Returns:
A dict containing updated hierarchical global stats.
"""
if isinstance(global_metrics, dict):
for key, value in list(global_metrics.items()):
if key == StC.NAME:
continue
elif key == StC.GLOBAL:
global_metrics[key].update(get_initial_structure(client_metrics, ordered_metrics))
elif key == StC.LOCAL:
global_metrics[key].update(get_initial_structure(client_metrics, ordered_metrics))
return
elif isinstance(value, list):
update_output_strcture(client_metrics, metric_task, ordered_metrics, value)
elif isinstance(global_metrics, list):
for item in global_metrics:
update_output_strcture(client_metrics, metric_task, ordered_metrics, item)
[docs]
def get_global_stats(global_metrics: dict, client_metrics: dict, metric_task: str, hierarchy_config: dict) -> dict:
"""Get global hierarchical statistics for the given hierarchy config.
Args:
global_metrics: The current global metrics.
client_metrics: Local stats for each client.
metric_task: Statistics task.
hierarchy_config: Hierarchy configuration for the global stats.
Returns:
A dict containing global hierarchical statistics.
"""
# create stats structure
ordered_target_metrics = StC.ordered_statistics[metric_task]
ordered_metrics = [metric for metric in ordered_target_metrics if metric in client_metrics]
# Create hierarchical output structure
if StC.GLOBAL not in global_metrics:
global_metrics = get_output_structure(client_metrics, metric_task, ordered_metrics, hierarchy_config)
else:
update_output_strcture(client_metrics, metric_task, ordered_metrics, global_metrics)
for metric in ordered_metrics:
stats = client_metrics[metric]
if metric == StC.STATS_COUNT or metric == StC.STATS_FAILURE_COUNT or metric == StC.STATS_SUM:
for client_name in stats:
global_metrics = accumulate_hierarchical_metrics(
metric, client_name, stats[client_name], global_metrics, hierarchy_config
)
elif metric == StC.STATS_MAX or metric == StC.STATS_MIN:
for client_name in stats:
global_metrics = get_hierarchical_mins_or_maxs(
metric, client_name, stats[client_name], global_metrics, hierarchy_config
)
elif metric == StC.STATS_MEAN:
global_metrics = get_hierarchical_means(metric, global_metrics)
elif metric == StC.STATS_HISTOGRAM:
for client_name in stats:
global_metrics = get_hierarchical_histograms(
metric, client_name, stats[client_name], global_metrics, hierarchy_config
)
elif metric == StC.STATS_VAR:
for client_name in stats:
global_metrics = accumulate_hierarchical_metrics(
metric, client_name, stats[client_name], global_metrics, hierarchy_config
)
elif metric == StC.STATS_STDDEV:
global_metrics = get_hierarchical_stddevs(global_metrics)
return global_metrics
[docs]
def accumulate_hierarchical_metrics(
metric: str, client_name: str, metrics: dict, global_metrics: dict, hierarchy_config: dict
) -> dict:
"""Accumulate metrics at each hierarchical level.
Args:
metric: Metric to accumulate.
client_name: Client name.
metrics: Client metrics.
global_metrics: The current global metrics.
hierarchy_config: Hierarchy configuration for the global stats.
Returns:
A dict containing accumulated hierarchical global statistics.
"""
def recursively_accumulate_hierarchical_metrics(
metric: str, client_name: str, metrics: dict, global_metrics: dict, dataset: str, feature: str, org: list
) -> dict:
if isinstance(global_metrics, dict):
for key, value in global_metrics.items():
if key == StC.GLOBAL and StC.NAME not in global_metrics:
global_metrics[StC.GLOBAL][metric][dataset][feature] += metrics[dataset][feature]
continue
if key == StC.NAME:
if org and value in org:
# The client belongs to this org so update current global metrics before sending it further
global_metrics[StC.GLOBAL][metric][dataset][feature] += metrics[dataset][feature]
elif value == client_name:
# This is a client local metrics update
global_metrics[StC.LOCAL][metric][dataset][feature] += metrics[dataset][feature]
else:
break
if isinstance(value, list):
for item in value:
recursively_accumulate_hierarchical_metrics(
metric, client_name, metrics, item, dataset, feature, org
)
client_org = get_client_hierarchy(copy.deepcopy(hierarchy_config), client_name)
for dataset in metrics:
for feature in metrics[dataset]:
recursively_accumulate_hierarchical_metrics(
metric, client_name, metrics, global_metrics, dataset, feature, client_org
)
return global_metrics
[docs]
def get_hierarchical_mins_or_maxs(
metric: str, client_name: str, metrics: dict, global_metrics: dict, hierarchy_config: dict
) -> dict:
"""Calculate min or max at each hierarchical level.
Args:
metric: Metric to accumulate.
client_name: Client name.
metrics: Client metrics.
global_metrics: The current global metrics.
hierarchy_config: Hierarchy configuration for the global stats.
Returns:
A dict containing updated hierarchical global statistics with
accumulated mins or maxs.
"""
def recursively_update_org_mins_or_maxs(
metric: str,
client_name: str,
metrics: dict,
global_metrics: dict,
dataset: str,
feature: str,
org: list,
op: str,
) -> dict:
if isinstance(global_metrics, dict):
for key, value in global_metrics.items():
if key == StC.GLOBAL and StC.NAME not in global_metrics:
if global_metrics[StC.GLOBAL][metric][dataset][feature]:
global_metrics[StC.GLOBAL][metric][dataset][feature] = op(
global_metrics[StC.GLOBAL][metric][dataset][feature], metrics[dataset][feature]
)
else:
global_metrics[StC.GLOBAL][metric][dataset][feature] = metrics[dataset][feature]
continue
if key == StC.NAME:
if org and value in org:
# The client belongs to this org so update current global metrics before sending it further
if global_metrics[StC.GLOBAL][metric][dataset][feature]:
global_metrics[StC.GLOBAL][metric][dataset][feature] = op(
global_metrics[StC.GLOBAL][metric][dataset][feature], metrics[dataset][feature]
)
else:
global_metrics[StC.GLOBAL][metric][dataset][feature] = metrics[dataset][feature]
elif value == client_name:
# This is a client local metrics update
global_metrics[StC.LOCAL][metric][dataset][feature] = metrics[dataset][feature]
else:
break
if isinstance(value, list):
for item in value:
recursively_update_org_mins_or_maxs(
metric, client_name, metrics, item, dataset, feature, org, op
)
if metric == "min":
op = min
else:
op = max
client_org = get_client_hierarchy(copy.deepcopy(hierarchy_config), client_name)
for dataset in metrics:
for feature in metrics[dataset]:
recursively_update_org_mins_or_maxs(
metric, client_name, metrics, global_metrics, dataset, feature, client_org, op
)
return global_metrics
[docs]
def get_hierarchical_means(metric: str, global_metrics: dict) -> dict:
"""Calculate means at each hierarchical level.
Args:
metric: Metric to accumulate.
global_metrics: The current global metrics.
Returns:
A dict containing updated hierarchical global statistics with
accumulated means.
"""
def recursively_update_org_means(metrics: dict, global_metrics: dict, dataset: str, feature: str) -> dict:
if isinstance(global_metrics, dict):
for key, value in global_metrics.items():
if key == StC.GLOBAL:
global_metrics[StC.GLOBAL][metric][dataset][feature] = (
global_metrics[StC.GLOBAL][StC.STATS_SUM][dataset][feature]
/ global_metrics[StC.GLOBAL][StC.STATS_COUNT][dataset][feature]
)
if key == StC.LOCAL:
global_metrics[StC.LOCAL][metric][dataset][feature] = (
global_metrics[StC.LOCAL][StC.STATS_SUM][dataset][feature]
/ global_metrics[StC.LOCAL][StC.STATS_COUNT][dataset][feature]
)
if isinstance(value, list):
for item in value:
recursively_update_org_means(metrics, item, dataset, feature)
# Iterate each hierarchical level and calculate 'mean' from 'sum' and 'count'.
for dataset in global_metrics[StC.GLOBAL][StC.STATS_COUNT]:
for feature in global_metrics[StC.GLOBAL][StC.STATS_COUNT][dataset]:
recursively_update_org_means(metric, global_metrics, dataset, feature)
return global_metrics
[docs]
def get_hierarchical_histograms(
metric: str, client_name: str, metrics: dict, global_metrics: dict, hierarchy_config: dict
) -> dict:
"""Calculate histograms at each hierarchical level.
Args:
metric: Metric to accumulate.
client_name: Client name.
metrics: Client metrics.
global_metrics: The current global metrics.
hierarchy_config: Hierarchy configuration for the global stats.
Returns:
A dict containing updated hierarchical global statistics with
accumulated histograms.
"""
def recursively_accumulate_org_histograms(
metric: str,
client_name: str,
metrics: dict,
global_metrics: dict,
dataset: str,
feature: str,
org: list,
histogram: dict,
) -> dict:
if isinstance(global_metrics, dict):
for key, value in global_metrics.items():
if key == StC.GLOBAL and StC.NAME not in global_metrics:
if (
feature not in global_metrics[StC.GLOBAL][metric][dataset]
or not global_metrics[StC.GLOBAL][metric][dataset][feature]
):
g_bins = []
for bucket in histogram.bins:
g_bins.append(Bin(bucket.low_value, bucket.high_value, bucket.sample_count))
g_hist = Histogram(HistogramType.STANDARD, g_bins)
global_metrics[StC.GLOBAL][metric][dataset][feature] = g_hist
else:
g_hist = global_metrics[StC.GLOBAL][metric][dataset][feature]
g_buckets = bins_to_dict(g_hist.bins)
for bucket in histogram.bins:
bin_range = BinRange(bucket.low_value, bucket.high_value)
if bin_range in g_buckets:
g_buckets[bin_range] += bucket.sample_count
else:
g_buckets[bin_range] = bucket.sample_count
# update ordered bins
updated_bins = []
for gb in g_hist.bins:
bin_range = BinRange(gb.low_value, gb.high_value)
updated_bins.append(Bin(gb.low_value, gb.high_value, g_buckets[bin_range]))
global_metrics[StC.GLOBAL][metric][dataset][feature] = Histogram(g_hist.hist_type, updated_bins)
continue
if key == StC.NAME:
if org and value in org:
# The client belongs to this org so update current global metrics before sending it further
if (
feature not in global_metrics[StC.GLOBAL][metric][dataset]
or not global_metrics[StC.GLOBAL][metric][dataset][feature]
):
g_bins = []
for bucket in histogram.bins:
g_bins.append(Bin(bucket.low_value, bucket.high_value, bucket.sample_count))
g_hist = Histogram(HistogramType.STANDARD, g_bins)
global_metrics[StC.GLOBAL][metric][dataset][feature] = g_hist
else:
g_hist = global_metrics[StC.GLOBAL][metric][dataset][feature]
g_buckets = bins_to_dict(g_hist.bins)
for bucket in histogram.bins:
bin_range = BinRange(bucket.low_value, bucket.high_value)
if bin_range in g_buckets:
g_buckets[bin_range] += bucket.sample_count
else:
g_buckets[bin_range] = bucket.sample_count
# update ordered bins
updated_bins = []
for gb in g_hist.bins:
bin_range = BinRange(gb.low_value, gb.high_value)
updated_bins.append(Bin(gb.low_value, gb.high_value, g_buckets[bin_range]))
global_metrics[StC.GLOBAL][metric][dataset][feature] = Histogram(
g_hist.hist_type, updated_bins
)
elif value == client_name:
# This is a client local metrics update
if (
feature not in global_metrics[StC.LOCAL][metric][dataset]
or not global_metrics[StC.LOCAL][metric][dataset][feature]
):
g_bins = []
for bucket in histogram.bins:
g_bins.append(Bin(bucket.low_value, bucket.high_value, bucket.sample_count))
g_hist = Histogram(HistogramType.STANDARD, g_bins)
global_metrics[StC.LOCAL][metric][dataset][feature] = g_hist
else:
g_hist = global_metrics[StC.LOCAL][metric][dataset][feature]
g_buckets = bins_to_dict(g_hist.bins)
for bucket in histogram.bins:
bin_range = BinRange(bucket.low_value, bucket.high_value)
if bin_range in g_buckets:
g_buckets[bin_range] += bucket.sample_count
else:
g_buckets[bin_range] = bucket.sample_count
# update ordered bins
updated_bins = []
for gb in g_hist.bins:
bin_range = BinRange(gb.low_value, gb.high_value)
updated_bins.append(Bin(gb.low_value, gb.high_value, g_buckets[bin_range]))
global_metrics[StC.LOCAL][metric][dataset][feature] = Histogram(
g_hist.hist_type, updated_bins
)
else:
break
if isinstance(value, list):
for item in value:
recursively_accumulate_org_histograms(
metric, client_name, metrics, item, dataset, feature, org, histogram
)
client_org = get_client_hierarchy(copy.deepcopy(hierarchy_config), client_name)
for dataset in metrics:
for feature in metrics[dataset]:
histogram = metrics[dataset][feature]
recursively_accumulate_org_histograms(
metric, client_name, metrics, global_metrics, dataset, feature, client_org, histogram
)
return global_metrics
[docs]
def get_hierarchical_stddevs(global_metrics: dict) -> dict:
"""Calculate stddevs at each hierarchical level.
Args:
global_metrics: The current global metrics.
Returns:
A dict containing updated hierarchical global statistics with
accumulated stddevs.
"""
def recursively_update_org_stddevs(global_metrics: dict, dataset: str, feature: str) -> dict:
if isinstance(global_metrics, dict):
for key, value in global_metrics.items():
if key == StC.GLOBAL:
global_metrics[StC.GLOBAL][StC.STATS_STDDEV][dataset][feature] = sqrt(
global_metrics[StC.GLOBAL][StC.STATS_VAR][dataset][feature]
)
if key == StC.LOCAL:
global_metrics[StC.LOCAL][StC.STATS_STDDEV][dataset][feature] = sqrt(
global_metrics[StC.LOCAL][StC.STATS_VAR][dataset][feature]
)
if isinstance(value, list):
for item in value:
recursively_update_org_stddevs(item, dataset, feature)
for dataset in global_metrics[StC.GLOBAL][StC.STATS_VAR]:
for feature in global_metrics[StC.GLOBAL][StC.STATS_VAR][dataset]:
recursively_update_org_stddevs(global_metrics, dataset, feature)
return global_metrics
[docs]
def get_hierarchical_levels(data: dict, level: int = 0, levels_dict: dict = None) -> dict:
"""Calculate number of hierarchical levels from the given hierarchy config.
Args:
data: Hierarchy configuration for the global stats.
level: The current hierarchical level (used for recursive calls).
levels_dict: The accumulated levels dict (used for recursive calls).
Returns:
A dict containing containing hierarchical levels.
"""
if levels_dict is None:
levels_dict = {}
if isinstance(data, list):
for item in data:
get_hierarchical_levels(item, level, levels_dict)
elif isinstance(data, dict):
for key, value in data.items():
if key == StC.NAME:
continue
if key not in levels_dict:
levels_dict[key] = level
get_hierarchical_levels(value, level + 1, levels_dict)
return levels_dict
[docs]
def get_client_hierarchy(hierarchy_config: dict, client_name: str, path=None) -> list:
"""Calculate hierarchy for the given client name.
Args:
hierarchy_config: Hierarchy configuration for the global stats.
client_name: Client name.
path: The accumulated hierarchy path (used for recursive calls).
Returns:
A list containing hierarchy levels for the client.
"""
if path is None:
path = []
if isinstance(hierarchy_config, dict):
for key, value in hierarchy_config.items():
if isinstance(value, list):
result = get_client_hierarchy(value, client_name, path)
if result:
return result
elif isinstance(hierarchy_config, list):
for item in hierarchy_config:
if item == client_name:
return path
if isinstance(item, dict):
result = get_client_hierarchy(item, client_name, path + [item.get(StC.NAME)])
if result:
return result
return None
[docs]
def bins_to_dict(bins: List[Bin]) -> Dict[BinRange, float]:
"""Convert histogram bins to a 'dict'.
Args:
bins: Histogram bins.
Returns:
A dict containing histogram bins.
"""
buckets = {}
for bucket in bins:
bucket_range = BinRange(bucket.low_value, bucket.high_value)
buckets[bucket_range] = bucket.sample_count
return buckets
[docs]
def filter_numeric_features(ds_features: Dict[str, List[Feature]]) -> Dict[str, List[Feature]]:
"""Filter numeric features.
Args:
ds_features: A features dict.
Returns:
A dict containing numeric features.
"""
numeric_ds_features = {}
for ds_name in ds_features:
features: List[Feature] = ds_features[ds_name]
n_features = [f for f in features if (f.data_type == DataType.INT or f.data_type == DataType.FLOAT)]
numeric_ds_features[ds_name] = n_features
return numeric_ds_features