Source code for nvflare.app_common.np.np_validator

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import time

import numpy as np

from nvflare.apis.dxo import DXO, DataKind, from_shareable
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import AppConstants
from nvflare.security.logging import secure_format_exception

from .constants import NPConstants


[docs]class NPValidator(Executor): def __init__( self, epsilon=1, sleep_time=0, validate_task_name=AppConstants.TASK_VALIDATION, ): # Init functions of components should be very minimal. Init # is called when json is read. A big init will cause json loading to halt # for long time. super().__init__() self.logger = logging.getLogger("NPValidator") self._random_epsilon = epsilon self._sleep_time = sleep_time self._validate_task_name = validate_task_name
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): # if event_type == EventType.START_RUN: # Create all major components here. This is a simple app that doesn't need any components. # elif event_type == EventType.END_RUN: # # Clean up resources (closing files, joining threads, removing dirs etc.) pass
[docs] def execute( self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal, ) -> Shareable: # Any long tasks should check abort_signal regularly. # Otherwise, abort client will not work. count, interval = 0, 0.5 while count < self._sleep_time: if abort_signal.triggered: return make_reply(ReturnCode.TASK_ABORTED) time.sleep(interval) count += interval if task_name == self._validate_task_name: try: # First we extract DXO from the shareable. try: model_dxo = from_shareable(shareable) except Exception as e: self.log_error( fl_ctx, f"Unable to extract model dxo from shareable. Exception: {secure_format_exception(e)}" ) return make_reply(ReturnCode.BAD_TASK_DATA) # Get model from shareable. data_kind must be WEIGHTS. if model_dxo.data and model_dxo.data_kind == DataKind.WEIGHTS: model = model_dxo.data else: self.log_error( fl_ctx, "Model DXO doesn't have data or is not of type DataKind.WEIGHTS. Unable to validate." ) return make_reply(ReturnCode.BAD_TASK_DATA) # Check if key exists in model if NPConstants.NUMPY_KEY not in model: self.log_error(fl_ctx, "numpy_key not in model. Unable to validate.") return make_reply(ReturnCode.BAD_TASK_DATA) # The workflow provides MODEL_OWNER information in the shareable header. model_name = shareable.get_header(AppConstants.MODEL_OWNER, "?") # Print properties. self.log_info(fl_ctx, f"Model: \n{model}") self.log_info(fl_ctx, f"Task name: {task_name}") self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}") self.log_info(fl_ctx, f"Validating model from {model_name}.") # Check abort signal regularly. if abort_signal.triggered: return make_reply(ReturnCode.TASK_ABORTED) # Do some dummy validation. random_epsilon = np.random.random() self.log_info(fl_ctx, f"Adding random epsilon {random_epsilon} in validation.") val_results = {} np_data = model[NPConstants.NUMPY_KEY] np_data = np.sum(np_data / np.max(np_data)) val_results["accuracy"] = np_data + random_epsilon # Check abort signal regularly. if abort_signal.triggered: return make_reply(ReturnCode.TASK_ABORTED) self.log_info(fl_ctx, f"Validation result: {val_results}") # Create DXO for metrics and return shareable. metric_dxo = DXO(data_kind=DataKind.METRICS, data=val_results) return metric_dxo.to_shareable() except Exception as e: self.log_exception(fl_ctx, f"Exception in NPValidator execute: {secure_format_exception(e)}.") return make_reply(ReturnCode.EXECUTION_EXCEPTION) else: return make_reply(ReturnCode.TASK_UNKNOWN)