Quickstart (TensorFlow 2)

Before You Start

We recommend you first finish either the Quickstart (PyTorch) or the Quickstart (Numpy) exercise. Those guides go more in depth in explaining the federated learning aspect of NVIDIA FLARE.

Here we assume you have already installed NVIDIA FLARE inside a python virtual environment and have already cloned the repo.

Introduction

Through this exercise, you will integrate NVIDIA FLARE with the popular deep learning framework TensorFlow 2 and learn how to use NVIDIA FLARE to train a convolutional network with the MNIST dataset using the Scatter and Gather workflow. You will also be introduced to some new components and concepts, including filters, aggregrators, and event handlers.

The design of this exercise consists of one server and two clients all having the same TensorFlow 2 model. The following steps compose one cycle of weight updates, called a round:

  1. Clients are responsible for generating individual weight-updates for the model using their own MNIST dataset.

  2. These updates are then sent to the server which will aggregate them to produce a model with new weights.

  3. Finally, the server sends this updated version of the model back to each client.

For this exercise, we will be working with the hello-tf2 application in the examples folder. Custom FL applications can contain the folders:

  1. custom: contains the custom components (tf2_net.py, trainer.py, filter.py, tf2_model_persistor.py)

  2. config: contains client and server configurations (config_fed_client.json, config_fed_server.json)

  3. resources: contains the logger config (log.config)

Let’s get started. Since this task is using TensorFlow, let’s go ahead and install the library inside our virtual environment:

(nvflare-env) $ python3 -m pip install tensorflow

NVIDIA FLARE Client

Neural Network

With all the required dependencies installed, you are ready to run a Federated Learning system with two clients and one server. Before you start, let’s see what a simplified MNIST network looks like.

tf2_net.py
15import tensorflow as tf
16
17
18class Net(tf.keras.Model):
19    def __init__(self):
20        super().__init__()
21        self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
22        self.dense1 = tf.keras.layers.Dense(128, activation="relu")
23        self.dropout = tf.keras.layers.Dropout(0.2)
24        self.dense2 = tf.keras.layers.Dense(10)
25
26    def call(self, x):
27        x = self.flatten(x)
28        x = self.dense1(x)
29        x = self.dropout(x)
30        x = self.dense2(x)
31        return x

This Net class is the convolutional neural network to train with MNIST dataset. This is not related to NVIDIA FLARE, so implement it in a file called tf2_net.py.

Dataset & Setup

Now you have to implement the class Trainer, which is a subclass of Executor in NVIDIA FLARE, in a file called trainer.py.

Before you can really start a training, you need to set up your dataset. In this exercise, you can download it from the Internet via tf.keras’s datasets module, and split it in half to create a separate dataset for each client. Additionally, you must setup the optimizer, loss function and transform to process the data.

Since every step will be encapsulated in the SimpleTrainer class, let’s put this preparation stage into one method setup:

41    def setup(self, fl_ctx: FLContext):
42        (self.train_images, self.train_labels), (
43            self.test_images,
44            self.test_labels,
45        ) = tf.keras.datasets.mnist.load_data()
46        self.train_images, self.test_images = (
47            self.train_images / 255.0,
48            self.test_images / 255.0,
49        )
50
51        # simulate separate datasets for each client by dividing MNIST dataset in half
52        client_name = fl_ctx.get_identity_name()
53        if client_name == "site-1":
54            self.train_images = self.train_images[: len(self.train_images) // 2]
55            self.train_labels = self.train_labels[: len(self.train_labels) // 2]
56            self.test_images = self.test_images[: len(self.test_images) // 2]
57            self.test_labels = self.test_labels[: len(self.test_labels) // 2]
58        elif client_name == "site-2":
59            self.train_images = self.train_images[len(self.train_images) // 2 :]
60            self.train_labels = self.train_labels[len(self.train_labels) // 2 :]
61            self.test_images = self.test_images[len(self.test_images) // 2 :]
62            self.test_labels = self.test_labels[len(self.test_labels) // 2 :]
63
64        model = Net()
65
66        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
67        model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
68        _ = model(tf.keras.Input(shape=(28, 28)))
69        self.var_list = [model.get_layer(index=index).name for index in range(len(model.get_weights()))]
70        self.model = model
71
72    def execute(
73        self,

How can you ensure this setup method is called before the client receives the model from the server? The Trainer class is also a FLComponent, which always receives Event whenever NVIDIA FLARE enters or leaves a certain stage. In this case, there is an Event called EventType.START_RUN which perfectly matches these requirements. Because our trainer is a subclass of FLComponent, you can implement the handler to handle the event and call the setup method:

37    def handle_event(self, event_type: str, fl_ctx: FLContext):
38        if event_type == EventType.START_RUN:
39            self.setup(fl_ctx)

Note

This is a new concept you haven’t learned in previous two exercises. The concepts of event and handler are very powerful because you are free to add your logic so it can run at different time and process various events. The entire list of events fired by NVIDIA FLARE is shown at Event types.

You have everything you need, now let’s implement the last method called execute, which is called every time the client receives an updated model from the server with the Task we will configure.

NVIDIA FLARE Server & Application

Filter

filter can be used for additional data processing in the Shareable, for both inbound and outbound data from the client and/or server.

For this exercise, we use a basic exclude_var filter to exclude the variable/layer flatten from the task result as it goes outbound from the client to the server. The excluded layer is replaced with all zeros of the same shape, which reduces compression size and ensures that the clients’ weights for this variable are not shared with the server.

filter.py
 15import re
 16
 17import numpy as np
 18from nvflare.apis.dxo import DXO, DataKind, from_shareable
 19from nvflare.apis.filter import Filter
 20from nvflare.apis.fl_context import FLContext
 21from nvflare.apis.shareable import Shareable
 22
 23
 24class ExcludeVars(Filter):
 25    """
 26        Exclude/Remove variables from Sharable
 27
 28    Args:
 29        exclude_vars: if not specified (None), all layers are being encrypted;
 30                      if list of variable/layer names, only specified variables are excluded;
 31                      if string containing regular expression (e.g. "conv"), only matched variables are being excluded.
 32    """
 33
 34    def __init__(self, exclude_vars=None):
 35        super().__init__()
 36        self.exclude_vars = exclude_vars
 37        self.skip = False
 38        if self.exclude_vars is not None:
 39            if not (
 40                isinstance(self.exclude_vars, list)
 41                or isinstance(self.exclude_vars, str)
 42            ):
 43                self.skip = True
 44                self.logger.debug(
 45                    "Need to provide a list of layer names or a string for regex matching"
 46                )
 47                return
 48
 49            if isinstance(self.exclude_vars, list):
 50                for var in self.exclude_vars:
 51                    if not isinstance(var, str):
 52                        self.skip = True
 53                        self.logger.debug(
 54                            "encrypt_layers needs to be a list of layer names to encrypt."
 55                        )
 56                        return
 57                self.logger.debug(f"Excluding {self.exclude_vars} from shareable")
 58            elif isinstance(self.exclude_vars, str):
 59                self.exclude_vars = (
 60                    re.compile(self.exclude_vars) if self.exclude_vars else None
 61                )
 62                if self.exclude_vars is None:
 63                    self.skip = True
 64                self.logger.debug(
 65                    f'Excluding all layers based on regex matches with "{self.exclude_vars}"'
 66                )
 67        else:
 68            self.logger.debug("Not excluding anything")
 69            self.skip = True
 70
 71    def process(self, shareable: Shareable, fl_ctx: FLContext) -> Shareable:
 72
 73        self.log_debug(fl_ctx, "inside filter")
 74        if self.skip:
 75            return shareable
 76
 77        try:
 78            dxo = from_shareable(shareable)
 79        except:
 80            self.log_exception(fl_ctx, "shareable data is not a valid DXO")
 81            return shareable
 82
 83        assert isinstance(dxo, DXO)
 84        if dxo.data_kind not in (DataKind.WEIGHT_DIFF, DataKind.WEIGHTS):
 85            self.log_debug(fl_ctx, "I cannot handle {}".format(dxo.data_kind))
 86            return shareable
 87
 88        if dxo.data is None:
 89            self.log_debug(fl_ctx, "no data to filter")
 90            return shareable
 91
 92        weights = dxo.data
 93
 94        # parse regex encrypt layers
 95        if isinstance(self.exclude_vars, re.Pattern):
 96            re_pattern = self.exclude_vars
 97            self.exclude_vars = []
 98            for var_name in weights.keys():
 99                if re_pattern.search(var_name):
100                    self.exclude_vars.append(var_name)
101            self.log_debug(fl_ctx, f"Regex found {self.exclude_vars} matching layers.")
102            if len(self.exclude_vars) == 0:
103                self.log_warning(
104                    fl_ctx, f"No matching layers found with regex {re_pattern}"
105                )
106
107        # remove variables
108        n_excluded = 0
109        var_names = list(
110            weights.keys()
111        )  # needs to recast to list to be used in for loop
112        n_vars = len(var_names)
113        for var_name in var_names:
114            # self.logger.info(f"Checking {var_name}")
115            if var_name in self.exclude_vars:
116                self.log_debug(fl_ctx, f"Excluding {var_name}")
117                weights[var_name] = np.zeros(weights[var_name].shape)
118                n_excluded += 1
119        self.log_debug(
120            fl_ctx,
121            f"Excluded {n_excluded} of {n_vars} variables. {len(weights.keys())} remaining.",
122        )
123
124        dxo.data = weights
125        return dxo.update_shareable(shareable)

The filtering procedure occurs in the one required method, process, which receives and returns a shareable. The parameters for what is excluded and the inbound/outbound option are all set in config_fed_client.json (shown later below) and passed in through the constructor.

Model Aggregator

The model aggregator is used by the server to aggregate the clients’ models into one model within the Scatter and Gather workflow.

In this exercise, we perform a simple average over the two clients’ weights with the AccumulateWeightedAggregator and configure for it to be used in config_fed_server.json (shown later below).

Model Persistor

The model persistor is used to load and save models on the server.

tf2_model_persistor.py
 15import os
 16import pickle
 17import json
 18
 19import tensorflow as tf
 20from nvflare.apis.event_type import EventType
 21from nvflare.apis.fl_constant import FLContextKey
 22from nvflare.apis.fl_context import FLContext
 23from nvflare.app_common.abstract.model import ModelLearnable
 24from nvflare.app_common.abstract.model_persistor import ModelPersistor
 25from tf2_net import Net
 26from nvflare.app_common.app_constant import AppConstants
 27from nvflare.app_common.abstract.model import make_model_learnable
 28
 29
 30class TF2ModelPersistor(ModelPersistor):
 31    def __init__(self, save_name="tf2_model.pkl"):
 32        super().__init__()
 33        self.save_name = save_name
 34
 35    def _initialize(self, fl_ctx: FLContext):
 36        # get save path from FLContext
 37        app_root = fl_ctx.get_prop(FLContextKey.APP_ROOT)
 38        env = None
 39        run_args = fl_ctx.get_prop(FLContextKey.ARGS)
 40        if run_args:
 41            env_config_file_name = os.path.join(app_root, run_args.env)
 42            if os.path.exists(env_config_file_name):
 43                try:
 44                    with open(env_config_file_name) as file:
 45                        env = json.load(file)
 46                except:
 47                    self.system_panic(
 48                        reason="error opening env config file {}".format(env_config_file_name), fl_ctx=fl_ctx
 49                    )
 50                    return
 51
 52        if env is not None:
 53            if env.get("APP_CKPT_DIR", None):
 54                fl_ctx.set_prop(AppConstants.LOG_DIR, env["APP_CKPT_DIR"], private=True, sticky=True)
 55            if env.get("APP_CKPT") is not None:
 56                fl_ctx.set_prop(
 57                    AppConstants.CKPT_PRELOAD_PATH,
 58                    env["APP_CKPT"],
 59                    private=True,
 60                    sticky=True,
 61                )
 62
 63        log_dir = fl_ctx.get_prop(AppConstants.LOG_DIR)
 64        if log_dir:
 65            self.log_dir = os.path.join(app_root, log_dir)
 66        else:
 67            self.log_dir = app_root
 68        self._pkl_save_path = os.path.join(self.log_dir, self.save_name)
 69        if not os.path.exists(self.log_dir):
 70            os.makedirs(self.log_dir)
 71
 72        fl_ctx.sync_sticky()
 73
 74    def load_model(self, fl_ctx: FLContext) -> ModelLearnable:
 75        """
 76            initialize and load the Model.
 77
 78        Args:
 79            fl_ctx: FLContext
 80
 81        Returns:
 82            Model object
 83        """
 84
 85        if os.path.exists(self._pkl_save_path):
 86            self.logger.info(f"Loading server weights")
 87            with open(self._pkl_save_path, "rb") as f:
 88                model_learnable = pickle.load(f)
 89        else:
 90            self.logger.info(f"Initializing server model")
 91            network = Net()
 92            loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
 93            network.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
 94            _ = network(tf.keras.Input(shape=(28, 28)))
 95            var_dict = {network.get_layer(index=key).name: value for key, value in enumerate(network.get_weights())}
 96            model_learnable = make_model_learnable(var_dict, dict())
 97        return model_learnable
 98
 99    def handle_event(self, event: str, fl_ctx: FLContext):
100        if event == EventType.START_RUN:
101            self._initialize(fl_ctx)
102
103    def save_model(self, model_learnable: ModelLearnable, fl_ctx: FLContext):
104        """
105            persist the Model object
106
107        Args:
108            model: Model object
109            fl_ctx: FLContext
110        """
111        model_learnable_info = {k: str(type(v)) for k, v in model_learnable.items()}
112        self.logger.info(f"Saving aggregated server weights: \n {model_learnable_info}")
113        with open(self._pkl_save_path, "wb") as f:
114            pickle.dump(model_learnable, f)

In this exercise, we simply serialize the model weights dictionary using pickle and save it to a log directory calculated in initialize. The file is saved on the FL server and the weights file name is defined in config_fed_server.json. Depending on the frameworks and tools, the methods of saving the model may vary.

FLContext is used throughout these functions to provide various useful FL-related information. You can find more details in the documentation.

Application Configuration

Finally, inside the config folder there are two files, config_fed_client.json and config_fed_server.json.

config_fed_server.json
 1{
 2  "format_version": 2,
 3  "server": {
 4    "heart_beat_timeout": 600
 5  },
 6  "task_data_filters": [],
 7  "task_result_filters": [],
 8  "components": [
 9    {
10      "id": "persistor",
11      "path": "tf2_model_persistor.TF2ModelPersistor",
12      "args": {
13        "save_name": "tf2weights.pickle"
14      }
15    },
16    {
17      "id": "shareable_generator",
18      "path": "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator",
19      "args": {}
20    },
21    {
22      "id": "aggregator",
23      "path": "nvflare.app_common.aggregators.accumulate_model_aggregator.AccumulateWeightedAggregator",
24      "args": {
25        "expected_data_kind": "WEIGHTS"
26      }
27    }
28  ],
29  "workflows": [
30    {
31      "id": "scatter_gather_ctl",
32      "path": "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather",
33      "args": {
34        "min_clients": 1,
35        "num_rounds": 3,
36        "start_round": 0,
37        "wait_time_after_min_received": 10,
38        "aggregator_id": "aggregator",
39        "persistor_id": "persistor",
40        "shareable_generator_id": "shareable_generator",
41        "train_task_name": "train",
42        "train_timeout": 0
43      }
44    }
45  ]
46}

Note how the ScatterAndGather workflow is configured to use the included aggregator AccumulateWeightedAggregator and shareable_generator FullModelShareableGenerator. The persistor is configured to use TF2ModelPersistor in the custom directory of this hello_tf2 app with full Python module paths.

config_fed_client.json
 1{
 2  "format_version": 2,
 3  "executors": [
 4    {
 5      "tasks": [
 6        "train"
 7      ],
 8      "executor": {
 9        "path": "trainer.SimpleTrainer",
10        "args": {
11          "epochs_per_round": 2
12        }
13      }
14    }
15  ],
16  "task_result_filters": [
17    {
18      "tasks": [
19        "train"
20      ],
21      "filters": [
22        {
23          "path": "filter.ExcludeVars",
24          "args": {
25            "exclude_vars": [
26              "flatten"
27            ]
28          }
29        }
30      ]
31    }
32  ],
33  "task_data_filters": []
34}

Here, executors is configured with the Trainer implementation SimpleTrainer. Also, we set up filter.ExcludeVars as a task_result_filters and pass in ["flatten"] as the argument. Both of these are configured for the only Task that will be broadcast in the Scatter and Gather workflow, “train”.

Train the Model, Federated!

Now you can use admin commands to upload, deploy, and start this example app. To do this on a proof of concept local FL system, follow the sections Setting Up the Application Environment in POC Mode and Starting the Application Environment in POC Mode if you have not already.

Running the FL System

With the admin client command prompt successfully connected and logged in, enter the commands below in order. Pay close attention to what happens in each of four terminals. You can see how the admin controls the server and clients with each command.

> upload_app hello-tf2

Uploads the application from the admin client to the server’s staging area.

> set_run_number 1

Creates a run directory in the workspace for the run_number on the server and all clients. The run directory allows for the isolation of different runs so the information in one particular run does not interfere with other runs.

> deploy_app hello-tf2 all

This will make the hello-tf2 application the active one in the run_number workspace. In this exercise, after the above two commands, the server and all the clients know the hello-tf2 application will reside in run_1 workspace.

> start_app all

This start_app command instructs the NVIDIA FLARE server and clients to start training with the hello-tf2 application in the run_1 workspace.

From time to time, you can issue check_status server in the admin client to check the entire training progress.

You should now see how the training does in the very first terminal (the one that started the server).

Once the fl run is complete and the server has successfully aggregated the clients’ results after all the rounds, run the following commands in the fl_admin to shutdown the system (while inputting admin when prompted with user name):

> shutdown client
> shutdown server
> bye

In order to stop all processes, run ./stop_fl.sh.

All artifacts from the FL run can be found in the server run folder you created with set_run_number. In this exercise, the folder is run_1.

Congratulations! You’ve successfully built and run a federated learning system using TensorFlow 2. The full source code for this exercise can be found in examples/hello-tf2.