Source code for nvflare.app_common.executors.learner_executor

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.apis.dxo import MetaKey, from_shareable
from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.learner_spec import Learner
from nvflare.app_common.app_constant import AppConstants, ValidateType
from nvflare.security.logging import secure_format_exception


[docs]class LearnerExecutor(Executor): def __init__( self, learner_id, train_task=AppConstants.TASK_TRAIN, submit_model_task=AppConstants.TASK_SUBMIT_MODEL, validate_task=AppConstants.TASK_VALIDATION, ): """Key component to run learner on clients. Args: learner_id (str): id of the learner object train_task (str, optional): task name for train. Defaults to AppConstants.TASK_TRAIN. submit_model_task (str, optional): task name for submit model. Defaults to AppConstants.TASK_SUBMIT_MODEL. validate_task (str, optional): task name for validation. Defaults to AppConstants.TASK_VALIDATION. """ super().__init__() self.learner_id = learner_id self.learner = None self.train_task = train_task self.submit_model_task = submit_model_task self.validate_task = validate_task self.is_initialized = False
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.ABORT_TASK: try: if self.learner: if not self.unsafe: self.learner.abort(fl_ctx) else: self.log_warning(fl_ctx, f"skipped abort of unsafe learner {self.learner.__class__.__name__}") except Exception as e: self.log_exception(fl_ctx, f"learner abort exception: {secure_format_exception(e)}") elif event_type == EventType.END_RUN: if not self.unsafe: self.finalize(fl_ctx) elif self.learner: self.log_warning(fl_ctx, f"skipped finalize of unsafe learner {self.learner.__class__.__name__}")
[docs] def initialize(self, fl_ctx: FLContext): try: engine = fl_ctx.get_engine() self.learner = engine.get_component(self.learner_id) if not isinstance(self.learner, Learner): raise TypeError(f"learner must be Learner type. Got: {type(self.learner)}") self.learner.initialize(engine.get_all_components(), fl_ctx) except Exception as e: self.log_exception(fl_ctx, f"learner initialize exception: {secure_format_exception(e)}") raise e
[docs] def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: self.log_info(fl_ctx, f"Client trainer got task: {task_name}") if not self.is_initialized: self.is_initialized = True self.initialize(fl_ctx) if task_name == self.train_task: return self.train(shareable, fl_ctx, abort_signal) elif task_name == self.submit_model_task: return self.submit_model(shareable, fl_ctx) elif task_name == self.validate_task: return self.validate(shareable, fl_ctx, abort_signal) else: self.log_error(fl_ctx, f"Could not handle task: {task_name}") return make_reply(ReturnCode.TASK_UNKNOWN)
[docs] def train(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: self.log_debug(fl_ctx, f"train abort signal: {abort_signal.triggered}") shareable.set_header(AppConstants.VALIDATE_TYPE, ValidateType.BEFORE_TRAIN_VALIDATE) validate_result: Shareable = self.learner.validate(shareable, fl_ctx, abort_signal) train_result = self.learner.train(shareable, fl_ctx, abort_signal) if not (train_result and isinstance(train_result, Shareable)): return make_reply(ReturnCode.EMPTY_RESULT) # if the learner returned the valid BEFORE_TRAIN_VALIDATE result, set the INITIAL_METRICS in # the train result, which can be used for best model selection. if ( validate_result and isinstance(validate_result, Shareable) and validate_result.get_return_code() == ReturnCode.OK ): try: metrics_dxo = from_shareable(validate_result) train_dxo = from_shareable(train_result) train_dxo.meta[MetaKey.INITIAL_METRICS] = metrics_dxo.data.get(MetaKey.INITIAL_METRICS, 0) return train_dxo.to_shareable() except ValueError: return train_result else: return train_result
[docs] def submit_model(self, shareable: Shareable, fl_ctx: FLContext) -> Shareable: model_name = shareable.get_header(AppConstants.SUBMIT_MODEL_NAME) submit_model_result = self.learner.get_model_for_validation(model_name, fl_ctx) if submit_model_result and isinstance(submit_model_result, Shareable): return submit_model_result else: return make_reply(ReturnCode.EMPTY_RESULT)
[docs] def validate(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: self.log_debug(fl_ctx, f"validate abort_signal {abort_signal.triggered}") shareable.set_header(AppConstants.VALIDATE_TYPE, ValidateType.MODEL_VALIDATE) validate_result: Shareable = self.learner.validate(shareable, fl_ctx, abort_signal) if validate_result and isinstance(validate_result, Shareable): return validate_result else: return make_reply(ReturnCode.EMPTY_RESULT)
[docs] def finalize(self, fl_ctx: FLContext): try: if self.learner: self.learner.finalize(fl_ctx) except Exception as e: self.log_exception(fl_ctx, f"learner finalize exception: {secure_format_exception(e)}")