Source code for nvflare.widgets.info_collector

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime

from nvflare.apis.analytix import AnalyticsData
from nvflare.apis.dxo import from_shareable
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable

from .widget import Widget


[docs]class GroupInfoCollector(object): def __init__(self): """Records the information using a dict of dict. Note: Key is group name and value is the information dictionary. """ self.info = {}
[docs] def set_info(self, group_name: str, info: dict): self.info[group_name] = info
[docs] def add_info(self, group_name: str, info: dict): if group_name not in self.info: self.info[group_name] = info else: self.info[group_name].update(info)
[docs]class InfoCollector(Widget): CATEGORY_STATS = "stats" CATEGORY_ERROR = "error" EVENT_TYPE_GET_STATS = "info_collector.get_stats" CTX_KEY_STATS_COLLECTOR = "info_collector.stats_collector" def __init__(self): """A widget for information collection. Note: self.categories structure: category (dict) group (dict) key/value (dict) """ super().__init__() self.categories = {} self.engine = None
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: self.reset_all() self.engine = fl_ctx.get_engine() elif event_type == EventType.END_RUN: self.engine = None elif event_type in ( EventType.CRITICAL_LOG_AVAILABLE, EventType.ERROR_LOG_AVAILABLE, EventType.WARNING_LOG_AVAILABLE, EventType.EXCEPTION_LOG_AVAILABLE, ): origin = fl_ctx.get_prop(FLContextKey.EVENT_ORIGIN, None) if origin: group_name = str(origin) else: group_name = "general" data = fl_ctx.get_prop(FLContextKey.EVENT_DATA, None) if not isinstance(data, Shareable): # not a valid error report self.log_error( fl_ctx=fl_ctx, msg="wrong event data type for event {}: expect Shareable but got {}".format( event_type, type(data) ), fire_event=False, ) return try: dxo = from_shareable(data) except: self.log_exception( fl_ctx=fl_ctx, msg="invalid event data type for event {}".format(event_type), fire_event=False ) return analytic_data = AnalyticsData.from_dxo(dxo) if not analytic_data: return if event_type == EventType.CRITICAL_LOG_AVAILABLE: key = "critical" elif event_type == EventType.ERROR_LOG_AVAILABLE: key = "error" elif event_type == EventType.WARNING_LOG_AVAILABLE: key = "warning" else: key = "exception" self.add_error(group_name=group_name, key=key, err=analytic_data.value)
[docs] def get_run_stats(self) -> dict: """Gets status for this current run. Returns: A dictionary that contains the status for this run. """ # NOTE: it's important to assign self.engine to a new var! # This is because another thread may fire the END_RUN event, which will cause # self.engine to be set to None, just after checking it being None and before using it! engine = self.engine if not engine: return {} # NOTE: we need a new context here to make sure all sticky props are copied! # We create a new StatusCollector to hold status info. # Do not use the InfoCollector itself for thread safety - multiple calls to # this method (from parallel admin commands) are possible at the same time! with self.engine.new_context() as fl_ctx: coll = GroupInfoCollector() fl_ctx.set_prop(key=self.CTX_KEY_STATS_COLLECTOR, value=coll, sticky=False, private=True) engine.fire_event(event_type=self.EVENT_TYPE_GET_STATS, fl_ctx=fl_ctx) # Get the StatusCollector from the fl_ctx, it could have been updated by other component. coll = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR) return coll.info
[docs] def add_info(self, category_name: str, group_name: str, key: str, value): """Adds information to the specified category / group. Args: category_name (str): The top level distinction is called category. group_name (str): One level down category is called group key (str): The key to be recorded inside the dict. value (str): The value to be recorded inside the dict. """ category = self.categories.get(category_name, None) if not category: category = dict() self.categories[category_name] = category group = category.get(group_name, None) if not group: group = dict() category[group_name] = group group[key] = value
[docs] def set_info(self, category_name: str, group_name: str, info: dict): """Sets information to the specified category / group. Args: category_name (str): The top level distinction is called category. group_name (str): One level down category is called group info (dict): The dict to be recorded. Note: This sets the entire dictionary vs add_info only add a key-value pair. """ category = self.categories.get(category_name, None) if not category: category = dict() self.categories[category_name] = category category[group_name] = info
[docs] def get_category(self, category_name: str): """Gets the category dict. Args: category_name (str): The name of the category. Returns: A dictionary of specified category. """ return self.categories.get(category_name, None)
[docs] def get_group(self, category_name: str, group_name: str): """Gets the group dict. Args: category_name (str): The name of the category. group_name (str): The name of the group_name. Returns: A dictionary of specified category/group. """ cat = self.categories.get(category_name, None) if not cat: return None return cat.get(group_name, None)
[docs] def reset_all(self): """Resets all information collected.""" self.categories = {}
[docs] def reset_category(self, category_name: str): """Resets the specified category information collected. Args: category_name (str): The name of the category. """ self.categories[category_name] = {}
[docs] def reset_group(self, category_name: str, group_name: str): """Resets the specified category/group information collected. Args: category_name (str): The name of the category. group_name (str): The name of the group_name. """ cat = self.categories.get(category_name, None) if not cat: return cat.get[group_name] = {}
[docs] def add_error(self, group_name: str, key: str, err: str): """Adds error information to error category. Args: group_name (str): One level down category is called group key (str): The key to be recorded inside the dict. err (str): The error value to be put in. """ now = datetime.datetime.now() value = "{}: {}".format(now.strftime("%Y-%m-%d %H:%M:%S"), err) self.add_info(category_name=self.CATEGORY_ERROR, group_name=group_name, key=key, value=value)
[docs] def get_errors(self): """Gets the error category information.""" return self.get_category(self.CATEGORY_ERROR)
[docs] def reset_errors(self): """Resets the error category information.""" self.reset_category(self.CATEGORY_ERROR)