Source code for nvflare.app_common.widgets.event_recorder

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from typing import Dict

from nvflare.apis.client_engine_spec import ClientEngineSpec
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.server_engine_spec import ServerEngineSpec
from nvflare.apis.shareable import Shareable
from nvflare.widgets.widget import Widget


class _CtxPropReq(object):
    """Requirements of a prop in the FLContext.

    Arguments:
        dtype: data type of the prop.
        is_private: if this prop is private.
        is_sticky: if this prop is sticky.
        allow_none: if this prop can be None
    """

    def __init__(self, dtype, is_private, is_sticky, allow_none: bool = False):
        self.dtype = dtype
        self.is_private = is_private
        self.is_sticky = is_sticky
        self.allow_none = allow_none


class _EventReq(object):
    """Requirements for FL and peer context when an event is fired.

    Arguments:
        ctx_reqs: A dictionary that describes the requirements for fl_ctx. It maps property names to _CtxPropReq
        peer_ctx_reqs: A dictionary that describes the requirements for peer_ctx. It maps property names to _CtxPropReq
    """

    def __init__(
        self,
        ctx_reqs: Dict[str, _CtxPropReq],
        peer_ctx_reqs: Dict[str, _CtxPropReq],
        ctx_block_list: [str] = None,
        peer_ctx_block_list: [str] = None,
    ):
        self.ctx_reqs = ctx_reqs  # prop name => _CtxPropReq
        self.peer_ctx_reqs = peer_ctx_reqs

        if ctx_block_list is None:
            ctx_block_list = []

        if peer_ctx_block_list is None:
            peer_ctx_block_list = []

        self.ctx_block_list = ctx_block_list
        self.peer_ctx_block_list = peer_ctx_block_list


class _EventStats(object):
    """Stats of each event."""

    def __init__(self):
        self.call_count = 0
        self.prop_missing = 0
        self.prop_none_value = 0
        self.prop_dtype_mismatch = 0
        self.prop_attr_mismatch = 0
        self.prop_block_list_violation = 0
        self.peer_ctx_missing = 0


[docs]class EventRecorder(Widget): _KEY_CTX_TYPE = "ctx_type" _KEY_EVENT_TYPE = "event_type" _KEY_EVENT_STATS = "event_stats" _KEY_EVENT_REQ = "event_req" def __init__(self, log_file_name=None): """A component to record all system-wide events. Args: log_file_name (str, optional): the log filename to save recorded events. Defaults to None. """ super().__init__() all_ctx_reqs = { "__run_num__": _CtxPropReq(dtype=str, is_private=False, is_sticky=True), "__identity_name__": _CtxPropReq(dtype=str, is_private=False, is_sticky=True), } run_req = _EventReq(ctx_reqs=all_ctx_reqs, peer_ctx_reqs={}) self.event_reqs = {EventType.START_RUN: run_req, EventType.END_RUN: run_req} # event type => _EventReq self.event_stats = {} # event_type => _EventStats self._log_handler_added = False self.log_file_name = log_file_name if log_file_name else "event_recorded.txt"
[docs] def event_tag(self, fl_ctx: FLContext): event_type = fl_ctx.get_prop(self._KEY_EVENT_TYPE, "?") event_id = fl_ctx.get_prop(FLContextKey.EVENT_ID, None) if event_id: return "[type={}, id={}]".format(event_type, event_id) else: return "[{}]".format(event_type)
[docs] def event_error_tag(self, fl_ctx: FLContext): ctx_type = fl_ctx.get_prop(self._KEY_CTX_TYPE, "?") return "Event {}: in {},".format(self.event_tag(fl_ctx), ctx_type)
[docs] def validate_prop(self, prop_name: str, req: _CtxPropReq, fl_ctx: FLContext): stats = fl_ctx.get_prop(self._KEY_EVENT_STATS, None) detail = fl_ctx.get_prop_detail(prop_name) if not isinstance(detail, dict): stats.prop_missing += 1 self.logger.error("{} required prop '{}' doesn't exist".format(self.event_error_tag(fl_ctx), prop_name)) return value = detail["value"] if value is None and not req.allow_none: stats.prop_none_value += 1 self.logger.error( "{} prop '{}' is None, but None is not allowed".format(self.event_error_tag(fl_ctx), prop_name) ) if req.dtype is not None: if not isinstance(value, req.dtype): stats.prop_dtype_mismatch += 1 self.logger.error( "{} prop '{}' should be {}, but got {}".format( self.event_error_tag(fl_ctx), prop_name, req.dtype, type(value) ) ) if req.is_private and not detail["private"]: stats.prop_attr_mismatch += 1 self.logger.error( "{} prop '{}' should be private but is public".format(self.event_error_tag(fl_ctx), prop_name) ) if req.is_private is not None and not req.is_private and detail["private"]: stats.prop_attr_mismatch += 1 self.logger.error( "{} prop '{}' should be public but is private".format(self.event_error_tag(fl_ctx), prop_name) ) if req.is_sticky and not detail["sticky"]: stats.prop_attr_mismatch += 1 self.logger.error( "{} prop '{}' should be sticky but is non-sticky".format(self.event_error_tag(fl_ctx), prop_name) ) if req.is_sticky is not None and not req.is_sticky and detail["sticky"]: stats.prop_attr_mismatch += 1 self.logger.error( "{} prop '{}' should be non-sticky but is sticky".format(self.event_error_tag(fl_ctx), prop_name) )
[docs] def check_block_list(self, block_list, fl_ctx: FLContext): stats = fl_ctx.get_prop(self._KEY_EVENT_STATS, None) for prop_name in block_list: detail = fl_ctx.get_prop_detail(prop_name) if detail: stats.prop_block_list_violation += 1 self.logger.error("{} prop {} is not expected".format(self.event_error_tag(fl_ctx), prop_name))
[docs] def check_props(self, fl_ctx: FLContext): event_req = fl_ctx.get_prop(self._KEY_EVENT_REQ) stats = fl_ctx.get_prop(self._KEY_EVENT_STATS) for prop_name, req in event_req.ctx_reqs.items(): self.validate_prop(prop_name, req, fl_ctx) self.check_block_list(event_req.ctx_block_list, fl_ctx) if event_req.peer_ctx_reqs: peer_ctx = fl_ctx.get_peer_context() if not peer_ctx: stats.peer_ctx_missing += 1 self.logger.error("{} expected peer_ctx not present".format(self.event_error_tag(fl_ctx))) else: for prop_name, req in event_req.peer_ctx_reqs.items(): self.validate_prop(prop_name, req, peer_ctx) self.check_block_list(event_req.peer_ctx_block_list, peer_ctx)
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if not self._log_handler_added: workspace = fl_ctx.get_engine().get_workspace() app_dir = workspace.get_app_dir(fl_ctx.get_job_id()) output_file_handler = logging.FileHandler(os.path.join(app_dir, self.log_file_name)) formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") output_file_handler.setFormatter(formatter) self.logger.addHandler(output_file_handler) self._log_handler_added = True event_stats = self.event_stats.get(event_type, None) if not event_stats: event_stats = _EventStats() self.event_stats[event_type] = event_stats fl_ctx.set_prop(key=self._KEY_EVENT_STATS, value=event_stats, private=True, sticky=False) fl_ctx.set_prop(key=self._KEY_EVENT_TYPE, value=event_type, private=True, sticky=False) fl_ctx.set_prop(key=self._KEY_CTX_TYPE, value="fl_ctx", private=True, sticky=False) self.log_info(fl_ctx, "Got event {}".format(self.event_tag(fl_ctx)), fire_event=False) event_stats.call_count += 1 peer_ctx = fl_ctx.get_peer_context() if peer_ctx: event_id = fl_ctx.get_prop(FLContextKey.EVENT_ID) peer_ctx.set_prop(key=FLContextKey.EVENT_ID, value=event_id, private=True, sticky=False) peer_ctx.set_prop(key=self._KEY_EVENT_STATS, value=event_stats, private=True, sticky=False) peer_ctx.set_prop(key=self._KEY_EVENT_TYPE, value=event_type, private=True, sticky=False) peer_ctx.set_prop(key=self._KEY_CTX_TYPE, value="peer_ctx", private=True, sticky=False) self.log_info( fl_ctx, "Peer Context for event {}: {}".format(self.event_tag(fl_ctx), peer_ctx), fire_event=False ) event_req = self.event_reqs.get(event_type, None) fl_ctx.set_prop(key=self._KEY_EVENT_REQ, value=event_req, private=True, sticky=False) if event_req: self.check_props(fl_ctx) if event_type == EventType.END_RUN: # print stats for e, s in self.event_stats.items(): self.log_info(fl_ctx, "Stats of {}: {}".format(e, vars(s)), fire_event=False)
[docs]class ServerEventRecorder(EventRecorder): def __init__(self): """Server-specific event recorder.""" super().__init__() task_data_filter_reqs = _EventReq( ctx_reqs={ "__engine__": _CtxPropReq(dtype=ServerEngineSpec, is_private=True, is_sticky=True), FLContextKey.TASK_ID: _CtxPropReq(dtype=str, is_private=True, is_sticky=False), FLContextKey.TASK_NAME: _CtxPropReq(dtype=str, is_private=True, is_sticky=False), FLContextKey.TASK_DATA: _CtxPropReq(dtype=Shareable, is_private=True, is_sticky=False, allow_none=True), "testPrivateServerSticky": _CtxPropReq(dtype=str, is_private=True, is_sticky=True), "testPublicServerSticky": _CtxPropReq(dtype=str, is_private=False, is_sticky=True), }, ctx_block_list=[ "testPrivateServerNonSticky", "testPublicServerNonSticky", "testPrivateClientNonSticky", "testPublicClientNonSticky", "testPrivateClientSticky", "testPublicClientSticky", ], peer_ctx_reqs={ "__run_num__": _CtxPropReq(dtype=str, is_private=None, is_sticky=None), "__identity_name__": _CtxPropReq(dtype=str, is_private=None, is_sticky=None), "testPublicClientSticky": _CtxPropReq(dtype=str, is_private=None, is_sticky=None), }, peer_ctx_block_list=[ "__engine__", "testPrivateClientSticky", "testPrivateClientNonSticky", "testPublicClientNonSticky", ], ) self.event_reqs.update( { EventType.BEFORE_TASK_DATA_FILTER: task_data_filter_reqs, EventType.AFTER_TASK_DATA_FILTER: task_data_filter_reqs, } )
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: fl_ctx.set_prop( key="testPrivateServerSticky", value="this is a server private sticky", private=True, sticky=True ) fl_ctx.set_prop( key="testPublicServerSticky", value="this is a server public sticky", private=False, sticky=True ) fl_ctx.set_prop( key="testPrivateServerNonSticky", value="this is a server private non-sticky", private=True, sticky=False, ) fl_ctx.set_prop( key="testPublicServerNonSticky", value="this is a server public non-sticky", private=False, sticky=False ) super().handle_event(event_type, fl_ctx)
[docs]class ClientEventRecorder(EventRecorder): def __init__(self): """Client-specific event recorder.""" super().__init__() task_data_filter_reqs = _EventReq( ctx_reqs={ "__engine__": _CtxPropReq(dtype=ClientEngineSpec, is_private=True, is_sticky=True), FLContextKey.TASK_ID: _CtxPropReq(dtype=str, is_private=True, is_sticky=False), FLContextKey.TASK_NAME: _CtxPropReq(dtype=str, is_private=True, is_sticky=False), FLContextKey.TASK_DATA: _CtxPropReq(dtype=Shareable, is_private=True, is_sticky=False, allow_none=True), "testPrivateClientSticky": _CtxPropReq(dtype=str, is_private=True, is_sticky=True), "testPublicClientSticky": _CtxPropReq(dtype=str, is_private=False, is_sticky=True), }, ctx_block_list=[ "testPrivateServerNonSticky", "testPublicServerNonSticky", "testPrivateClientNonSticky", "testPublicClientNonSticky", "testPrivateServerSticky", "testPublicServerSticky", ], peer_ctx_reqs={ "__run_num__": _CtxPropReq(dtype=str, is_private=None, is_sticky=None), "__identity_name__": _CtxPropReq(dtype=str, is_private=None, is_sticky=None), "testPublicServerSticky": _CtxPropReq(dtype=str, is_private=None, is_sticky=None), }, peer_ctx_block_list=[ "__engine__", "testPrivateServerSticky", "testPrivateServerNonSticky", "testPublicServerNonSticky", ], ) self.event_reqs.update( { EventType.BEFORE_TASK_DATA_FILTER: task_data_filter_reqs, EventType.AFTER_TASK_DATA_FILTER: task_data_filter_reqs, } )
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: fl_ctx.set_prop( key="testPrivateClientSticky", value="this is a client private sticky", private=True, sticky=True ) fl_ctx.set_prop( key="testPublicClientSticky", value="this is a client public sticky", private=False, sticky=True ) fl_ctx.set_prop( key="testPrivateClientNonSticky", value="this is a client private non-sticky", private=True, sticky=False, ) fl_ctx.set_prop( key="testPublicClientNonSticky", value="this is a client public non-sticky", private=False, sticky=False ) super().handle_event(event_type, fl_ctx)