Hello Cross-Site Validation

Before You Start

Before jumping into this guide, make sure you have an environment with NVIDIA FLARE installed.

You can follow Getting Started on the general concept of setting up a Python virtual environment (the recommended environment) and how to install NVIDIA FLARE.

Prerequisite

This example builds on the Hello Scatter and Gather example based on the ScatterAndGather workflow.

Please make sure you go through it completely as the concepts are heavily tied.

Introduction

This tutorial is meant to solely demonstrate how the NVIDIA FLARE system works, without introducing any actual deep learning concepts.

Through this exercise, you will learn how to use NVIDIA FLARE with numpy to perform cross site validation after training.

The training process is explained in the Hello Scatter and Gather example.

Using simplified weights and metrics, you will be able to clearly see how NVIDIA FLARE performs validation across different sites with little extra work.

The setup of this exercise consists of one server and two clients. The server side model starting with weights [[1, 2, 3], [4, 5, 6], [7, 8, 9]].

Cross site validation consists of the following steps:

  • During the initial phase of training with the ScatterAndGather workflow, NPTrainer saves the local model to disk for the clients.

  • The CrossSiteModelEval workflow gets the client models with the submit_model task.

  • The validate task is broadcast to the all participating clients with the model shareable containing the model data, and results from the validate task are saved.

During this exercise, we will see how NVIDIA FLARE takes care of most of the above steps with little work from the user. We will be working with the hello-numpy-cross-val application in the examples folder. Custom FL applications can contain the folders:

  1. custom: contains the custom components (np_trainer.py, np_model_persistor.py, np_validator.py, np_model_locator, np_formatter)

  2. config: contains client and server configurations (config_fed_client.json, config_fed_server.json)

  3. resources: contains the logger config (log_config.json)

Let’s get started. First clone the repo, if you haven’t already:

$ git clone https://github.com/NVIDIA/NVFlare.git

Remember to activate your NVIDIA FLARE Python virtual environment from the installation guide. Ensure numpy is installed.

(nvflare-env) $ python3 -m pip install numpy

Now that you have all your dependencies installed, let’s implement the Federated Learning system.

Training

In the Hello Scatter and Gather example, we implemented the NPTrainer object. In this example, we use the same NPTrainer but extend it to process the submit_model task to work with the CrossSiteModelEval workflow to get the client models.

The code in np_trainer.py saves the model to disk after each step of training in the model.

Note that the server also produces a global model. The CrossSiteModelEval workflow submits the server model for evaluation after the client models.

Implementing the Validator

The validator is an Executor that is called for validating the models received from the server during the CrossSiteModelEval workflow.

These models could be from other clients or models generated on server.

np_validator.py
 15import time
 16
 17import numpy as np
 18
 19from nvflare.apis.dxo import DXO, DataKind, from_shareable
 20from nvflare.apis.executor import Executor
 21from nvflare.apis.fl_constant import ReturnCode
 22from nvflare.apis.fl_context import FLContext
 23from nvflare.apis.shareable import Shareable, make_reply
 24from nvflare.apis.signal import Signal
 25from nvflare.app_common.app_constant import AppConstants
 26from nvflare.fuel.utils.log_utils import get_obj_logger
 27from nvflare.security.logging import secure_format_exception
 28
 29from .constants import NPConstants
 30
 31
 32class NPValidator(Executor):
 33    def __init__(
 34        self,
 35        epsilon=1,
 36        sleep_time=0,
 37        validate_task_name=AppConstants.TASK_VALIDATION,
 38    ):
 39        # Init functions of components should be very minimal. Init
 40        # is called when json is read. A big init will cause json loading to halt
 41        # for long time.
 42        super().__init__()
 43
 44        self.logger = get_obj_logger(self)
 45        self._random_epsilon = epsilon
 46        self._sleep_time = sleep_time
 47        self._validate_task_name = validate_task_name
 48
 49    def handle_event(self, event_type: str, fl_ctx: FLContext):
 50        # if event_type == EventType.START_RUN:
 51        #     Create all major components here. This is a simple app that doesn't need any components.
 52        # elif event_type == EventType.END_RUN:
 53        #     # Clean up resources (closing files, joining threads, removing dirs etc.)
 54        pass
 55
 56    def execute(
 57        self,
 58        task_name: str,
 59        shareable: Shareable,
 60        fl_ctx: FLContext,
 61        abort_signal: Signal,
 62    ) -> Shareable:
 63        # Any long tasks should check abort_signal regularly.
 64        # Otherwise, abort client will not work.
 65        count, interval = 0, 0.5
 66        while count < self._sleep_time:
 67            if abort_signal.triggered:
 68                return make_reply(ReturnCode.TASK_ABORTED)
 69            time.sleep(interval)
 70            count += interval
 71
 72        if task_name == self._validate_task_name:
 73            try:
 74                # First we extract DXO from the shareable.
 75                try:
 76                    model_dxo = from_shareable(shareable)
 77                except Exception as e:
 78                    self.log_error(
 79                        fl_ctx, f"Unable to extract model dxo from shareable. Exception: {secure_format_exception(e)}"
 80                    )
 81                    return make_reply(ReturnCode.BAD_TASK_DATA)
 82
 83                # Get model from shareable. data_kind must be WEIGHTS.
 84                if model_dxo.data and model_dxo.data_kind == DataKind.WEIGHTS:
 85                    model = model_dxo.data
 86                else:
 87                    self.log_error(
 88                        fl_ctx, "Model DXO doesn't have data or is not of type DataKind.WEIGHTS. Unable to validate."
 89                    )
 90                    return make_reply(ReturnCode.BAD_TASK_DATA)
 91
 92                # Check if key exists in model
 93                if NPConstants.NUMPY_KEY not in model:
 94                    self.log_error(fl_ctx, "numpy_key not in model. Unable to validate.")
 95                    return make_reply(ReturnCode.BAD_TASK_DATA)
 96
 97                # The workflow provides MODEL_OWNER information in the shareable header.
 98                model_name = shareable.get_header(AppConstants.MODEL_OWNER, "?")
 99
100                # Print properties.
101                self.log_info(fl_ctx, f"Model: \n{model}")
102                self.log_info(fl_ctx, f"Task name: {task_name}")
103                self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}")
104                self.log_info(fl_ctx, f"Validating model from {model_name}.")
105
106                # Check abort signal regularly.
107                if abort_signal.triggered:
108                    return make_reply(ReturnCode.TASK_ABORTED)
109
110                # Do some dummy validation.
111                random_epsilon = np.random.random()
112                self.log_info(fl_ctx, f"Adding random epsilon {random_epsilon} in validation.")
113                val_results = {}
114                np_data = model[NPConstants.NUMPY_KEY]
115                np_data = np.sum(np_data / np.max(np_data))
116                val_results["accuracy"] = np_data + random_epsilon
117
118                # Check abort signal regularly.
119                if abort_signal.triggered:
120                    return make_reply(ReturnCode.TASK_ABORTED)
121
122                self.log_info(fl_ctx, f"Validation result: {val_results}")
123
124                # Create DXO for metrics and return shareable.
125                metric_dxo = DXO(data_kind=DataKind.METRICS, data=val_results)
126                return metric_dxo.to_shareable()
127            except Exception as e:
128                self.log_exception(fl_ctx, f"Exception in NPValidator execute: {secure_format_exception(e)}.")
129                return make_reply(ReturnCode.EXECUTION_EXCEPTION)
130        else:
131            return make_reply(ReturnCode.TASK_UNKNOWN)

The validator is an Executor and implements the execute function which receives a Shareable.

It handles the validate task by performing a calculation to find the sum divided by the max of the data and adding a random_epsilon before returning the results packaged with a DXO into a Shareable.

Note

Note that in our hello-examples, we are demonstrating Federated Learning using data that does not have to do with deep learning. NVIDIA FLARE can be used with any data packaged inside a Shareable object (subclasses dict), and DXO is recommended as a way to manage that data in a standard way.

Cross site validation!

We can run it using NVFlare simulator

python3 job_train_and_cse.py

During the first phase, the model will be trained.

During the second phase, cross site validation will happen.

The workflow on the client will change to CrossSiteModelEval as it enters this second phase.

During cross site model evaluation, every client validates other clients’ models and server models (if present). This can produce a lot of results. All the results will be kept in the job’s workspace when it is completed.

Understanding the Output

You can find the running logs and results inside the simulator’s workspace:

ls /tmp/nvflare/jobs/workdir/
server/  site-1/  site-2/  startup/

The cross site validation results:

Congratulations!

You’ve successfully run your numpy federated learning system with cross site validation.

The full source code for this exercise can be found in examples/hello-world/hello-numpy-cross-val.

Previous Versions of Hello Cross-Site Validation