Quickstart (TensorFlow 2)¶
Before You Start¶
We recommend you first finish either the Quickstart (PyTorch) or the Quickstart (Numpy) exercise. Those guides go more in depth in explaining the federated learning aspect of NVIDIA FLARE.
Here we assume you have already installed NVIDIA FLARE inside a python virtual environment and have already cloned the repo.
Introduction¶
Through this exercise, you will integrate NVIDIA FLARE with the popular deep learning framework TensorFlow 2 and learn how to use NVIDIA FLARE to train a convolutional network with the MNIST dataset using the Scatter and Gather workflow. You will also be introduced to some new components and concepts, including filters, aggregrators, and event handlers.
The design of this exercise consists of one server and two clients all having the same TensorFlow 2 model. The following steps compose one cycle of weight updates, called a round:
Clients are responsible for generating individual weight-updates for the model using their own MNIST dataset.
These updates are then sent to the server which will aggregate them to produce a model with new weights.
Finally, the server sends this updated version of the model back to each client.
For this exercise, we will be working with the hello-tf2
application in the examples folder.
Custom FL applications can contain the folders:
custom: contains the custom components (
tf2_net.py
,trainer.py
,filter.py
,tf2_model_persistor.py
)config: contains client and server configurations (
config_fed_client.json
,config_fed_server.json
)resources: contains the logger config (
log.config
)
Let’s get started. Since this task is using TensorFlow, let’s go ahead and install the library inside our virtual environment:
(nvflare-env) $ python3 -m pip install tensorflow
NVIDIA FLARE Client¶
Neural Network¶
With all the required dependencies installed, you are ready to run a Federated Learning system with two clients and one server. Before you start, let’s see what a simplified MNIST network looks like.
15import tensorflow as tf
16
17
18class Net(tf.keras.Model):
19 def __init__(self):
20 super().__init__()
21 self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
22 self.dense1 = tf.keras.layers.Dense(128, activation="relu")
23 self.dropout = tf.keras.layers.Dropout(0.2)
24 self.dense2 = tf.keras.layers.Dense(10)
25
26 def call(self, x):
27 x = self.flatten(x)
28 x = self.dense1(x)
29 x = self.dropout(x)
30 x = self.dense2(x)
31 return x
This Net class is the convolutional neural network to train with MNIST dataset. This is not related to NVIDIA FLARE, so implement it in a file called tf2_net.py
.
Dataset & Setup¶
Now you have to implement the class Trainer
, which is a subclass of Executor
in NVIDIA FLARE, in a file called trainer.py
.
Before you can really start a training, you need to set up your dataset.
In this exercise, you can download it from the Internet via tf.keras
’s datasets module, and split it in half to create a separate dataset for each client.
Additionally, you must setup the optimizer, loss function and transform to process the data.
Since every step will be encapsulated in the SimpleTrainer
class, let’s put this preparation stage into one method setup
:
41 def setup(self, fl_ctx: FLContext):
42 (self.train_images, self.train_labels), (
43 self.test_images,
44 self.test_labels,
45 ) = tf.keras.datasets.mnist.load_data()
46 self.train_images, self.test_images = (
47 self.train_images / 255.0,
48 self.test_images / 255.0,
49 )
50
51 # simulate separate datasets for each client by dividing MNIST dataset in half
52 client_name = fl_ctx.get_identity_name()
53 if client_name == "site-1":
54 self.train_images = self.train_images[: len(self.train_images) // 2]
55 self.train_labels = self.train_labels[: len(self.train_labels) // 2]
56 self.test_images = self.test_images[: len(self.test_images) // 2]
57 self.test_labels = self.test_labels[: len(self.test_labels) // 2]
58 elif client_name == "site-2":
59 self.train_images = self.train_images[len(self.train_images) // 2 :]
60 self.train_labels = self.train_labels[len(self.train_labels) // 2 :]
61 self.test_images = self.test_images[len(self.test_images) // 2 :]
62 self.test_labels = self.test_labels[len(self.test_labels) // 2 :]
63
64 model = Net()
65
66 loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
67 model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
68 _ = model(tf.keras.Input(shape=(28, 28)))
69 self.var_list = [model.get_layer(index=index).name for index in range(len(model.get_weights()))]
70 self.model = model
71
72 def execute(
73 self,
How can you ensure this setup method is called before the client receives the model from the server? The Trainer
class is also a FLComponent, which always receives Event
whenever NVIDIA FLARE enters or leaves a certain stage.
In this case, there is an Event
called EventType.START_RUN
which perfectly matches these requirements.
Because our trainer is a subclass of FLComponent
, you can implement the handler to handle the event and call the setup method:
37 def handle_event(self, event_type: str, fl_ctx: FLContext):
38 if event_type == EventType.START_RUN:
39 self.setup(fl_ctx)
Note
This is a new concept you haven’t learned in previous two exercises. The concepts of event
and handler
are very powerful because
you are free to add your logic so it can run at different time and process various events. The entire list of events fired by
NVIDIA FLARE is shown at Event types.
You have everything you need, now let’s implement the last method called execute
, which is
called every time the client receives an updated model from the server with the Task we will configure.
Link NVIDIA FLARE with Local Train¶
Take a look at the following code:
def execute(
self,
task_name: str,
shareable: Shareable,
fl_ctx: FLContext,
abort_signal: Signal,
) -> Shareable:
"""
This function is an extended function from the super class.
As a supervised learning based trainer, the train function will run
evaluate and train engines based on model weights from `shareable`.
After finishing training, a new `Shareable` object will be submitted
to server for aggregation.
Args:
task_name: dispatched task
shareable: the `Shareable` object acheived from server.
fl_ctx: the `FLContext` object achieved from server.
abort_signal: if triggered, the training will be aborted.
Returns:
a new `Shareable` object to be submitted to server for aggregation.
"""
# retrieve model weights download from server's shareable
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)
if task_name != "train":
return make_reply(ReturnCode.TASK_UNKNOWN)
dxo = from_shareable(shareable)
model_weights = dxo.data
# use previous round's client weights to replace excluded layers from server
prev_weights = {
self.model.get_layer(index=key).name: value for key, value in enumerate(self.model.get_weights())
}
ordered_model_weights = {key: model_weights.get(key) for key in prev_weights}
for key in self.var_list:
value = ordered_model_weights.get(key)
if np.all(value == 0):
ordered_model_weights[key] = prev_weights[key]
# update local model weights with received weights
self.model.set_weights(list(ordered_model_weights.values()))
# adjust LR or other training time info as needed
# such as callback in the fit function
self.model.fit(
self.train_images,
self.train_labels,
epochs=self.epochs_per_round,
validation_data=(self.test_images, self.test_labels),
)
# report updated weights in shareable
weights = {self.model.get_layer(index=key).name: value for key, value in enumerate(self.model.get_weights())}
dxo = DXO(data_kind=DataKind.WEIGHTS, data=weights)
self.log_info(fl_ctx, "Local epochs finished. Returning shareable")
new_shareable = dxo.to_shareable()
return new_shareable
Every NVIDIA FLARE client receives the model weights from the server in the shareable.
This exercise uses a simple exclude_var
filter, so make sure to replace the missing layer with weights from the clients’ previous training round:
111 ordered_model_weights = {key: model_weights.get(key) for key in prev_weights}
112 for key in self.var_list:
113 value = ordered_model_weights.get(key)
114 if np.all(value == 0):
115 ordered_model_weights[key] = prev_weights[key]
Now update the local model with those received weights:
118 self.model.set_weights(list(ordered_model_weights.values()))
Then perform a simple self.model.fit
so the client’s model is trained with its own dataset:
122 self.model.fit(
123 self.train_images,
124 self.train_labels,
125 epochs=self.epochs_per_round,
126 validation_data=(self.test_images, self.test_labels),
127 )
After finishing the local train, the train method uses the newly-trained weights to build a new DXO
to update the
Shareable
with and then returns it back to the NVIDIA FLARE server.
NVIDIA FLARE Server & Application¶
Filter¶
filter can be used for additional data processing in the Shareable
, for both
inbound and outbound data from the client and/or server.
For this exercise, we use a basic exclude_var
filter to exclude the variable/layer flatten
from the task result
as it goes outbound from the client to the server. The excluded layer is replaced with all zeros of the same shape,
which reduces compression size and ensures that the clients’ weights for this variable are not shared with the server.
15import re
16
17import numpy as np
18from nvflare.apis.dxo import DXO, DataKind, from_shareable
19from nvflare.apis.filter import Filter
20from nvflare.apis.fl_context import FLContext
21from nvflare.apis.shareable import Shareable
22
23
24class ExcludeVars(Filter):
25 """
26 Exclude/Remove variables from Sharable
27
28 Args:
29 exclude_vars: if not specified (None), all layers are being encrypted;
30 if list of variable/layer names, only specified variables are excluded;
31 if string containing regular expression (e.g. "conv"), only matched variables are being excluded.
32 """
33
34 def __init__(self, exclude_vars=None):
35 super().__init__()
36 self.exclude_vars = exclude_vars
37 self.skip = False
38 if self.exclude_vars is not None:
39 if not (
40 isinstance(self.exclude_vars, list)
41 or isinstance(self.exclude_vars, str)
42 ):
43 self.skip = True
44 self.logger.debug(
45 "Need to provide a list of layer names or a string for regex matching"
46 )
47 return
48
49 if isinstance(self.exclude_vars, list):
50 for var in self.exclude_vars:
51 if not isinstance(var, str):
52 self.skip = True
53 self.logger.debug(
54 "encrypt_layers needs to be a list of layer names to encrypt."
55 )
56 return
57 self.logger.debug(f"Excluding {self.exclude_vars} from shareable")
58 elif isinstance(self.exclude_vars, str):
59 self.exclude_vars = (
60 re.compile(self.exclude_vars) if self.exclude_vars else None
61 )
62 if self.exclude_vars is None:
63 self.skip = True
64 self.logger.debug(
65 f'Excluding all layers based on regex matches with "{self.exclude_vars}"'
66 )
67 else:
68 self.logger.debug("Not excluding anything")
69 self.skip = True
70
71 def process(self, shareable: Shareable, fl_ctx: FLContext) -> Shareable:
72
73 self.log_debug(fl_ctx, "inside filter")
74 if self.skip:
75 return shareable
76
77 try:
78 dxo = from_shareable(shareable)
79 except:
80 self.log_exception(fl_ctx, "shareable data is not a valid DXO")
81 return shareable
82
83 assert isinstance(dxo, DXO)
84 if dxo.data_kind not in (DataKind.WEIGHT_DIFF, DataKind.WEIGHTS):
85 self.log_debug(fl_ctx, "I cannot handle {}".format(dxo.data_kind))
86 return shareable
87
88 if dxo.data is None:
89 self.log_debug(fl_ctx, "no data to filter")
90 return shareable
91
92 weights = dxo.data
93
94 # parse regex encrypt layers
95 if isinstance(self.exclude_vars, re.Pattern):
96 re_pattern = self.exclude_vars
97 self.exclude_vars = []
98 for var_name in weights.keys():
99 if re_pattern.search(var_name):
100 self.exclude_vars.append(var_name)
101 self.log_debug(fl_ctx, f"Regex found {self.exclude_vars} matching layers.")
102 if len(self.exclude_vars) == 0:
103 self.log_warning(
104 fl_ctx, f"No matching layers found with regex {re_pattern}"
105 )
106
107 # remove variables
108 n_excluded = 0
109 var_names = list(
110 weights.keys()
111 ) # needs to recast to list to be used in for loop
112 n_vars = len(var_names)
113 for var_name in var_names:
114 # self.logger.info(f"Checking {var_name}")
115 if var_name in self.exclude_vars:
116 self.log_debug(fl_ctx, f"Excluding {var_name}")
117 weights[var_name] = np.zeros(weights[var_name].shape)
118 n_excluded += 1
119 self.log_debug(
120 fl_ctx,
121 f"Excluded {n_excluded} of {n_vars} variables. {len(weights.keys())} remaining.",
122 )
123
124 dxo.data = weights
125 return dxo.update_shareable(shareable)
The filtering procedure occurs in the one required method, process, which receives and returns a shareable.
The parameters for what is excluded and the inbound/outbound option are all set in config_fed_client.json
(shown later below) and passed in through the constructor.
Model Aggregator¶
The model aggregator is used by the server to aggregate the clients’ models into one model within the Scatter and Gather workflow.
In this exercise, we perform a simple average over the two clients’ weights with the AccumulateWeightedAggregator
and configure for it to be used in config_fed_server.json
(shown later below).
Model Persistor¶
The model persistor is used to load and save models on the server.
15import os
16import pickle
17import json
18
19import tensorflow as tf
20from nvflare.apis.event_type import EventType
21from nvflare.apis.fl_constant import FLContextKey
22from nvflare.apis.fl_context import FLContext
23from nvflare.app_common.abstract.model import ModelLearnable
24from nvflare.app_common.abstract.model_persistor import ModelPersistor
25from tf2_net import Net
26from nvflare.app_common.app_constant import AppConstants
27from nvflare.app_common.abstract.model import make_model_learnable
28
29
30class TF2ModelPersistor(ModelPersistor):
31 def __init__(self, save_name="tf2_model.pkl"):
32 super().__init__()
33 self.save_name = save_name
34
35 def _initialize(self, fl_ctx: FLContext):
36 # get save path from FLContext
37 app_root = fl_ctx.get_prop(FLContextKey.APP_ROOT)
38 env = None
39 run_args = fl_ctx.get_prop(FLContextKey.ARGS)
40 if run_args:
41 env_config_file_name = os.path.join(app_root, run_args.env)
42 if os.path.exists(env_config_file_name):
43 try:
44 with open(env_config_file_name) as file:
45 env = json.load(file)
46 except:
47 self.system_panic(
48 reason="error opening env config file {}".format(env_config_file_name), fl_ctx=fl_ctx
49 )
50 return
51
52 if env is not None:
53 if env.get("APP_CKPT_DIR", None):
54 fl_ctx.set_prop(AppConstants.LOG_DIR, env["APP_CKPT_DIR"], private=True, sticky=True)
55 if env.get("APP_CKPT") is not None:
56 fl_ctx.set_prop(
57 AppConstants.CKPT_PRELOAD_PATH,
58 env["APP_CKPT"],
59 private=True,
60 sticky=True,
61 )
62
63 log_dir = fl_ctx.get_prop(AppConstants.LOG_DIR)
64 if log_dir:
65 self.log_dir = os.path.join(app_root, log_dir)
66 else:
67 self.log_dir = app_root
68 self._pkl_save_path = os.path.join(self.log_dir, self.save_name)
69 if not os.path.exists(self.log_dir):
70 os.makedirs(self.log_dir)
71
72 fl_ctx.sync_sticky()
73
74 def load_model(self, fl_ctx: FLContext) -> ModelLearnable:
75 """
76 initialize and load the Model.
77
78 Args:
79 fl_ctx: FLContext
80
81 Returns:
82 Model object
83 """
84
85 if os.path.exists(self._pkl_save_path):
86 self.logger.info(f"Loading server weights")
87 with open(self._pkl_save_path, "rb") as f:
88 model_learnable = pickle.load(f)
89 else:
90 self.logger.info(f"Initializing server model")
91 network = Net()
92 loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
93 network.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
94 _ = network(tf.keras.Input(shape=(28, 28)))
95 var_dict = {network.get_layer(index=key).name: value for key, value in enumerate(network.get_weights())}
96 model_learnable = make_model_learnable(var_dict, dict())
97 return model_learnable
98
99 def handle_event(self, event: str, fl_ctx: FLContext):
100 if event == EventType.START_RUN:
101 self._initialize(fl_ctx)
102
103 def save_model(self, model_learnable: ModelLearnable, fl_ctx: FLContext):
104 """
105 persist the Model object
106
107 Args:
108 model: Model object
109 fl_ctx: FLContext
110 """
111 model_learnable_info = {k: str(type(v)) for k, v in model_learnable.items()}
112 self.logger.info(f"Saving aggregated server weights: \n {model_learnable_info}")
113 with open(self._pkl_save_path, "wb") as f:
114 pickle.dump(model_learnable, f)
In this exercise, we simply serialize the model weights dictionary using pickle and save it to a log directory calculated
in initialize. The file is saved on the FL server and the weights file name is defined in config_fed_server.json
.
Depending on the frameworks and tools, the methods of saving the model may vary.
FLContext is used throughout these functions to provide various useful FL-related information. You can find more details in the documentation.
Application Configuration¶
Finally, inside the config folder there are two files, config_fed_client.json
and config_fed_server.json
.
1{
2 "format_version": 2,
3 "server": {
4 "heart_beat_timeout": 600
5 },
6 "task_data_filters": [],
7 "task_result_filters": [],
8 "components": [
9 {
10 "id": "persistor",
11 "path": "tf2_model_persistor.TF2ModelPersistor",
12 "args": {
13 "save_name": "tf2weights.pickle"
14 }
15 },
16 {
17 "id": "shareable_generator",
18 "path": "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator",
19 "args": {}
20 },
21 {
22 "id": "aggregator",
23 "path": "nvflare.app_common.aggregators.accumulate_model_aggregator.AccumulateWeightedAggregator",
24 "args": {
25 "expected_data_kind": "WEIGHTS"
26 }
27 }
28 ],
29 "workflows": [
30 {
31 "id": "scatter_gather_ctl",
32 "path": "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather",
33 "args": {
34 "min_clients": 1,
35 "num_rounds": 3,
36 "start_round": 0,
37 "wait_time_after_min_received": 10,
38 "aggregator_id": "aggregator",
39 "persistor_id": "persistor",
40 "shareable_generator_id": "shareable_generator",
41 "train_task_name": "train",
42 "train_timeout": 0
43 }
44 }
45 ]
46}
Note how the ScatterAndGather
workflow is
configured to use the included aggregator
AccumulateWeightedAggregator
and shareable_generator
FullModelShareableGenerator
.
The persistor
is configured to use TF2ModelPersistor
in the custom directory of this hello_tf2 app with full
Python module paths.
1{
2 "format_version": 2,
3 "executors": [
4 {
5 "tasks": [
6 "train"
7 ],
8 "executor": {
9 "path": "trainer.SimpleTrainer",
10 "args": {
11 "epochs_per_round": 2
12 }
13 }
14 }
15 ],
16 "task_result_filters": [
17 {
18 "tasks": [
19 "train"
20 ],
21 "filters": [
22 {
23 "path": "filter.ExcludeVars",
24 "args": {
25 "exclude_vars": [
26 "flatten"
27 ]
28 }
29 }
30 ]
31 }
32 ],
33 "task_data_filters": []
34}
Here, executors
is configured with the Trainer implementation SimpleTrainer
.
Also, we set up filter.ExcludeVars
as a task_result_filters
and pass in ["flatten"]
as the argument.
Both of these are configured for the only Task that will be broadcast in the Scatter and Gather workflow, “train”.
Train the Model, Federated!¶
Now you can use admin commands to upload, deploy, and start this example app. To do this on a proof of concept local FL system, follow the sections Setting Up the Application Environment in POC Mode and Starting the Application Environment in POC Mode if you have not already.
Running the FL System¶
With the admin client command prompt successfully connected and logged in, enter the commands below in order. Pay close attention to what happens in each of four terminals. You can see how the admin controls the server and clients with each command.
> upload_app hello-tf2
Uploads the application from the admin client to the server’s staging area.
> set_run_number 1
Creates a run directory in the workspace for the run_number on the server and all clients. The run directory allows for the isolation of different runs so the information in one particular run does not interfere with other runs.
> deploy_app hello-tf2 all
This will make the hello-tf2 application the active one in the run_number workspace. In this exercise, after the above two commands, the
server and all the clients know the hello-tf2 application will reside in run_1
workspace.
> start_app all
This start_app
command instructs the NVIDIA FLARE server and clients to start training with the hello-tf2 application in the run_1
workspace.
From time to time, you can issue check_status server
in the admin client to check the entire training progress.
You should now see how the training does in the very first terminal (the one that started the server).
Once the fl run is complete and the server has successfully aggregated the clients’ results after all the rounds,
run the following commands in the fl_admin to shutdown the system (while inputting admin
when prompted with user name):
> shutdown client
> shutdown server
> bye
In order to stop all processes, run ./stop_fl.sh
.
All artifacts from the FL run can be found in the server run folder you created with set_run_number
. In this exercise,
the folder is run_1
.
Congratulations!
You’ve successfully built and run a federated learning system using TensorFlow 2.
The full source code for this exercise can be found in examples/hello-tf2
.