Source code for nvflare.app_common.ccwf.cse_server_ctl

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.dxo import from_shareable
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable
from nvflare.apis.signal import Signal
from nvflare.apis.workspace import Workspace
from nvflare.app_common.app_constant import AppConstants, ModelName
from nvflare.app_common.app_event_type import AppEventType
from nvflare.app_common.ccwf.common import Constant, ModelType, make_task_name
from nvflare.app_common.ccwf.server_ctl import ServerSideController
from nvflare.app_common.ccwf.val_result_manager import EvalResultManager
from nvflare.fuel.utils.validation_utils import (
    DefaultValuePolicy,
    check_positive_number,
    check_str,
    validate_candidate,
    validate_candidates,
)


[docs]class CrossSiteEvalServerController(ServerSideController): def __init__( self, task_name_prefix=Constant.TN_PREFIX_CROSS_SITE_EVAL, start_task_timeout=Constant.START_TASK_TIMEOUT, configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT, eval_task_timeout=30, task_check_period: float = Constant.TASK_CHECK_INTERVAL, job_status_check_interval: float = Constant.JOB_STATUS_CHECK_INTERVAL, progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, private_p2p: bool = True, participating_clients=None, evaluators=None, evaluatees=None, global_model_client=None, max_status_report_interval: float = Constant.PER_CLIENT_STATUS_REPORT_TIMEOUT, eval_result_dir=AppConstants.CROSS_VAL_DIR, ): if not evaluatees: evaluatees = [] if not evaluators: evaluators = [] super().__init__( num_rounds=1, task_name_prefix=task_name_prefix, start_task_timeout=start_task_timeout, configure_task_timeout=configure_task_timeout, task_check_period=task_check_period, job_status_check_interval=job_status_check_interval, participating_clients=participating_clients, starting_client="", starting_client_policy=DefaultValuePolicy.EMPTY, max_status_report_interval=max_status_report_interval, result_clients="", result_clients_policy=DefaultValuePolicy.EMPTY, progress_timeout=progress_timeout, private_p2p=private_p2p, ) check_str("eval_result_dir", eval_result_dir) check_positive_number("eval_task_timeout", eval_task_timeout) if not global_model_client: global_model_client = "" self.global_model_client = global_model_client self.eval_task_name = make_task_name(task_name_prefix, Constant.BASENAME_EVAL) self.eval_task_timeout = eval_task_timeout self.eval_local = False self.eval_global = False self.evaluators = evaluators self.evaluatees = evaluatees self.eval_result_dir = eval_result_dir self.global_names = {} self.eval_manager = None self.current_round = 0
[docs] def start_controller(self, fl_ctx: FLContext): super().start_controller(fl_ctx) self.evaluators = validate_candidates( var_name="evaluators", candidates=self.evaluators, base=self.participating_clients, default_policy=DefaultValuePolicy.ALL, allow_none=False, ) self.evaluatees = validate_candidates( var_name="evaluatees", candidates=self.evaluatees, base=self.participating_clients, default_policy=DefaultValuePolicy.ALL, allow_none=True, ) self.global_model_client = validate_candidate( var_name="global_model_client", candidate=self.global_model_client, base=self.participating_clients, default_policy=DefaultValuePolicy.ANY, allow_none=True, ) if self.global_model_client: self.eval_global = True if self.evaluatees: self.eval_local = True if not self.eval_global and not self.eval_local: raise RuntimeError("nothing to evaluate: you must set evaluatees and/or eval_global") workspace: Workspace = self._engine.get_workspace() run_dir = workspace.get_run_dir(fl_ctx.get_job_id()) cross_val_path = os.path.join(run_dir, self.eval_result_dir) cross_val_results_dir = os.path.join(cross_val_path, AppConstants.CROSS_VAL_RESULTS_DIR_NAME) self.eval_manager = EvalResultManager(cross_val_results_dir)
[docs] def prepare_config(self): return { Constant.EVAL_LOCAL: self.eval_local, Constant.EVAL_GLOBAL: self.eval_global, Constant.EVALUATORS: self.evaluators, Constant.EVALUATEES: self.evaluatees, Constant.GLOBAL_CLIENT: self.global_model_client, }
[docs] def process_config_reply(self, client_name: str, reply: Shareable, fl_ctx: FLContext) -> bool: global_names = reply.get(Constant.GLOBAL_NAMES) if global_names: for m in global_names: if m not in self.global_names: self.global_names[m] = client_name self.log_info(fl_ctx, f"got global model name {m} from {client_name}") return True
def _ask_to_evaluate( self, current_round: int, model_name: str, model_type: str, model_owner: str, fl_ctx: FLContext ): self.log_info( fl_ctx, f"R{current_round}: asking {self.evaluators} to evaluate {model_type} model '{model_name}' " f"on client '{model_owner}'", ) # Create validation task and broadcast to all participating clients. task_data = Shareable() task_data[AppConstants.CURRENT_ROUND] = current_round task_data[Constant.MODEL_OWNER] = model_owner # client that holds the model task_data[Constant.MODEL_NAME] = model_name task_data[Constant.MODEL_TYPE] = model_type task = Task( name=self.eval_task_name, data=task_data, result_received_cb=self._process_eval_result, timeout=self.eval_task_timeout, ) self.broadcast( task=task, fl_ctx=fl_ctx, targets=self.evaluators, min_responses=len(self.evaluators), wait_time_after_min_received=0, )
[docs] def sub_flow(self, abort_signal: Signal, fl_ctx: FLContext): if not self.global_names and not self.evaluatees: self.system_panic("there are neither global models nor local models to evaluate!", fl_ctx) return # ask everyone to evaluate global model if self.eval_global: if len(self.global_names) == 0: self.log_warning(fl_ctx, "no global models to evaluate!") for m, owner in self.global_names.items(): self._ask_to_evaluate( current_round=self.current_round, model_name=m, model_type=ModelType.GLOBAL, model_owner=owner, fl_ctx=fl_ctx, ) self.current_round += 1 # ask everyone to eval everyone else's local model train_clients = fl_ctx.get_prop(Constant.PROP_KEY_TRAIN_CLIENTS) for c in self.evaluatees: if train_clients and c not in train_clients: # this client does not have local models self.log_info(fl_ctx, f"ignore client {c} since it does not have local models") continue self._ask_to_evaluate( current_round=self.current_round, model_name=ModelName.BEST_MODEL, model_type=ModelType.LOCAL, model_owner=c, fl_ctx=fl_ctx, ) self.current_round += 1
[docs] def is_sub_flow_done(self, fl_ctx: FLContext) -> bool: return self.get_num_standing_tasks() == 0
def _process_eval_result(self, client_task: ClientTask, fl_ctx: FLContext): # Find name of the client sending this result = client_task.result client_name = client_task.client.name self._accept_eval_result(client_name=client_name, result=result, fl_ctx=fl_ctx) def _accept_eval_result(self, client_name: str, result: Shareable, fl_ctx: FLContext): model_owner = result.get_header(Constant.MODEL_OWNER, "") model_type = result.get_header(Constant.MODEL_TYPE) model_name = result.get_header(Constant.MODEL_NAME) if model_type == ModelType.GLOBAL: # global model model_owner = "GLOBAL_" + model_name model_info = model_owner else: model_info = f"{model_name} of {model_owner}" # Fire event. This needs to be a new local context per each client fl_ctx.set_prop(AppConstants.MODEL_OWNER, model_owner, private=True, sticky=False) fl_ctx.set_prop(AppConstants.DATA_CLIENT, client_name, private=True, sticky=False) fl_ctx.set_prop(AppConstants.VALIDATION_RESULT, result, private=True, sticky=False) self.fire_event(AppEventType.VALIDATION_RESULT_RECEIVED, fl_ctx) rc = result.get_return_code(ReturnCode.OK) if rc != ReturnCode.OK: self.log_error(fl_ctx, f"bad evaluation result from client {client_name} on model {model_info}") else: dxo = from_shareable(result) location = self.eval_manager.add_result(evaluatee=model_owner, evaluator=client_name, result=dxo) self.log_info(fl_ctx, f"saved evaluation result from {client_name} on model {model_info} in {location}")