Source code for nvflare.app_common.widgets.intime_model_selector

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from nvflare.apis.dxo import DataKind, MetaKey, from_shareable
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.security.logging import secure_format_exception
from nvflare.widgets.widget import Widget


[docs]class IntimeModelSelector(Widget): def __init__( self, weigh_by_local_iter=False, aggregation_weights=None, validation_metric_name=MetaKey.INITIAL_METRICS, key_metric: str = "val_accuracy", negate_key_metric: bool = False, ): """Handler to determine if the model is globally best. Args: weigh_by_local_iter (bool, optional): whether the metrics should be weighted by trainer's iteration number. aggregation_weights (dict, optional): a mapping of client name to float for aggregation. Defaults to None. validation_metric_name (str, optional): key used to save initial validation metric in the DXO meta properties (defaults to MetaKey.INITIAL_METRICS). key_metric: if metrics are a `dict`, `key_metric` can select the metric used for global model selection. Defaults to "val_accuracy". negate_key_metric: Whether to invert the key metric. Should be used if key metric is a loss. Defaults to `False`. """ super().__init__() self.val_metric = self.best_val_metric = -np.inf self.weigh_by_local_iter = weigh_by_local_iter self.validation_metric_name = validation_metric_name self.aggregation_weights = aggregation_weights or {} self.key_metric = key_metric self.negate_key_metric = negate_key_metric self.logger.info(f"model selection weights control: {aggregation_weights}") self._reset_stats()
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: self._startup() elif event_type == AppEventType.ROUND_STARTED: self._reset_stats() elif event_type == AppEventType.BEFORE_CONTRIBUTION_ACCEPT: self._before_accept(fl_ctx) elif event_type == AppEventType.BEFORE_AGGREGATION: self._before_aggregate(fl_ctx)
def _startup(self): self._reset_stats() def _reset_stats(self): self.validation_metric_weighted_sum = 0 self.validation_metric_sum_of_weights = 0 def _before_accept(self, fl_ctx: FLContext): peer_ctx = fl_ctx.get_peer_context() shareable: Shareable = peer_ctx.get_prop(FLContextKey.SHAREABLE) try: dxo = from_shareable(shareable) except Exception as e: self.log_exception( fl_ctx, f"shareable data is not a valid DXO. Received Exception: {secure_format_exception(e)}" ) return False if dxo.data_kind not in (DataKind.WEIGHT_DIFF, DataKind.WEIGHTS, DataKind.COLLECTION): self.log_debug(fl_ctx, "cannot handle {}".format(dxo.data_kind)) return False if dxo.data is None: self.log_debug(fl_ctx, "no data to filter") return False contribution_round = shareable.get_cookie(AppConstants.CONTRIBUTION_ROUND) client_name = peer_ctx.get_identity_name(default="?") current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) if current_round == 0: self.log_debug(fl_ctx, "skipping round 0") return False # There is no aggregated model at round 0 if contribution_round != current_round: self.log_warning( fl_ctx, f"discarding shareable from {client_name} for round: {contribution_round}. Current round is: {current_round}", ) return False validation_metric = dxo.get_meta_prop(self.validation_metric_name) if validation_metric is None: self.log_warning(fl_ctx, f"validation metric not existing in {client_name}") return False # select key metric if dictionary of metrics is provided if isinstance(validation_metric, dict): if self.key_metric in validation_metric: validation_metric = validation_metric[self.key_metric] else: self.log_warning( fl_ctx, f"validation metric `{self.key_metric}` not in metrics from {client_name}: {list(validation_metric.keys())}", ) return False if self.negate_key_metric: validation_metric = -1.0 * validation_metric self.log_info(fl_ctx, f"validation metric {validation_metric} from client {client_name}") if self.weigh_by_local_iter: n_iter = dxo.get_meta_prop(MetaKey.NUM_STEPS_CURRENT_ROUND, 1.0) else: n_iter = 1.0 aggregation_weights = self.aggregation_weights.get(client_name, 1.0) self.log_debug(fl_ctx, f"aggregation weight: {aggregation_weights}") weight = n_iter * aggregation_weights self.validation_metric_weighted_sum += validation_metric * weight self.validation_metric_sum_of_weights += weight return True def _before_aggregate(self, fl_ctx): if self.validation_metric_sum_of_weights == 0: self.log_debug(fl_ctx, "nothing accumulated") return False self.val_metric = self.validation_metric_weighted_sum / self.validation_metric_sum_of_weights self.logger.debug(f"weighted validation metric {self.val_metric}") if self.val_metric > self.best_val_metric: self.best_val_metric = self.val_metric current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) self.log_info(fl_ctx, f"new best validation metric at round {current_round}: {self.best_val_metric}") # Fire event to notify that the current global model is a new best fl_ctx.set_prop(AppConstants.VALIDATION_RESULT, self.best_val_metric, private=True, sticky=False) self.fire_event(AppEventType.GLOBAL_BEST_MODEL_AVAILABLE, fl_ctx) self._reset_stats() return True
[docs]class IntimeModelSelectionHandler(IntimeModelSelector): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.logger.warning("'IntimeModelSelectionHandler' was renamed to 'IntimeModelSelector'")