# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import time
from typing import Union
from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.dxo import DXO, from_file, from_shareable, get_leaf_dxos
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.apis.workspace import Workspace
from nvflare.app_common.abstract.formatter import Formatter
from nvflare.app_common.abstract.model_locator import ModelLocator
from nvflare.app_common.app_constant import AppConstants, ModelName
from nvflare.app_common.app_event_type import AppEventType
from nvflare.security.logging import secure_format_exception
from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector
[docs]class CrossSiteModelEval(Controller):
def __init__(
self,
task_check_period=0.5,
cross_val_dir=AppConstants.CROSS_VAL_DIR,
submit_model_timeout=600,
validation_timeout: int = 6000,
model_locator_id="",
formatter_id="",
submit_model_task_name=AppConstants.TASK_SUBMIT_MODEL,
validation_task_name=AppConstants.TASK_VALIDATION,
cleanup_models=False,
participating_clients=None,
wait_for_clients_timeout=300,
):
"""Cross Site Model Validation workflow.
Args:
task_check_period (float, optional): How often to check for new tasks or tasks being finished.
Defaults to 0.5.
cross_val_dir (str, optional): Path to cross site validation directory relative to run directory.
Defaults to "cross_site_val".
submit_model_timeout (int, optional): Timeout of submit_model_task. Defaults to 600 secs.
validation_timeout (int, optional): Timeout for validate_model task. Defaults to 6000 secs.
model_locator_id (str, optional): ID for model_locator component. Defaults to "".
formatter_id (str, optional): ID for formatter component. Defaults to "".
submit_model_task_name (str, optional): Name of submit_model task. Defaults to "".
validation_task_name (str, optional): Name of validate_model task. Defaults to "validate".
cleanup_models (bool, optional): Whether or not models should be deleted after run. Defaults to False.
participating_clients (list, optional): List of participating client names. If not provided, defaults
to all clients connected at start of controller.
wait_for_clients_timeout (int, optional): Timeout for clients to appear. Defaults to 300 secs
"""
super().__init__(task_check_period=task_check_period)
if not isinstance(task_check_period, float):
raise TypeError("task_check_period must be float but got {}".format(type(task_check_period)))
if not isinstance(cross_val_dir, str):
raise TypeError("cross_val_dir must be a string but got {}".format(type(cross_val_dir)))
if not isinstance(submit_model_timeout, int):
raise TypeError("submit_model_timeout must be int but got {}".format(type(submit_model_timeout)))
if not isinstance(validation_timeout, int):
raise TypeError("validation_timeout must be int but got {}".format(type(validation_timeout)))
if not isinstance(model_locator_id, str):
raise TypeError("model_locator_id must be a string but got {}".format(type(model_locator_id)))
if not isinstance(formatter_id, str):
raise TypeError("formatter_id must be a string but got {}".format(type(formatter_id)))
if not isinstance(submit_model_task_name, str):
raise TypeError("submit_model_task_name must be a string but got {}".format(type(submit_model_task_name)))
if not isinstance(validation_task_name, str):
raise TypeError("validation_task_name must be a string but got {}".format(type(validation_task_name)))
if not isinstance(cleanup_models, bool):
raise TypeError("cleanup_models must be bool but got {}".format(type(cleanup_models)))
if participating_clients:
if not isinstance(participating_clients, list):
raise TypeError("participating_clients must be a list but got {}".format(type(participating_clients)))
if not all(isinstance(x, str) for x in participating_clients):
raise TypeError("participating_clients must be strings")
if submit_model_timeout < 0:
raise ValueError("submit_model_timeout must be greater than or equal to 0.")
if validation_timeout < 0:
raise ValueError("model_validate_timeout must be greater than or equal to 0.")
if wait_for_clients_timeout < 0:
raise ValueError("wait_for_clients_timeout must be greater than or equal to 0.")
self._cross_val_dir = cross_val_dir
self._model_locator_id = model_locator_id
self._formatter_id = formatter_id
self._submit_model_task_name = submit_model_task_name
self._validation_task_name = validation_task_name
self._submit_model_timeout = submit_model_timeout
self._validation_timeout = validation_timeout
self._wait_for_clients_timeout = wait_for_clients_timeout
self._cleanup_models = cleanup_models
self._participating_clients = participating_clients
self._val_results = {}
self._server_models = {}
self._client_models = {}
self._formatter = None
self._cross_val_models_dir = None
self._cross_val_results_dir = None
self._model_locator = None
[docs] def start_controller(self, fl_ctx: FLContext):
# If the list of participating clients is not provided, include all clients currently available.
if not self._participating_clients:
clients = self._engine.get_clients()
self._participating_clients = [c.name for c in clients]
# Create shareable dirs for models and results
workspace: Workspace = self._engine.get_workspace()
run_dir = workspace.get_run_dir(fl_ctx.get_job_id())
cross_val_path = os.path.join(run_dir, self._cross_val_dir)
self._cross_val_models_dir = os.path.join(cross_val_path, AppConstants.CROSS_VAL_MODEL_DIR_NAME)
self._cross_val_results_dir = os.path.join(cross_val_path, AppConstants.CROSS_VAL_RESULTS_DIR_NAME)
# Fire the init event.
fl_ctx.set_prop(AppConstants.CROSS_VAL_MODEL_PATH, self._cross_val_models_dir)
fl_ctx.set_prop(AppConstants.CROSS_VAL_RESULTS_PATH, self._cross_val_results_dir)
self.fire_event(AppEventType.CROSS_VAL_INIT, fl_ctx)
# Cleanup/create the cross val models and results directories
if os.path.exists(self._cross_val_models_dir):
shutil.rmtree(self._cross_val_models_dir)
if os.path.exists(self._cross_val_results_dir):
shutil.rmtree(self._cross_val_results_dir)
# Recreate new directories.
os.makedirs(self._cross_val_models_dir)
os.makedirs(self._cross_val_results_dir)
# Get components
if self._model_locator_id:
self._model_locator = self._engine.get_component(self._model_locator_id)
if not isinstance(self._model_locator, ModelLocator):
self.system_panic(
reason="bad model locator {}: expect ModelLocator but got {}".format(
self._model_locator_id, type(self._model_locator)
),
fl_ctx=fl_ctx,
)
return
if self._formatter_id:
self._formatter = self._engine.get_component(self._formatter_id)
if not isinstance(self._formatter, Formatter):
self.system_panic(
reason=f"formatter {self._formatter_id} is not an instance of Formatter.", fl_ctx=fl_ctx
)
return
if not self._formatter:
self.log_info(fl_ctx, "Formatter not found. Stats will not be printed.")
for c_name in self._participating_clients:
self._client_models[c_name] = None
self._val_results[c_name] = {}
[docs] def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
try:
# wait until there are some clients
start_time = time.time()
while not self._participating_clients:
self._participating_clients = [c.name for c in self._engine.get_clients()]
if time.time() - start_time > self._wait_for_clients_timeout:
self.log_info(fl_ctx, "No clients available - quit model validation.")
return
self.log_info(fl_ctx, "No clients available - waiting ...")
time.sleep(2.0)
if abort_signal.triggered:
self.log_info(fl_ctx, "Abort signal triggered. Finishing model validation.")
return
self.log_info(fl_ctx, f"Beginning model validation with clients: {self._participating_clients}.")
if self._submit_model_task_name:
shareable = Shareable()
shareable.set_header(AppConstants.SUBMIT_MODEL_NAME, ModelName.BEST_MODEL)
submit_model_task = Task(
name=self._submit_model_task_name,
data=shareable,
result_received_cb=self._receive_local_model_cb,
timeout=self._submit_model_timeout,
)
self.broadcast(
task=submit_model_task,
targets=self._participating_clients,
fl_ctx=fl_ctx,
min_responses=len(self._participating_clients),
)
if abort_signal.triggered:
self.log_info(fl_ctx, "Abort signal triggered. Finishing model validation.")
return
# Load server models and assign those tasks
if self._model_locator:
success = self._locate_server_models(fl_ctx)
if not success:
return
for server_model in self._server_models:
self._send_validation_task(server_model, fl_ctx)
else:
self.log_info(fl_ctx, "ModelLocator not present. No server models will be included.")
while self.get_num_standing_tasks():
if abort_signal.triggered:
self.log_info(fl_ctx, "Abort signal triggered. Finishing cross site validation.")
return
self.log_debug(fl_ctx, "Checking standing tasks to see if cross site validation finished.")
time.sleep(self._task_check_period)
except Exception as e:
error_msg = f"Exception in cross site validator control_flow: {secure_format_exception(e)}"
self.log_exception(fl_ctx, error_msg)
self.system_panic(error_msg, fl_ctx)
[docs] def stop_controller(self, fl_ctx: FLContext):
self.cancel_all_tasks(fl_ctx=fl_ctx)
if self._cleanup_models:
self.log_info(fl_ctx, "Removing local models kept for validation.")
for model_name, model_path in self._server_models.items():
if model_path and os.path.isfile(model_path):
os.remove(model_path)
self.log_debug(fl_ctx, f"Removing server model {model_name} at {model_path}.")
for model_name, model_path in self._client_models.items():
if model_path and os.path.isfile(model_path):
os.remove(model_path)
self.log_debug(fl_ctx, f"Removing client {model_name}'s model at {model_path}.")
def _receive_local_model_cb(self, client_task: ClientTask, fl_ctx: FLContext):
client_name = client_task.client.name
result: Shareable = client_task.result
self._accept_local_model(client_name=client_name, result=result, fl_ctx=fl_ctx)
# Cleanup task result
client_task.result = None
def _before_send_validate_task_cb(self, client_task: ClientTask, fl_ctx: FLContext):
model_name = client_task.task.props[AppConstants.MODEL_OWNER]
try:
model_dxo: DXO = self._load_validation_content(model_name, self._cross_val_models_dir, fl_ctx)
except ValueError as e:
reason = f"Error in loading model shareable for {model_name}: {secure_format_exception(e)}. CrossSiteModelEval exiting."
self.log_error(fl_ctx, reason)
self.system_panic(reason, fl_ctx)
return
if not model_dxo:
self.system_panic(
f"Model contents for {model_name} not found in {self._cross_val_models_dir}. "
"CrossSiteModelEval exiting",
fl_ctx=fl_ctx,
)
return
model_shareable = model_dxo.to_shareable()
model_shareable.set_header(AppConstants.MODEL_OWNER, model_name)
model_shareable.add_cookie(AppConstants.MODEL_OWNER, model_name)
client_task.task.data = model_shareable
fl_ctx.set_prop(AppConstants.DATA_CLIENT, client_task.client, private=True, sticky=False)
fl_ctx.set_prop(AppConstants.MODEL_OWNER, model_name, private=True, sticky=False)
fl_ctx.set_prop(AppConstants.MODEL_TO_VALIDATE, model_shareable, private=True, sticky=False)
fl_ctx.set_prop(AppConstants.PARTICIPATING_CLIENTS, self._participating_clients, private=True, sticky=False)
self.fire_event(AppEventType.SEND_MODEL_FOR_VALIDATION, fl_ctx)
def _after_send_validate_task_cb(self, client_task: ClientTask, fl_ctx: FLContext):
# Once task is sent clear data to restore memory
client_task.task.data = None
def _receive_val_result_cb(self, client_task: ClientTask, fl_ctx: FLContext):
# Find name of the client sending this
result = client_task.result
client_name = client_task.client.name
self._accept_val_result(client_name=client_name, result=result, fl_ctx=fl_ctx)
client_task.result = None
def _locate_server_models(self, fl_ctx: FLContext) -> bool:
# Load models from model_locator
self.log_info(fl_ctx, "Locating server models.")
server_model_names = self._model_locator.get_model_names(fl_ctx)
unique_names = []
for name in server_model_names:
# Get the model
dxo = self._model_locator.locate_model(name, fl_ctx)
if not isinstance(dxo, DXO):
self.system_panic(f"ModelLocator produced invalid data: expect DXO but got {type(dxo)}.", fl_ctx)
return False
# Save to workspace
unique_name = "SRV_" + name
unique_names.append(unique_name)
try:
save_path = self._save_dxo_content(unique_name, self._cross_val_models_dir, dxo, fl_ctx)
except:
self.log_exception(fl_ctx, f"Unable to save shareable contents of server model {unique_name}")
self.system_panic(f"Unable to save shareable contents of server model {unique_name}", fl_ctx)
return False
self._server_models[unique_name] = save_path
self._val_results[unique_name] = {}
if unique_names:
self.log_info(fl_ctx, f"Server models loaded: {unique_names}.")
else:
self.log_info(fl_ctx, "no server models to validate!")
return True
def _accept_local_model(self, client_name: str, result: Shareable, fl_ctx: FLContext):
fl_ctx.set_prop(AppConstants.RECEIVED_MODEL, result, private=True, sticky=False)
fl_ctx.set_prop(AppConstants.RECEIVED_MODEL_OWNER, client_name, private=True, sticky=False)
fl_ctx.set_prop(AppConstants.CROSS_VAL_DIR, self._cross_val_dir, private=True, sticky=False)
self.fire_event(AppEventType.RECEIVE_BEST_MODEL, fl_ctx)
# get return code
rc = result.get_return_code()
if rc and rc != ReturnCode.OK:
# Raise errors if bad peer context or execution exception.
if rc in [ReturnCode.MISSING_PEER_CONTEXT, ReturnCode.BAD_PEER_CONTEXT]:
self.log_error(fl_ctx, "Peer context is bad or missing. No model submitted for this client.")
elif rc in [ReturnCode.EXECUTION_EXCEPTION, ReturnCode.TASK_UNKNOWN]:
self.log_error(
fl_ctx, "Execution Exception on client during model submission. No model submitted for this client."
)
# Ignore contribution if result invalid.
elif rc in [
ReturnCode.EXECUTION_RESULT_ERROR,
ReturnCode.TASK_DATA_FILTER_ERROR,
ReturnCode.TASK_RESULT_FILTER_ERROR,
ReturnCode.TASK_UNKNOWN,
]:
self.log_error(fl_ctx, "Execution result is not a shareable. Model submission will be ignored.")
else:
self.log_error(fl_ctx, "Return code set. Model submission from client will be ignored.")
else:
# Save shareable in models directory.
try:
self.log_debug(fl_ctx, "Extracting DXO from shareable.")
dxo = from_shareable(result)
except ValueError as e:
self.log_error(
fl_ctx,
f"Ignored bad result from {client_name}: {secure_format_exception(e)}",
)
return
# The DXO could contain multiple sub-DXOs (e.g. received from a T2 system)
leaf_dxos, errors = get_leaf_dxos(dxo, client_name)
if errors:
for err in errors:
self.log_error(fl_ctx, f"Bad result from {client_name}: {err}")
for k, v in leaf_dxos.items():
self._save_client_model(k, v, fl_ctx)
def _save_client_model(self, model_name: str, dxo: DXO, fl_ctx: FLContext):
save_path = self._save_dxo_content(model_name, self._cross_val_models_dir, dxo, fl_ctx)
self.log_info(fl_ctx, f"Saved client model {model_name} to {save_path}")
self._client_models[model_name] = save_path
# Send a model to this client to validate
self._send_validation_task(model_name, fl_ctx)
def _send_validation_task(self, model_name: str, fl_ctx: FLContext):
self.log_info(fl_ctx, f"Sending {model_name} model to all participating clients for validation.")
# Create validation task and broadcast to all participating clients.
task = Task(
name=self._validation_task_name,
data=Shareable(),
before_task_sent_cb=self._before_send_validate_task_cb,
after_task_sent_cb=self._after_send_validate_task_cb,
result_received_cb=self._receive_val_result_cb,
timeout=self._validation_timeout,
props={AppConstants.MODEL_OWNER: model_name},
)
self.broadcast(
task=task,
fl_ctx=fl_ctx,
targets=self._participating_clients,
min_responses=len(self._participating_clients),
wait_time_after_min_received=0,
)
def _accept_val_result(self, client_name: str, result: Shareable, fl_ctx: FLContext):
model_owner = result.get_cookie(AppConstants.MODEL_OWNER, "")
# Fire event. This needs to be a new local context per each client
fl_ctx.set_prop(AppConstants.MODEL_OWNER, model_owner, private=True, sticky=False)
fl_ctx.set_prop(AppConstants.DATA_CLIENT, client_name, private=True, sticky=False)
fl_ctx.set_prop(AppConstants.VALIDATION_RESULT, result, private=True, sticky=False)
self.fire_event(AppEventType.VALIDATION_RESULT_RECEIVED, fl_ctx)
rc = result.get_return_code()
if rc and rc != ReturnCode.OK:
# Raise errors if bad peer context or execution exception.
if rc in [ReturnCode.MISSING_PEER_CONTEXT, ReturnCode.BAD_PEER_CONTEXT]:
self.log_error(fl_ctx, "Peer context is bad or missing.")
elif rc in [ReturnCode.EXECUTION_EXCEPTION, ReturnCode.TASK_UNKNOWN]:
self.log_error(fl_ctx, "Execution Exception in model validation.")
elif rc in [
ReturnCode.EXECUTION_RESULT_ERROR,
ReturnCode.TASK_DATA_FILTER_ERROR,
ReturnCode.TASK_RESULT_FILTER_ERROR,
]:
self.log_error(fl_ctx, "Execution result is not a shareable. Validation results will be ignored.")
else:
self.log_error(
fl_ctx,
f"Client {client_name} sent results for validating {model_owner} model with return code set."
" Logging empty results.",
)
if client_name not in self._val_results:
self._val_results[client_name] = {}
self._val_results[client_name][model_owner] = {}
else:
try:
dxo = from_shareable(result)
except ValueError as e:
reason = (
f"Bad validation result from {client_name} on model {model_owner}. "
f"Exception: {secure_format_exception(e)}"
)
self.log_exception(fl_ctx, reason)
return
# The DXO could contain multiple sub-DXOs (e.g. received from a T2 system)
leaf_dxos, errors = get_leaf_dxos(dxo, client_name)
if errors:
for err in errors:
self.log_error(fl_ctx, f"Bad result from {client_name}: {err}")
for k, v in leaf_dxos.items():
self._save_validation_result(k, model_owner, v, fl_ctx)
def _save_validation_result(self, client_name: str, model_name: str, dxo, fl_ctx):
file_name = client_name + "_" + model_name
file_path = self._save_dxo_content(file_name, self._cross_val_results_dir, dxo, fl_ctx)
client_results = self._val_results.get(client_name, None)
if not client_results:
client_results = {}
self._val_results[client_name] = client_results
client_results[model_name] = file_path
self.log_info(
fl_ctx, f"Saved validation result from client '{client_name}' on model '{model_name}' in {file_path}"
)
def _save_dxo_content(self, name: str, save_dir: str, dxo: DXO, fl_ctx: FLContext) -> str:
"""Saves shareable to given directory within the app_dir.
Args:
name (str): Name of shareable
save_dir (str): Relative path to directory in which to save
dxo (DXO): DXO object
fl_ctx (FLContext): FLContext object
Returns:
str: Path to the file saved.
"""
# Save the model with name as the filename to shareable directory
data_filename = os.path.join(save_dir, name)
try:
dxo.to_file(data_filename)
except Exception as e:
raise ValueError(f"Unable to save DXO to {data_filename}: {secure_format_exception(e)}")
return data_filename
def _load_validation_content(self, name: str, load_dir: str, fl_ctx: FLContext) -> Union[DXO, None]:
# Load shareable from disk
shareable_filename = os.path.join(load_dir, name)
# load shareable
try:
dxo: DXO = from_file(shareable_filename)
self.log_debug(fl_ctx, f"Loading cross validation shareable content with name: {name}.")
except Exception as e:
raise ValueError(f"Exception in loading shareable content for {name}: {secure_format_exception(e)}")
return dxo
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext):
super().handle_event(event_type=event_type, fl_ctx=fl_ctx)
if event_type == InfoCollector.EVENT_TYPE_GET_STATS:
if self._formatter:
collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR, None)
if collector:
if not isinstance(collector, GroupInfoCollector):
raise TypeError("collector must be GroupInfoCollector but got {}".format(type(collector)))
fl_ctx.set_prop(AppConstants.VALIDATION_RESULT, self._val_results, private=True, sticky=False)
val_info = self._formatter.format(fl_ctx)
collector.add_info(
group_name=self._name,
info={"val_results": val_info},
)
else:
self.log_warning(fl_ctx, "No formatter provided. Validation results can't be printed.")
[docs] def process_result_of_unknown_task(
self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext
):
if task_name == self._submit_model_task_name:
self._accept_local_model(client_name=client.name, result=result, fl_ctx=fl_ctx)
elif task_name == self._validation_task_name:
self._accept_val_result(client_name=client.name, result=result, fl_ctx=fl_ctx)
else:
self.log_error(fl_ctx, "Ignoring result from unknown task.")