Hello TensorFlow 2

Before You Start

We recommend you first finish either the Hello PyTorch or the Hello Scatter and Gather exercise.

Those guides go more in depth in explaining the federated learning aspect of NVIDIA FLARE.

Here we assume you have already installed NVIDIA FLARE inside a python virtual environment and have already cloned the repo.

Introduction

Through this exercise, you will integrate NVIDIA FLARE with the popular deep learning framework TensorFlow 2 and learn how to use NVIDIA FLARE to train a convolutional network with the MNIST dataset using the Scatter and Gather workflow. You will also be introduced to some new components and concepts, including filters, aggregators, and event handlers.

The setup of this exercise consists of one server and two clients.

The following steps compose one cycle of weight updates, called a round:

  1. Clients are responsible for generating individual weight-updates for the model using their own MNIST dataset.

  2. These updates are then sent to the server which will aggregate them to produce a model with new weights.

  3. Finally, the server sends this updated version of the model back to each client.

For this exercise, we will be working with the hello-tf2 application in the examples folder. Custom FL applications can contain the folders:

  1. custom: contains the custom components (tf2_net.py, trainer.py, filter.py, tf2_model_persistor.py)

  2. config: contains client and server configurations (config_fed_client.json, config_fed_server.json)

  3. resources: contains the logger config (log.config)

Let’s get started. Since this task is using TensorFlow, let’s go ahead and install the library inside our virtual environment:

(nvflare-env) $ python3 -m pip install tensorflow

NVIDIA FLARE Client

Neural Network

With all the required dependencies installed, you are ready to run a Federated Learning system with two clients and one server.

Before you start, let’s see what a simplified MNIST network looks like.

tf2_net.py
15import tensorflow as tf
16
17
18class Net(tf.keras.Model):
19    def __init__(self):
20        super().__init__()
21        self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
22        self.dense1 = tf.keras.layers.Dense(128, activation="relu")
23        self.dropout = tf.keras.layers.Dropout(0.2)
24        self.dense2 = tf.keras.layers.Dense(10)
25
26    def call(self, x):
27        x = self.flatten(x)
28        x = self.dense1(x)
29        x = self.dropout(x)
30        x = self.dense2(x)
31        return x

This Net class is the convolutional neural network to train with MNIST dataset. This is not related to NVIDIA FLARE, so implement it in a file called tf2_net.py.

Dataset & Setup

Now you have to implement the class Trainer, which is a subclass of Executor in NVIDIA FLARE, in a file called trainer.py.

Before you can really start a training, you need to set up your dataset. In this exercise, you can download it from the Internet via tf.keras’s datasets module, and split it in half to create a separate dataset for each client. Additionally, you must setup the optimizer, loss function and transform to process the data.

Since every step will be encapsulated in the SimpleTrainer class, let’s put this preparation stage into one method setup:

41    def handle_event(self, event_type: str, fl_ctx: FLContext):
42        if event_type == EventType.START_RUN:
43            self.setup(fl_ctx)
44
45    def setup(self, fl_ctx: FLContext):
46        (self.train_images, self.train_labels), (
47            self.test_images,
48            self.test_labels,
49        ) = tf.keras.datasets.mnist.load_data()
50        self.train_images, self.test_images = (
51            self.train_images / 255.0,
52            self.test_images / 255.0,
53        )
54
55        # simulate separate datasets for each client by dividing MNIST dataset in half
56        client_name = fl_ctx.get_identity_name()
57        if client_name == "site-1":
58            self.train_images = self.train_images[: len(self.train_images) // 2]
59            self.train_labels = self.train_labels[: len(self.train_labels) // 2]
60            self.test_images = self.test_images[: len(self.test_images) // 2]
61            self.test_labels = self.test_labels[: len(self.test_labels) // 2]
62        elif client_name == "site-2":
63            self.train_images = self.train_images[len(self.train_images) // 2 :]
64            self.train_labels = self.train_labels[len(self.train_labels) // 2 :]
65            self.test_images = self.test_images[len(self.test_images) // 2 :]
66            self.test_labels = self.test_labels[len(self.test_labels) // 2 :]
67
68        model = Net()
69
70        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
71        model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])

How can you ensure this setup method is called before the client receives the model from the server?

The Trainer class is also a FLComponent, which always receives Event whenever NVIDIA FLARE enters or leaves a certain stage.

In this case, there is an Event called EventType.START_RUN which perfectly matches these requirements. Because our trainer is a subclass of FLComponent, you can implement the handler to handle the event and call the setup method:

37        self.test_images, self.test_labels = None, None
38        self.model = None
39        self.var_list = None

Note

This is a new concept you haven’t learned in previous two exercises.

The concepts of event and handler are very powerful because you are free to add your logic so it can run at different time and process various events.

The entire list of events fired by NVIDIA FLARE is shown at Event types.

You have everything you need, now let’s implement the last method called execute, which is called every time the client receives an updated model from the server with the Task we will configure.

NVIDIA FLARE Server & Application

Filter

filter can be used for additional data processing in the Shareable, for both inbound and outbound data from the client and/or server.

For this exercise, we use a basic exclude_var filter to exclude the variable/layer flatten from the task result as it goes outbound from the client to the server. The excluded layer is replaced with all zeros of the same shape, which reduces compression size and ensures that the clients’ weights for this variable are not shared with the server.

filter.py
 15import re
 16
 17import numpy as np
 18
 19from nvflare.apis.dxo import DXO, DataKind, from_shareable
 20from nvflare.apis.filter import Filter
 21from nvflare.apis.fl_context import FLContext
 22from nvflare.apis.shareable import Shareable
 23
 24
 25class ExcludeVars(Filter):
 26    """
 27        Exclude/Remove variables from Sharable
 28
 29    Args:
 30        exclude_vars: if not specified (None), all layers are being encrypted;
 31                      if list of variable/layer names, only specified variables are excluded;
 32                      if string containing regular expression (e.g. "conv"), only matched variables are being excluded.
 33    """
 34
 35    def __init__(self, exclude_vars=None):
 36        super().__init__()
 37        self.exclude_vars = exclude_vars
 38        self.skip = False
 39        if self.exclude_vars is not None:
 40            if not (isinstance(self.exclude_vars, list) or isinstance(self.exclude_vars, str)):
 41                self.skip = True
 42                self.logger.debug("Need to provide a list of layer names or a string for regex matching")
 43                return
 44
 45            if isinstance(self.exclude_vars, list):
 46                for var in self.exclude_vars:
 47                    if not isinstance(var, str):
 48                        self.skip = True
 49                        self.logger.debug("encrypt_layers needs to be a list of layer names to encrypt.")
 50                        return
 51                self.logger.debug(f"Excluding {self.exclude_vars} from shareable")
 52            elif isinstance(self.exclude_vars, str):
 53                self.exclude_vars = re.compile(self.exclude_vars) if self.exclude_vars else None
 54                if self.exclude_vars is None:
 55                    self.skip = True
 56                self.logger.debug(f'Excluding all layers based on regex matches with "{self.exclude_vars}"')
 57        else:
 58            self.logger.debug("Not excluding anything")
 59            self.skip = True
 60
 61    def process(self, shareable: Shareable, fl_ctx: FLContext) -> Shareable:
 62
 63        self.log_debug(fl_ctx, "inside filter")
 64        if self.skip:
 65            return shareable
 66
 67        try:
 68            dxo = from_shareable(shareable)
 69        except:
 70            self.log_exception(fl_ctx, "shareable data is not a valid DXO")
 71            return shareable
 72
 73        assert isinstance(dxo, DXO)
 74        if dxo.data_kind not in (DataKind.WEIGHT_DIFF, DataKind.WEIGHTS):
 75            self.log_debug(fl_ctx, "I cannot handle {}".format(dxo.data_kind))
 76            return shareable
 77
 78        if dxo.data is None:
 79            self.log_debug(fl_ctx, "no data to filter")
 80            return shareable
 81
 82        weights = dxo.data
 83
 84        # parse regex encrypt layers
 85        if isinstance(self.exclude_vars, re.Pattern):
 86            re_pattern = self.exclude_vars
 87            self.exclude_vars = []
 88            for var_name in weights.keys():
 89                if re_pattern.search(var_name):
 90                    self.exclude_vars.append(var_name)
 91            self.log_debug(fl_ctx, f"Regex found {self.exclude_vars} matching layers.")
 92            if len(self.exclude_vars) == 0:
 93                self.log_warning(fl_ctx, f"No matching layers found with regex {re_pattern}")
 94
 95        # remove variables
 96        n_excluded = 0
 97        var_names = list(weights.keys())  # needs to recast to list to be used in for loop
 98        n_vars = len(var_names)
 99        for var_name in var_names:
100            # self.logger.info(f"Checking {var_name}")
101            if var_name in self.exclude_vars:
102                self.log_debug(fl_ctx, f"Excluding {var_name}")
103                weights[var_name] = np.zeros(weights[var_name].shape)
104                n_excluded += 1
105        self.log_debug(
106            fl_ctx,
107            f"Excluded {n_excluded} of {n_vars} variables. {len(weights.keys())} remaining.",
108        )
109
110        dxo.data = weights
111        return dxo.update_shareable(shareable)

The filtering procedure occurs in the one required method, process, which receives and returns a shareable. The parameters for what is excluded and the inbound/outbound option are all set in config_fed_client.json (shown later below) and passed in through the constructor.

Model Aggregator

The model aggregator is used by the server to aggregate the clients’ models into one model within the Scatter and Gather workflow.

In this exercise, we perform a simple average over the two clients’ weights with the InTimeAccumulateWeightedAggregator and configure for it to be used in config_fed_server.json (shown later below).

Model Persistor

The model persistor is used to load and save models on the server.

tf2_model_persistor.py
 15import json
 16import os
 17
 18import tensorflow as tf
 19from tf2_net import Net
 20
 21from nvflare.apis.event_type import EventType
 22from nvflare.apis.fl_constant import FLContextKey
 23from nvflare.apis.fl_context import FLContext
 24from nvflare.app_common.abstract.model import ModelLearnable, make_model_learnable
 25from nvflare.app_common.abstract.model_persistor import ModelPersistor
 26from nvflare.app_common.app_constant import AppConstants
 27from nvflare.fuel.utils import fobs
 28
 29
 30class TF2ModelPersistor(ModelPersistor):
 31    def __init__(self, save_name="tf2_model.fobs"):
 32        super().__init__()
 33        self.save_name = save_name
 34
 35    def _initialize(self, fl_ctx: FLContext):
 36        # get save path from FLContext
 37        app_root = fl_ctx.get_prop(FLContextKey.APP_ROOT)
 38        env = None
 39        run_args = fl_ctx.get_prop(FLContextKey.ARGS)
 40        if run_args:
 41            env_config_file_name = os.path.join(app_root, run_args.env)
 42            if os.path.exists(env_config_file_name):
 43                try:
 44                    with open(env_config_file_name) as file:
 45                        env = json.load(file)
 46                except:
 47                    self.system_panic(
 48                        reason="error opening env config file {}".format(env_config_file_name), fl_ctx=fl_ctx
 49                    )
 50                    return
 51
 52        if env is not None:
 53            if env.get("APP_CKPT_DIR", None):
 54                fl_ctx.set_prop(AppConstants.LOG_DIR, env["APP_CKPT_DIR"], private=True, sticky=True)
 55            if env.get("APP_CKPT") is not None:
 56                fl_ctx.set_prop(
 57                    AppConstants.CKPT_PRELOAD_PATH,
 58                    env["APP_CKPT"],
 59                    private=True,
 60                    sticky=True,
 61                )
 62
 63        log_dir = fl_ctx.get_prop(AppConstants.LOG_DIR)
 64        if log_dir:
 65            self.log_dir = os.path.join(app_root, log_dir)
 66        else:
 67            self.log_dir = app_root
 68        self._fobs_save_path = os.path.join(self.log_dir, self.save_name)
 69        if not os.path.exists(self.log_dir):
 70            os.makedirs(self.log_dir)
 71
 72        fl_ctx.sync_sticky()
 73
 74    def load_model(self, fl_ctx: FLContext) -> ModelLearnable:
 75        """Initializes and loads the Model.
 76
 77        Args:
 78            fl_ctx: FLContext
 79
 80        Returns:
 81            Model object
 82        """
 83
 84        if os.path.exists(self._fobs_save_path):
 85            self.logger.info("Loading server weights")
 86            with open(self._fobs_save_path, "rb") as f:
 87                model_learnable = fobs.load(f)
 88        else:
 89            self.logger.info("Initializing server model")
 90            network = Net()
 91            loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
 92            network.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
 93            _ = network(tf.keras.Input(shape=(28, 28)))
 94            var_dict = {network.get_layer(index=key).name: value for key, value in enumerate(network.get_weights())}
 95            model_learnable = make_model_learnable(var_dict, dict())
 96        return model_learnable
 97
 98    def handle_event(self, event: str, fl_ctx: FLContext):
 99        if event == EventType.START_RUN:
100            self._initialize(fl_ctx)
101
102    def save_model(self, model_learnable: ModelLearnable, fl_ctx: FLContext):
103        """Saves model.
104
105        Args:
106            model_learnable: ModelLearnable object
107            fl_ctx: FLContext
108        """
109        model_learnable_info = {k: str(type(v)) for k, v in model_learnable.items()}
110        self.logger.info(f"Saving aggregated server weights: \n {model_learnable_info}")
111        with open(self._fobs_save_path, "wb") as f:
112            fobs.dump(model_learnable, f)

In this exercise, we simply serialize the model weights dictionary using pickle and save it to a log directory calculated in initialize. The file is saved on the FL server and the weights file name is defined in config_fed_server.json. Depending on the frameworks and tools, the methods of saving the model may vary.

FLContext is used throughout these functions to provide various useful FL-related information. You can find more details in the documentation.

Application Configuration

Finally, inside the config folder there are two files, config_fed_client.json and config_fed_server.json.

config_fed_server.json
 1{
 2  "format_version": 2,
 3  "server": {
 4    "heart_beat_timeout": 600
 5  },
 6  "task_data_filters": [],
 7  "task_result_filters": [],
 8  "components": [
 9    {
10      "id": "persistor",
11      "path": "tf2_model_persistor.TF2ModelPersistor",
12      "args": {
13        "save_name": "tf2weights.fobs"
14      }
15    },
16    {
17      "id": "shareable_generator",
18      "path": "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator",
19      "args": {}
20    },
21    {
22      "id": "aggregator",
23      "path": "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator",
24      "args": {
25        "expected_data_kind": "WEIGHTS",
26        "aggregation_weights": {
27          "site-1": 1.0,
28          "site-2": 1.0
29        }
30      }
31    }
32  ],
33  "workflows": [
34    {
35      "id": "scatter_gather_ctl",
36      "path": "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather",
37      "args": {
38        "min_clients": 2,
39        "num_rounds": 3,
40        "start_round": 0,
41        "wait_time_after_min_received": 10,
42        "aggregator_id": "aggregator",
43        "persistor_id": "persistor",
44        "shareable_generator_id": "shareable_generator",
45        "train_task_name": "train",
46        "train_timeout": 0
47      }
48    }
49  ]
50}

Note how the ScatterAndGather workflow is configured to use the included aggregator InTimeAccumulateWeightedAggregator and shareable_generator FullModelShareableGenerator. The persistor is configured to use TF2ModelPersistor in the custom directory of this hello_tf2 app with full Python module paths.

config_fed_client.json
 1{
 2  "format_version": 2,
 3  "executors": [
 4    {
 5      "tasks": [
 6        "train"
 7      ],
 8      "executor": {
 9        "path": "trainer.SimpleTrainer",
10        "args": {
11          "epochs_per_round": 2
12        }
13      }
14    }
15  ],
16  "task_result_filters": [
17    {
18      "tasks": [
19        "train"
20      ],
21      "filters": [
22        {
23          "path": "filter.ExcludeVars",
24          "args": {
25            "exclude_vars": [
26              "flatten"
27            ]
28          }
29        }
30      ]
31    }
32  ],
33  "task_data_filters": []
34}

Here, executors is configured with the Trainer implementation SimpleTrainer. Also, we set up filter.ExcludeVars as a task_result_filters and pass in ["flatten"] as the argument. Both of these are configured for the only Task that will be broadcast in the Scatter and Gather workflow, “train”.

Train the Model, Federated!

Now you can use admin command prompt to submit and start this example job. To do this on a proof of concept local FL system, follow the sections Setting Up the Application Environment in POC Mode and Starting the Application Environment in POC Mode if you have not already.

Running the FL System

With the admin client command prompt successfully connected and logged in, enter the command below.

> submit_job hello-tf2

Pay close attention to what happens in each of four terminals. You can see how the admin submits the job to the server and how the JobRunner on the server automatically picks up the job to deploy and start the run.

This command uploads the job configuration from the admin client to the server. A job id will be returned, and we can use that id to access job information.

Note

If we use submit_job [app] then that app will be treated as a single app job.

From time to time, you can issue check_status server in the admin client to check the entire training progress.

You should now see how the training does in the very first terminal (the one that started the server).

Accessing the results

The results of each job will usually be stored inside the server side workspace.

Please refer to access server-side workspace for accessing the server side workspace.

Shutdown FL system

Once the FL run is complete and the server has successfully aggregated the client’s results after all the rounds, and cross site model evaluation is finished, run the following commands in the fl_admin to shutdown the system (while inputting admin when prompted with password):

> shutdown client
> shutdown server
> bye

Congratulations!

You’ve successfully built and run a federated learning system using TensorFlow 2.

The full source code for this exercise can be found in examples/hello-tf2.

Previous Versions of Hello TensorFlow 2