Source code for nvflare.app_common.ccwf.cse_client_ctl

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading

from nvflare.apis.controller_spec import Task
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.model import model_learnable_to_dxo
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_common.app_constant import AppConstants, ValidateType
from nvflare.app_common.ccwf.client_ctl import ClientSideController
from nvflare.app_common.ccwf.common import Constant, ModelType, make_task_name
from nvflare.fuel.utils.validation_utils import check_non_empty_str, check_positive_number
from nvflare.security.logging import secure_format_traceback


[docs]class CrossSiteEvalClientController(ClientSideController): def __init__( self, task_name_prefix=Constant.TN_PREFIX_CROSS_SITE_EVAL, submit_model_task_name=AppConstants.TASK_SUBMIT_MODEL, validation_task_name=AppConstants.TASK_VALIDATION, persistor_id=AppConstants.DEFAULT_PERSISTOR_ID, get_model_timeout=Constant.GET_MODEL_TIMEOUT, ): check_positive_number("get_model_timeout", get_model_timeout) check_non_empty_str("submit_model_task_name", submit_model_task_name) check_non_empty_str("validation_task_name", validation_task_name) check_non_empty_str("persistor_id", persistor_id) super().__init__( task_name_prefix=task_name_prefix, learn_task_name="", shareable_generator_id="", persistor_id=persistor_id, ) self.eval_task_name = make_task_name(task_name_prefix, Constant.BASENAME_EVAL) self.ask_for_model_task_name = make_task_name(task_name_prefix, Constant.BASENAME_ASK_FOR_MODEL) self.submit_model_task_name = submit_model_task_name # this is for the learner executor self.validation_task_name = validation_task_name self.my_local_model = None self.global_model_inventory = None self.submit_model_executor = None self.validate_executor = None self.inventory = None self.get_model_timeout = get_model_timeout self.local_model = None self.model_lock = threading.Lock()
[docs] def start_run(self, fl_ctx: FLContext): super().start_run(fl_ctx) runner = fl_ctx.get_prop(FLContextKey.RUNNER) if self.submit_model_task_name: self.submit_model_executor = runner.find_executor(self.submit_model_task_name) if not self.submit_model_executor: self.system_panic(f"no executor for task {self.submit_model_task_name}", fl_ctx) return if self.validation_task_name: self.validate_executor = runner.find_executor(self.validation_task_name) if not self.validate_executor: self.system_panic(f"no executor for task {self.validation_task_name}", fl_ctx) return
[docs] def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: if task_name == self.eval_task_name: # server assigned task return self.do_eval(shareable, fl_ctx, abort_signal) elif task_name == self.ask_for_model_task_name: # client-assigned task return self._process_get_model_request(shareable, fl_ctx) return super().execute(task_name, shareable, fl_ctx, abort_signal)
[docs] def process_config(self, fl_ctx: FLContext): eval_local = self.get_config_prop(Constant.EVAL_LOCAL) eval_global = self.get_config_prop(Constant.EVAL_GLOBAL) evaluators = self.get_config_prop(Constant.EVALUATORS) evaluatees = self.get_config_prop(Constant.EVALUATEES) global_client = self.get_config_prop(Constant.GLOBAL_CLIENT) if eval_local and self.me in evaluatees: # do I have any local model? if not self.submit_model_executor: return make_reply(Constant.RC_NO_LOCAL_MODEL) if self.me in evaluators: # I am required to evaluate others if not self.validate_executor: return make_reply(Constant.RC_UNABLE_TO_EVAL) reply = make_reply(ReturnCode.OK) if eval_global and self.me == global_client: # do I have global models? assert isinstance(self.persistor, ModelPersistor) self.inventory = self.persistor.get_model_inventory(fl_ctx) if self.inventory: assert isinstance(self.inventory, dict) reply[Constant.GLOBAL_NAMES] = list(self.inventory.keys()) return reply
[docs] def start_workflow(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: pass
[docs] def do_eval(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: model_type = shareable.get(Constant.MODEL_TYPE) model_owner = shareable.get(Constant.MODEL_OWNER) model_name = shareable.get(Constant.MODEL_NAME) current_round = shareable.get(AppConstants.CURRENT_ROUND) # ask the model owner for model req = Shareable() req[Constant.MODEL_NAME] = model_name req[Constant.MODEL_TYPE] = model_type if not self.validate_executor: self.log_error(fl_ctx, "got eval request but I don't have a validator") return make_reply(Constant.RC_UNABLE_TO_EVAL) self.update_status(action="eval:get_model", last_round=current_round) self.log_info(fl_ctx, f"asking client {model_owner} for model {model_type} {model_name}") task = Task( name=self.ask_for_model_task_name, data=req, timeout=int(self.get_model_timeout), secure=self.is_task_secure(fl_ctx), ) resp = self.broadcast_and_wait( task=task, targets=[model_owner], min_responses=1, fl_ctx=fl_ctx, ) assert isinstance(resp, dict) reply = resp.get(model_owner) if not reply: self.log_error(fl_ctx, f"failed to ask client {model_owner} for model") return make_reply(ReturnCode.EXECUTION_EXCEPTION) if not isinstance(reply, Shareable): self.log_error(fl_ctx, f"client {model_owner} failed to respond to get-model request") return make_reply(ReturnCode.EXECUTION_EXCEPTION) rc = reply.get_return_code() if rc != ReturnCode.OK: self.log_error(fl_ctx, f"client {model_owner} failed to respond to share final result request: {rc}") return make_reply(rc) if abort_signal.triggered: return make_reply(ReturnCode.TASK_ABORTED) model_to_validate = reply model_to_validate.set_header(AppConstants.VALIDATE_TYPE, ValidateType.MODEL_VALIDATE) model_to_validate.set_header(FLContextKey.TASK_NAME, self.validation_task_name) if model_type == ModelType.LOCAL: model_to_validate.set_header(AppConstants.MODEL_OWNER, model_owner) self.update_status(action="eval:validate", last_round=current_round) result = self.validate_executor.execute( task_name=self.validation_task_name, shareable=model_to_validate, abort_signal=abort_signal, fl_ctx=fl_ctx ) self.update_status(action="eval:finished", last_round=current_round) assert isinstance(result, Shareable) result.set_header(Constant.MODEL_TYPE, model_type) result.set_header(Constant.MODEL_NAME, model_name) result.set_header(Constant.MODEL_OWNER, model_owner) result.set_header(AppConstants.CURRENT_ROUND, current_round) return result
def _process_get_model_request(self, request: Shareable, fl_ctx: FLContext) -> Shareable: with self.model_lock: return self._do_process_get_model_request(request, fl_ctx) def _do_process_get_model_request(self, request: Shareable, fl_ctx: FLContext) -> Shareable: peer_ctx = fl_ctx.get_peer_context() assert isinstance(peer_ctx, FLContext) client_name = peer_ctx.get_identity_name() model_type = request.get(Constant.MODEL_TYPE) model_name = request.get(Constant.MODEL_NAME) if model_type == ModelType.GLOBAL: # get it from model inventory if not self.inventory: self.log_error( fl_ctx, f"got request for global model from client {client_name} but I don't have global models" ) return make_reply(ReturnCode.BAD_REQUEST_DATA) assert isinstance(self.persistor, ModelPersistor) model_learnable = self.persistor.get(model_name, fl_ctx) dxo = model_learnable_to_dxo(model_learnable) self.log_info(fl_ctx, f"sent global model {model_name} to client {client_name}") return dxo.to_shareable() # local model if not self.submit_model_executor: self.log_error( fl_ctx, f"got request for local model from client {client_name} but I don't have local models" ) return make_reply(ReturnCode.BAD_REQUEST_DATA) if not self.local_model: task_data = Shareable() task_data.set_header(AppConstants.SUBMIT_MODEL_NAME, model_name) task_data.set_header(FLContextKey.TASK_NAME, self.submit_model_task_name) abort_signal = Signal() try: result = self.submit_model_executor.execute( task_name=self.submit_model_task_name, shareable=task_data, fl_ctx=fl_ctx, abort_signal=abort_signal ) except: self.log_error( fl_ctx, f"failed to get local model from submit_model_executor: {secure_format_traceback()}" ) return make_reply(ReturnCode.EXECUTION_EXCEPTION) assert isinstance(result, Shareable) rc = result.get_return_code(ReturnCode.OK) if rc != ReturnCode.OK: self.log_error(fl_ctx, f"failed to get local model from submit_model_executor: {rc}") return make_reply(ReturnCode.EXECUTION_EXCEPTION) self.local_model = result self.log_info(fl_ctx, f"sent local model {model_name} to client {client_name}") return self.local_model
[docs] def do_learn_task(self, name: str, task_data: Shareable, fl_ctx: FLContext, abort_signal: Signal): pass