Source code for nvflare.fuel.f3.stats_pool

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import csv
import json
import sys
import threading
import time
from typing import List, Tuple, Union

_KEY_MAX = "max"
_KEY_MIN = "min"
_KEY_NAME = "name"
_KEY_DESC = "description"
_KEY_TOTAL = "total"
_KEY_COUNT = "count"
_KEY_UNIT = "unit"
_KEY_MARKS = "marks"
_KEY_COUNTER_NAMES = "counter_names"
_KEY_CAT_DATA = "cat_data"


[docs]class StatsMode: COUNT = "count" PERCENT = "percent" AVERAGE = "avg" MIN = "min" MAX = "max"
VALID_HIST_MODES = [StatsMode.COUNT, StatsMode.PERCENT, StatsMode.AVERAGE, StatsMode.MAX, StatsMode.MIN]
[docs]def format_value(v: float, n=3): if v is None: return "n/a" fmt = "{:." + str(n) + "e}" return fmt.format(v)
class _Bin: def __init__(self, count=0, total_value=0.0, min_value=None, max_value=None): self.count = count self.total = total_value self.min = min_value self.max = max_value def record_value(self, value: float): self.count += 1 self.total += value if self.min is None or self.min > value: self.min = value if self.max is None or self.max < value: self.max = value def get_content(self, mode=StatsMode.COUNT, total_count=0): if self.count == 0: return "" if mode == StatsMode.COUNT: return str(self.count) if mode == StatsMode.PERCENT: return str(round(self.count / total_count, 2)) if mode == StatsMode.AVERAGE: avg = self.total / self.count return format_value(avg) if mode == StatsMode.MIN: return format_value(self.min) if mode == StatsMode.MAX: return format_value(self.max) return "n/a" def to_dict(self) -> dict: return { _KEY_COUNT: self.count, _KEY_TOTAL: self.total, _KEY_MIN: self.min if self.min is not None else "", _KEY_MAX: self.max if self.max is not None else "", } @staticmethod def from_dict(d: dict): if not isinstance(d, dict): raise ValueError(f"d must be dict but got {type(d)}") b = _Bin() b.count = d.get(_KEY_COUNT, 0) b.total = d.get(_KEY_TOTAL, 0) m = d.get(_KEY_MIN) if isinstance(m, str): b.min = None else: b.min = m x = d.get(_KEY_MAX) if isinstance(x, str): b.max = None else: b.max = x return b
[docs]class StatsPool: def __init__(self, name: str, description: str): self.name = name self.description = description
[docs] def to_dict(self) -> dict: pass
[docs] def get_table(self, mode): pass
[docs] @staticmethod def from_dict(d: dict): pass
[docs]class RecordWriter:
[docs] def write(self, pool_name: str, category: str, value: float, report_time: float): pass
[docs] def close(self): pass
[docs]class HistPool(StatsPool): def __init__(self, name: str, description: str, marks: Union[List[float], Tuple], unit: str, record_writer=None): if record_writer: if not isinstance(record_writer, RecordWriter): raise TypeError(f"record_writer must be RecordWriter but got {type(record_writer)}") StatsPool.__init__(self, name, description) self.update_lock = threading.Lock() self.unit = unit self.marks = marks self.record_writer = record_writer # used for writing raw records self.cat_bins = {} # category name => list of bins if not marks: raise ValueError("marks not specified") if len(marks) < 2: raise ValueError(f"marks must have at least two numbers but got {len(marks)}") for i in range(1, len(marks)): if marks[i] <= marks[i - 1]: raise ValueError(f"marks must contain increasing values, but got {marks}") # A range is defined: left <= N < right [...) # [..., M1) [M1, M2) [M2, M3) [M3, ...) m = sys.float_info.max self.ranges = [(-m, marks[0])] self.range_names = [f"<{marks[0]}"] for i in range(len(marks) - 1): self.ranges.append((marks[i], marks[i + 1])) self.range_names.append(f"{marks[i]}-{marks[i+1]}") self.ranges.append((marks[-1], m)) self.range_names.append(f">={marks[-1]}")
[docs] def record_value(self, category: str, value: float): with self.update_lock: bins = self.cat_bins.get(category) if bins is None: bins = [None for _ in range(len(self.ranges))] self.cat_bins[category] = bins for i in range(len(self.ranges)): r = self.ranges[i] if r[0] <= value < r[1]: b = bins[i] if not b: b = _Bin() bins[i] = b b.record_value(value) if self.record_writer: self.record_writer.write(pool_name=self.name, category=category, value=value, report_time=time.time())
[docs] def get_table(self, mode=StatsMode.COUNT): with self.update_lock: headers = ["category"] has_values = [False for _ in range(len(self.ranges))] # determine bins that have values in any category for _, bins in self.cat_bins.items(): for i in range(len(self.ranges)): if bins[i]: has_values[i] = True for i in range(len(self.ranges)): if has_values[i]: headers.append(self.range_names[i]) headers.append("overall") rows = [] for cat_name in sorted(self.cat_bins.keys()): bins = self.cat_bins[cat_name] total_count = 0 total_value = 0.0 overall_min = None overall_max = None for b in bins: if b: total_count += b.count total_value += b.total if b.max is not None: if overall_max is None or overall_max < b.max: overall_max = b.max if b.min is not None: if overall_min is None or overall_min > b.min: overall_min = b.min r = [cat_name] for i in range(len(bins)): if not has_values[i]: continue b = bins[i] if not b: r.append("") else: r.append(b.get_content(mode, total_count)) # compute overall values overall_bin = _Bin( count=total_count, total_value=total_value, max_value=overall_max, min_value=overall_min ) r.append(overall_bin.get_content(mode, total_count)) rows.append(r) return headers, rows
[docs] def to_dict(self): with self.update_lock: cat_bins = {} for cat, bins in self.cat_bins.items(): exp_bins = [] for b in bins: if not b: exp_bins.append("") else: exp_bins.append(b.to_dict()) cat_bins[cat] = exp_bins return { _KEY_NAME: self.name, _KEY_DESC: self.description, _KEY_MARKS: list(self.marks), _KEY_UNIT: self.unit, _KEY_CAT_DATA: cat_bins, }
[docs] @staticmethod def from_dict(d: dict): p = HistPool( name=d.get(_KEY_NAME, ""), description=d.get(_KEY_DESC, ""), unit=d.get(_KEY_UNIT, ""), marks=d.get(_KEY_MARKS), ) cat_bins = d.get(_KEY_CAT_DATA) if not cat_bins: return p for cat, bins in cat_bins.items(): in_bins = [] for b in bins: if not b: in_bins.append(None) else: assert isinstance(b, dict) in_bins.append(_Bin.from_dict(b)) p.cat_bins[cat] = in_bins return p
[docs]class CounterPool(StatsPool): def __init__(self, name: str, description: str, counter_names: List[str], dynamic_counter_name=True): if not counter_names and not dynamic_counter_name: raise ValueError("counter_names cannot be empty") StatsPool.__init__(self, name, description) self.counter_names = counter_names self.cat_counters = {} # dict of cat_name => counter dict (counter_name => int) self.dynamic_counter_name = dynamic_counter_name self.update_lock = threading.Lock()
[docs] def increment(self, category: str, counter_name: str, amount=1): with self.update_lock: if counter_name not in self.counter_names: if self.dynamic_counter_name: self.counter_names.append(counter_name) else: raise ValueError(f"'{counter_name}' is not defined in pool '{self.name}'") counters = self.cat_counters.get(category) if not counters: counters = {} self.cat_counters[category] = counters c = counters.get(counter_name, 0) c += amount counters[counter_name] = c
[docs] def get_table(self, mode=""): with self.update_lock: headers = ["category"] eff_counter_names = [] for cn in self.counter_names: for _, counters in self.cat_counters.items(): v = counters.get(cn, 0) if v > 0: eff_counter_names.append(cn) break headers.extend(eff_counter_names) rows = [] for cat_name in sorted(self.cat_counters.keys()): counters = self.cat_counters[cat_name] r = [cat_name] for cn in eff_counter_names: value = counters.get(cn, 0) r.append(str(value)) rows.append(r) return headers, rows
[docs] def to_dict(self): with self.update_lock: return { _KEY_NAME: self.name, _KEY_DESC: self.description, _KEY_COUNTER_NAMES: list(self.counter_names), _KEY_CAT_DATA: self.cat_counters, }
[docs] @staticmethod def from_dict(d: dict): p = CounterPool( name=d.get(_KEY_NAME, ""), description=d.get(_KEY_DESC, ""), counter_names=d.get(_KEY_COUNTER_NAMES) ) p.cat_counters = d.get(_KEY_CAT_DATA) return p
[docs]def new_time_pool(name: str, description="", marks=None, record_writer=None) -> HistPool: if not marks: marks = (0.0001, 0.0005, 0.001, 0.002, 0.004, 0.008, 0.01, 0.02, 0.04, 0.08, 0.1, 0.2, 0.4, 0.8, 1.0, 2.0) return HistPool(name=name, description=description, marks=marks, unit="second", record_writer=record_writer)
[docs]def new_message_size_pool(name: str, description="", marks=None, record_writer=None) -> HistPool: if not marks: marks = (0.01, 0.1, 1, 10, 50, 100, 200, 500, 800, 1000) return HistPool(name=name, description=description, marks=marks, unit="MB", record_writer=record_writer)
[docs]def parse_hist_mode(mode: str) -> str: if not mode: return StatsMode.COUNT if mode.startswith("p"): return StatsMode.PERCENT elif mode.startswith("c"): return StatsMode.COUNT elif mode.startswith("a"): return StatsMode.AVERAGE if mode not in VALID_HIST_MODES: return "" else: return mode
[docs]class StatsPoolManager: _CONFIG_KEY_SAVE_POOLS = "save_pools" lock = threading.Lock() pools = {} # name => pool pool_config = {} record_writer = None @classmethod def _check_name(cls, name, scope): name = name.lower() if name not in cls.pools: return name if scope: name = f"{name}@{scope}" if name not in cls.pools: return name raise ValueError(f"pool '{name}' is already defined")
[docs] @classmethod def set_pool_config(cls, config: dict): if not isinstance(config, dict): raise ValueError(f"config data must be dict but got {type(config)}") for k, v in config.items(): cls.pool_config[k.lower()] = v
[docs] @classmethod def set_record_writer(cls, record_writer: RecordWriter): if not isinstance(record_writer, RecordWriter): raise TypeError(f"record_writer must be RecordWriter but got {type(record_writer)}") cls.record_writer = record_writer
@classmethod def _keep_hist_records(cls, name): name = name.lower() save_pools_list = cls.pool_config.get(cls._CONFIG_KEY_SAVE_POOLS, None) if not save_pools_list: return False return ("*" in save_pools_list) or (name in save_pools_list)
[docs] @classmethod def add_time_hist_pool(cls, name: str, description: str, marks=None, scope=None): # check pool config keep_records = cls._keep_hist_records(name) name = cls._check_name(name, scope) record_writer = cls.record_writer if keep_records else None p = new_time_pool(name, description, marks, record_writer=record_writer) cls.pools[name] = p return p
[docs] @classmethod def add_msg_size_pool(cls, name: str, description: str, marks=None, scope=None): keep_records = cls._keep_hist_records(name) name = cls._check_name(name, scope) record_writer = cls.record_writer if keep_records else None p = new_message_size_pool(name, description, marks, record_writer=record_writer) cls.pools[name] = p return p
[docs] @classmethod def add_counter_pool(cls, name: str, description: str, counter_names: list, scope=None): name = cls._check_name(name, scope) p = CounterPool(name, description, counter_names) cls.pools[name] = p return p
[docs] @classmethod def get_pool(cls, name: str): name = name.lower() return cls.pools.get(name)
[docs] @classmethod def delete_pool(cls, name: str): with cls.lock: name = name.lower() return cls.pools.pop(name, None)
[docs] @classmethod def get_table(cls): with cls.lock: headers = ["pool", "type", "description"] rows = [] for k in sorted(cls.pools.keys()): v = cls.pools[k] r = [v.name] if isinstance(v, HistPool): t = "hist" elif isinstance(v, CounterPool): t = "counter" else: t = "?" r.append(t) r.append(v.description) rows.append(r) return headers, rows
[docs] @classmethod def to_dict(cls): with cls.lock: result = {} for k in sorted(cls.pools.keys()): v = cls.pools[k] if isinstance(v, HistPool): t = "hist" elif isinstance(v, CounterPool): t = "counter" else: raise ValueError(f"unknown type of pool '{k}'") result[k] = {"type": t, "pool": v.to_dict()} return result
[docs] @classmethod def from_dict(cls, d: dict): cls.pools = {} for k, v in d.items(): t = v.get("type") if not t: raise ValueError("missing pool type") pd = v.get("pool") if not pd: raise ValueError("missing pool data") if t == "hist": p = HistPool.from_dict(pd) elif t == "counter": p = CounterPool.from_dict(pd) else: raise ValueError(f"invalid pool type {t}") cls.pools[k] = p
[docs] @classmethod def dump_summary(cls, file_name: str): stats_dict = cls.to_dict() json_string = json.dumps(stats_dict, indent=4) with open(file_name, "w") as f: f.write(json_string)
[docs] @classmethod def close(cls): if cls.record_writer: cls.record_writer.close()
[docs]class CsvRecordHandler(RecordWriter): def __init__(self, file_name): self.file = open(file_name, "w") self.writer = csv.writer(self.file) self.lock = threading.Lock()
[docs] def write(self, pool_name: str, category: str, value: float, report_time: float): if not pool_name.isascii(): raise ValueError(f"pool_name {pool_name} contains non-ascii chars") if not category.isascii(): raise ValueError(f"category {category} contains non-ascii chars") row = [pool_name, category, report_time, value] with self.lock: self.writer.writerow(row)
[docs] def close(self): self.file.close()
[docs] @staticmethod def read_records(csv_file_name: str): pools = {} reader = CsvRecordReader(csv_file_name) for rec in reader: pool_name = rec.pool_name cat_name = rec.category report_time = rec.report_time value = rec.value cats = pools.get(pool_name) if not cats: cats = {} pools[pool_name] = cats recs = cats.get(cat_name) if not recs: recs = [] cats[cat_name] = recs recs.append((report_time, value)) return pools
[docs]class StatsRecord: def __init__(self, pool_name, category, report_time, value): self.pool_name = pool_name self.category = category self.report_time = report_time self.value = value
[docs]class CsvRecordReader: def __init__(self, csv_file_name: str): self.csv_file_name = csv_file_name self.file = open(csv_file_name) self.reader = csv.reader(self.file) def __iter__(self): return self def __next__(self): row = next(self.reader) if len(row) != 4: raise ValueError(f"'{self.csv_file_name}' is not a valid stats pool record file: bad row length {len(row)}") pool_name = row[0] cat_name = row[1] report_time = float(row[2]) value = float(row[3]) return StatsRecord(pool_name, cat_name, report_time, value)