Source code for nvflare.app_common.widgets.metrics_artifact_writer

# Copyright (c) 2026, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import math
import os
from typing import Any, Dict, Optional, Tuple

import numpy as np

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey, FLMetaKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_def import JobMetaKey
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.app_common.utils.file_utils import resolve_path_under_root
from nvflare.app_common.utils.fl_model_utils import FLModelUtils
from nvflare.widgets.widget import Widget

METRICS_AGGREGATION_INFO = AppConstants.METRICS_AGGREGATION_INFO


[docs] class MetricsArtifactWriter(Widget): """Writes safe, machine-readable round and summary metric artifacts. The writer consumes metrics already produced by workflows/controllers. It records dynamic metric names as values instead of object keys so downstream consumers do not need to treat client-provided names as JSON object structure. """ def __init__( self, results_dir: str = "metrics", summary_file_name: str = "metrics_summary.json", round_file_name: str = "round_metrics.jsonl", limits: Optional[Dict[str, int]] = None, ): super().__init__() self.results_dir = results_dir self.summary_file_name = summary_file_name self.round_file_name = round_file_name limits = limits or {} self.max_metric_name_length = limits.get("max_metric_name_length", 256) self.max_string_value_length = limits.get("max_string_value_length", 1024) self.max_metrics_per_site_per_round = limits.get("max_metrics_per_site_per_round", 512) self.max_sites_per_round = limits.get("max_sites_per_round", 10000) self.max_site_metric_records_per_round = limits.get("max_site_metric_records_per_round", 10000) self.max_skipped_metrics_per_round = limits.get("max_skipped_metrics_per_round", 1024) self.max_round_record_bytes = limits.get("max_round_record_bytes", 1048576) self.max_summary_bytes = limits.get("max_summary_bytes", 1048576) self.max_int_bit_length = limits.get("max_int_bit_length", 1023) self._reset() def _reset(self): self._has_metrics = False self._final_round = None self._final_aggregated_metrics = [] self._best_selection = None self._key_metric = None self._aggregation = None self._metric_source = None self._metric_split = None self._round_file_path = None self._summary_file_path = None self._round_sites = {} self._round_skipped = {} self._round_site_metric_counts = {}
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: self._reset() elif event_type == AppEventType.AFTER_CONTRIBUTION_ACCEPT: self._handle_after_contribution_accept(fl_ctx) elif event_type == AppEventType.AFTER_AGGREGATION: self._handle_after_aggregation(fl_ctx) elif event_type == AppEventType.GLOBAL_BEST_MODEL_AVAILABLE: self._handle_global_best_model_available(fl_ctx) elif event_type == EventType.END_RUN: self._write_summary_if_needed(fl_ctx)
def _handle_after_aggregation(self, fl_ctx: FLContext): aggr_result = fl_ctx.get_prop(AppConstants.AGGREGATION_RESULT, None) aggr_result = self._to_fl_model(aggr_result) if aggr_result is None: return meta = aggr_result.meta or {} info = meta.get(METRICS_AGGREGATION_INFO, {}) if not isinstance(info, dict): info = {} skipped = [] aggregated_metrics = self._normalize_metrics( aggr_result.metrics, site=None, skipped=skipped, for_aggregation=True ) current_round = self._get_current_round(aggr_result, fl_ctx) fallback_sites = self._round_sites.pop(current_round, []) fallback_skipped = self._round_skipped.pop(current_round, []) self._round_site_metric_counts.pop(current_round, None) sites = self._normalize_sites(info.get("sites"), skipped) site_weights = self._normalize_site_weights(info.get("site_weights"), skipped) use_contribution_sites = info.get("use_contribution_sites", True) is not False if not sites and use_contribution_sites: sites = fallback_sites self._merge_skipped(skipped, fallback_skipped) self._apply_site_weights(sites, site_weights) if not aggregated_metrics and not sites and not skipped: return aggregation = self._sanitize_json_object(info.get("aggregation")) key_metric = self._normalize_key_metric(info.get("key_metric")) record = { "round": current_round, "aggregated_metrics": aggregated_metrics, "sites": sites, "skipped_metrics": skipped, } if aggregation: record["aggregation"] = aggregation if key_metric: record["key_metric"] = key_metric self._append_round_record(fl_ctx, record) self._has_metrics = True self._final_round = current_round self._final_aggregated_metrics = aggregated_metrics self._metric_source = self._safe_text(info.get("metric_source")) self._metric_split = self._safe_text(info.get("metric_split")) self._aggregation = aggregation if aggregation else self._aggregation if key_metric: self._key_metric = key_metric @staticmethod def _to_fl_model(value): if isinstance(value, FLModel): return value if isinstance(value, Shareable): try: return FLModelUtils.from_shareable(value) except Exception: return None return None def _handle_global_best_model_available(self, fl_ctx: FLContext): selection = self._normalize_selection_info(fl_ctx.get_prop(AppConstants.METRICS_SELECTION_INFO, None)) if not selection: return self._best_selection = selection key_metric = selection.get("key_metric") if key_metric: self._key_metric = key_metric def _handle_after_contribution_accept(self, fl_ctx: FLContext): accepted = fl_ctx.get_prop(AppConstants.AGGREGATION_ACCEPTED, True) if accepted is False: return result = fl_ctx.get_prop(AppConstants.TRAINING_RESULT, None) if result is None: peer_ctx = fl_ctx.get_peer_context() if peer_ctx: result = peer_ctx.get_prop(FLContextKey.SHAREABLE, None) if result is None: return try: model = FLModelUtils.from_shareable(result) except Exception: return if not model.metrics: return current_round = self._get_current_round(model, fl_ctx) skipped = [] site_name = self._get_site_name(model, fl_ctx) metrics = self._normalize_metrics(model.metrics, site=site_name, skipped=skipped) if skipped: self._extend_round_skipped(current_round, skipped) if not metrics: return sites = self._round_sites.setdefault(current_round, []) if len(sites) >= self.max_sites_per_round: too_many_sites = [] self._add_skipped(too_many_sites, site_name, "", "too_many_sites") self._extend_round_skipped(current_round, too_many_sites) return metric_count = self._round_site_metric_counts.get(current_round, 0) if metric_count + len(metrics) > self.max_site_metric_records_per_round: too_many_metrics = [] self._add_skipped(too_many_metrics, site_name, "", "too_many_metrics") self._extend_round_skipped(current_round, too_many_metrics) allowed = self.max_site_metric_records_per_round - metric_count if allowed <= 0: return metrics = metrics[:allowed] self._round_site_metric_counts[current_round] = metric_count + len(metrics) site = { "name": self._sanitize_name(site_name), "metrics": metrics, } meta = model.meta or {} weight = self._safe_weight(meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND)) if weight is not None: site["weight"] = weight site["weight_key"] = FLMetaKey.NUM_STEPS_CURRENT_ROUND sites.append(site) def _normalize_sites(self, sites, skipped): if not isinstance(sites, list): return [] normalized_sites = [] metric_count = 0 for site in sites: if not isinstance(site, dict): continue if len(normalized_sites) >= self.max_sites_per_round: self._add_skipped(skipped, site.get("name"), "", "too_many_sites") break metrics = self._normalize_metrics(site.get("metrics"), site=site.get("name"), skipped=skipped) if not metrics: continue if metric_count + len(metrics) > self.max_site_metric_records_per_round: self._add_skipped(skipped, site.get("name"), "", "too_many_metrics") allowed = self.max_site_metric_records_per_round - metric_count if allowed <= 0: break metrics = metrics[:allowed] metric_count += len(metrics) site_record = { "name": self._sanitize_name(site.get("name", "")), "metrics": metrics, } weight = self._safe_weight(site.get("weight")) if weight is not None: site_record["weight"] = weight weight_key = self._safe_text(site.get("weight_key")) if weight_key: site_record["weight_key"] = weight_key normalized_sites.append(site_record) return normalized_sites def _normalize_site_weights(self, site_weights, skipped): if not isinstance(site_weights, list): return {} normalized_weights = {} for site_weight in site_weights: if not isinstance(site_weight, dict): continue if len(normalized_weights) >= self.max_sites_per_round: self._add_skipped(skipped, site_weight.get("name"), "", "too_many_sites") break name = self._sanitize_name(site_weight.get("name", "")) if not name: continue record = {} weight = self._safe_weight(site_weight.get("weight")) if weight is not None: record["weight"] = weight weight_key = self._safe_text(site_weight.get("weight_key")) if weight_key: record["weight_key"] = weight_key if record: normalized_weights[name] = record return normalized_weights @staticmethod def _apply_site_weights(sites, site_weights): if not site_weights: return for site in sites: weight_info = site_weights.get(site.get("name")) if weight_info: site.update(weight_info) def _normalize_metrics(self, metrics, site, skipped, for_aggregation=False): if not isinstance(metrics, dict): return [] result = [] for name, value in metrics.items(): if len(result) >= self.max_metrics_per_site_per_round: self._add_skipped(skipped, site, name, "too_many_metrics") break if not isinstance(name, str): self._add_skipped(skipped, site, "", "non_string_metric_name") continue metric_name = self._sanitize_name(name) if not metric_name: self._add_skipped(skipped, site, "", "empty_metric_name") continue normalized, reason = self._normalize_metric_value(value, for_aggregation=for_aggregation) if reason: self._add_skipped(skipped, site, metric_name, reason) continue result.append({"name": metric_name, "value": normalized}) return result def _normalize_metric_value(self, value, for_aggregation=False) -> Tuple[Any, Optional[str]]: value = self._to_python_scalar(value) if isinstance(value, bool): return value, None if isinstance(value, int) and not isinstance(value, bool): if value.bit_length() > self.max_int_bit_length: return None, "number_too_large" return value, None if isinstance(value, float): if not math.isfinite(value): return None, "non_finite_number" return value, None if not for_aggregation and isinstance(value, str): if len(value) > self.max_string_value_length: return None, "string_too_long" return self._strip_control_chars(value), None return None, "unsupported_type" def _normalize_key_metric(self, key_metric): if not isinstance(key_metric, dict): return None name = key_metric.get("name") if not isinstance(name, str) or not name: return None normalized = {"name": self._sanitize_name(name)} mode = key_metric.get("mode") if mode in ("max", "min"): normalized["mode"] = mode mode_source = self._safe_text(key_metric.get("mode_source")) if mode_source: normalized["mode_source"] = mode_source return normalized def _normalize_selection_info(self, selection): if not isinstance(selection, dict): return None normalized = {} best_round = self._safe_round(selection.get("best_round")) if best_round is not None: normalized["best_round"] = best_round best_metrics = self._normalize_metric_collection(selection.get("best_metrics"), for_aggregation=True) if best_metrics: normalized["best_metrics"] = best_metrics best_aggregated_metrics = self._normalize_metric_collection( selection.get("best_aggregated_metrics"), for_aggregation=True ) if best_aggregated_metrics: normalized["best_aggregated_metrics"] = best_aggregated_metrics key_metric = self._normalize_key_metric(selection.get("key_metric")) if key_metric: normalized["key_metric"] = key_metric source = self._safe_text(selection.get("source")) if source: normalized["best_metric_source"] = source metric_source = self._safe_text(selection.get("metric_source")) if metric_source: normalized["best_metric_detail_source"] = metric_source return normalized or None def _normalize_metric_collection(self, metrics, for_aggregation=False): if isinstance(metrics, dict): return self._normalize_metrics(metrics, site=None, skipped=[], for_aggregation=for_aggregation) if not isinstance(metrics, list): return [] result = [] for metric in metrics: if len(result) >= self.max_metrics_per_site_per_round: break if not isinstance(metric, dict): continue name = metric.get("name") if not isinstance(name, str): continue metric_name = self._sanitize_name(name) if not metric_name: continue normalized, reason = self._normalize_metric_value(metric.get("value"), for_aggregation=for_aggregation) if reason: continue result.append({"name": metric_name, "value": normalized}) return result def _append_round_record(self, fl_ctx, record): self._ensure_paths(fl_ctx) os.makedirs(os.path.dirname(self._round_file_path), exist_ok=True) record = self._fit_round_record(record) line = self._safe_json_dumps(record) if line is None: record = self._make_minimal_round_record(record, "serialization_failed") line = self._safe_json_dumps(record) if line is None: return if len(line.encode("utf-8")) > self.max_round_record_bytes: record = self._make_minimal_round_record(record, "round_record_too_large") line = self._safe_json_dumps(record) if line is None or len(line.encode("utf-8")) > self.max_round_record_bytes: return with open(self._round_file_path, "a", encoding="utf-8") as f: f.write(line + "\n") def _write_summary_if_needed(self, fl_ctx): if not self._has_metrics: return self._ensure_paths(fl_ctx) summary = { "schema_version": "1", "status": "metrics_reported", "final_round": self._final_round, "final_aggregated_metrics": self._final_aggregated_metrics, "round_metrics_file": self.round_file_name, "notes": [ "Aggregated metrics are weighted averages of client-reported metric values.", "Nonlinear metrics are not recomputed from pooled predictions.", ], } job_name = self._get_job_name(fl_ctx) if job_name: summary["job_name"] = job_name if self._metric_source: summary["metric_source"] = self._metric_source if self._metric_split: summary["metric_split"] = self._metric_split if self._best_selection and self._best_selection.get("key_metric"): summary["key_metric"] = self._best_selection["key_metric"] elif self._key_metric: summary["key_metric"] = self._key_metric if self._best_selection: for field in ( "best_round", "best_metrics", "best_aggregated_metrics", "best_metric_source", "best_metric_detail_source", ): if field in self._best_selection: summary[field] = self._best_selection[field] if self._aggregation: summary["aggregation"] = self._sanitize_json_object(self._aggregation) data = self._safe_json_dumps(summary, indent=2) if data is None: return if len(data.encode("utf-8")) > self.max_summary_bytes: summary.pop("notes", None) data = self._safe_json_dumps(summary, indent=2) if data is None or len(data.encode("utf-8")) > self.max_summary_bytes: return os.makedirs(os.path.dirname(self._summary_file_path), exist_ok=True) with open(self._summary_file_path, "w", encoding="utf-8") as f: f.write(data) def _fit_round_record(self, record): fitted = { "round": record.get("round"), "aggregated_metrics": self._fit_json_list( record.get("aggregated_metrics", []), self.max_round_record_bytes ), "sites": [], "skipped_metrics": [], } for field in ("aggregation", "key_metric"): if field in record: fitted[field] = record[field] used = len((self._safe_json_dumps(fitted) or "").encode("utf-8")) for field in ("sites", "skipped_metrics"): for item in record.get(field, []): item_json = self._safe_json_dumps(item) if item_json is None: continue next_size = used + len(item_json.encode("utf-8")) + 2 if next_size > self.max_round_record_bytes: fitted["skipped_metrics"].append({"name": "", "reason": f"{field}_truncated"}) return fitted fitted[field].append(item) used = next_size return fitted def _fit_json_list(self, items, max_bytes): result = [] used = 0 for item in items: item_json = self._safe_json_dumps(item) if item_json is None: continue item_size = len(item_json.encode("utf-8")) + 1 if used + item_size > max_bytes: break result.append(item) used += item_size return result def _make_minimal_round_record(self, record, reason): return { "round": record.get("round"), "aggregated_metrics": self._fit_json_list( record.get("aggregated_metrics", []), self.max_round_record_bytes ), "sites": [], "skipped_metrics": [{"name": "", "reason": reason}], } @staticmethod def _safe_json_dumps(value, indent=None): try: return json.dumps(value, allow_nan=False, indent=indent, separators=None if indent else (",", ":")) except (TypeError, ValueError, OverflowError): return None def _ensure_paths(self, fl_ctx): if self._summary_file_path and self._round_file_path: return self._validate_file_name(self.summary_file_name, "metrics summary file name") self._validate_file_name(self.round_file_name, "round metrics file name") run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) self._summary_file_path = resolve_path_under_root( run_dir, os.path.join(self.results_dir, self.summary_file_name), "metrics summary path" ) self._round_file_path = resolve_path_under_root( run_dir, os.path.join(self.results_dir, self.round_file_name), "round metrics path" ) @staticmethod def _validate_file_name(file_name: str, path_name: str): if not isinstance(file_name, str): raise TypeError(f"{path_name} must be str but got {type(file_name)}") if os.path.basename(file_name) != file_name or file_name in ("", ".", ".."): raise ValueError(f"{path_name} {file_name} must be relative and stay inside the metrics directory.") def _get_current_round(self, aggr_result, fl_ctx): if aggr_result.current_round is not None: return aggr_result.current_round meta = aggr_result.meta or {} for key in (AppConstants.CURRENT_ROUND, "current_round"): value = meta.get(key) if isinstance(value, int) and not isinstance(value, bool): return value value = fl_ctx.get_prop(AppConstants.CURRENT_ROUND, None) if isinstance(value, int) and not isinstance(value, bool): return value return None def _get_job_name(self, fl_ctx): job_name = self._get_job_name_from_meta(fl_ctx.get_prop(FLContextKey.JOB_META, None)) if job_name: return job_name return self._safe_text(fl_ctx.get_job_id()) def _get_job_name_from_meta(self, meta): if not isinstance(meta, dict): return None for key in ( JobMetaKey.JOB_NAME.value, JobMetaKey.JOB_NAME, JobMetaKey.JOB_FOLDER_NAME.value, JobMetaKey.JOB_FOLDER_NAME, ): job_name = self._safe_text(meta.get(key)) if job_name: return job_name return None def _get_site_name(self, model, fl_ctx): meta = model.meta or {} for key in (FLMetaKey.SITE_NAME, "client_name", "site_name"): value = meta.get(key) if isinstance(value, str) and value: return value peer_ctx = fl_ctx.get_peer_context() if peer_ctx: identity = peer_ctx.get_identity_name(default=None) if identity: return identity return AppConstants.CLIENT_UNKNOWN @staticmethod def _safe_round(value): if isinstance(value, bool): return None if isinstance(value, int): return value return None def _safe_number(self, value): value = self._to_python_scalar(value) if isinstance(value, bool): return None if isinstance(value, int): if value.bit_length() <= self.max_int_bit_length: return value return None if isinstance(value, float) and math.isfinite(value): return value return None def _safe_weight(self, value): weight = self._safe_number(value) if weight is None or weight <= 0: return None return weight @staticmethod def _to_python_scalar(value): if isinstance(value, np.generic): try: return value.item() except (TypeError, ValueError, OverflowError): return value return value def _safe_text(self, value): if not isinstance(value, str): return None return self._strip_control_chars(value[: self.max_metric_name_length]) def _sanitize_name(self, name): if not isinstance(name, str): return "" return self._strip_control_chars(name[: self.max_metric_name_length]) @staticmethod def _strip_control_chars(value: str): return "".join(ch for ch in value if ord(ch) >= 32 and ch != "\x7f") def _add_skipped(self, skipped, site, name, reason): if len(skipped) >= self.max_skipped_metrics_per_round: return record = {"name": self._sanitize_name(name), "reason": reason} if site is not None: record["site"] = self._sanitize_name(site) skipped.append(record) def _extend_round_skipped(self, current_round, skipped): if not skipped: return round_skipped = self._round_skipped.setdefault(current_round, []) self._merge_skipped(round_skipped, skipped) def _merge_skipped(self, target, skipped): for record in skipped: if len(target) >= self.max_skipped_metrics_per_round: return target.append(record) def _sanitize_json_object(self, value): if isinstance(value, dict): result = {} for key, child in value.items(): if not isinstance(key, str): continue child_value = self._sanitize_json_object(child) if child_value is not None: result[self._sanitize_name(key)] = child_value return result if isinstance(value, list): return [v for v in (self._sanitize_json_object(child) for child in value) if v is not None] normalized, reason = self._normalize_metric_value(value, for_aggregation=False) if reason: return None return normalized