Hello TensorFlow with Job API¶
Before You Start¶
Feel free to refer to the detailed documentation at any point to learn more about the specifics of NVIDIA FLARE.
We recommend you first finish the Hello FedAvg with NumPy exercise since it introduces the federated learning concepts of NVIDIA FLARE.
Make sure you have an environment with NVIDIA FLARE installed.
You can follow Getting Started on the general concept of setting up a Python virtual environment (the recommended environment) and how to install NVIDIA FLARE.
Here we assume you have already installed NVIDIA FLARE inside a python virtual environment and have already cloned the repo.
Introduction¶
Through this exercise, you will integrate NVIDIA FLARE with the popular deep learning framework
TensorFlow and learn how to use NVIDIA FLARE to train a convolutional
network with the MNIST dataset using the FedAvg
workflow.
You will also be introduced to some new components and concepts, including filters, aggregators, and event handlers.
The setup of this exercise consists of one server and two clients.
The following steps compose one cycle of weight updates, called a round:
Clients are responsible for generating individual weight-updates for the model using their own MNIST dataset.
These updates are then sent to the server which will aggregate them to produce a model with new weights.
Finally, the server sends this updated version of the model back to each client.
For this exercise, we will be working with the hello-tf
application in the examples folder.
Let’s get started. Since this task is using TensorFlow, let’s go ahead and install the library inside our virtual environment:
(nvflare-env) $ python3 -m pip install tensorflow
With all the required dependencies installed, you are ready to run a Federated Learning system
with two clients and one server. If you would like to go ahead and run the exercise now, you can run
the fedavg_script_executor_hello-tf.py
script which builds the job with the Job API and runs the
job with the FLARE Simulator.
NVIDIA FLARE Job API¶
The fedavg_script_executor_hello-tf.py
script for this hello-tf example is very similar to the fedavg_script_executor_hello-numpy.py
script
for the Hello FedAvg with NumPy example and also the script for the Hello PyTorch
example. Other than changes to the names of the job and client script, the only difference is the line to define the initial global model
for the server:
# Define the initial global model and send to server
job.to(TFNet(), "server")
NVIDIA FLARE Client Training Script¶
The training script for this example, hello-tf_fl.py
, is the main script that will be run on the clients. It contains the TensorFlow specific
logic for training.
Neural Network¶
Let’s see what a simplified MNIST network looks like.
15from tensorflow.keras import layers, models
16
17
18class TFNet(models.Sequential):
19 def __init__(self, input_shape=(None, 28, 28)):
20 super().__init__()
21 self._input_shape = input_shape
22 self.add(layers.Flatten())
23 self.add(layers.Dense(128, activation="relu"))
24 self.add(layers.Dropout(0.2))
25 self.add(layers.Dense(10))
This TFNet
class is the convolutional neural network to train with MNIST dataset.
This is not related to NVIDIA FLARE, and it is implemented in a file called tf_net.py
.
Dataset & Setup¶
Before starting training, you need to set up your dataset.
In this exercise, it is downloaded from the Internet via tf.keras
’s datasets module
and split in half to create a separate dataset for each client. Note that this is just for an example since in a real-world scenario,
you will likely have different datasets for each client.
Additionally, the optimizer and loss function need to be configured.
All of this happens before the while flare.is_running():
line in hello-tf_fl.py
.
29 model = TFNet()
30 model.build(input_shape=(None, 28, 28))
31 model.compile(
32 optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"]
33 )
34 model.summary()
35
36 (train_images, train_labels), (
37 test_images,
38 test_labels,
39 ) = tf.keras.datasets.mnist.load_data()
40 train_images, test_images = (
41 train_images / 255.0,
42 test_images / 255.0,
43 )
44
45 # simulate separate datasets for each client by dividing MNIST dataset in half
46 client_name = sys_info["site_name"]
47 if client_name == "site-1":
48 train_images = train_images[: len(train_images) // 2]
49 train_labels = train_labels[: len(train_labels) // 2]
50 test_images = test_images[: len(test_images) // 2]
51 test_labels = test_labels[: len(test_labels) // 2]
52 elif client_name == "site-2":
53 train_images = train_images[len(train_images) // 2 :]
54 train_labels = train_labels[len(train_labels) // 2 :]
55 test_images = test_images[len(test_images) // 2 :]
56 test_labels = test_labels[len(test_labels) // 2 :]
57
Client Local Train¶
The client code gets the weights from the input_model received from the server then performs a simple self.model.fit
so the client’s model is trained with its own dataset:
58 while flare.is_running():
59 input_model = flare.receive()
60 print(f"current_round={input_model.current_round}")
61
62 sys_info = flare.system_info()
63 print(f"system info is: {sys_info}")
64
65 for k, v in input_model.params.items():
66 model.get_layer(k).set_weights(v)
67
68 _, test_global_acc = model.evaluate(test_images, test_labels, verbose=2)
69 print(
70 f"Accuracy of the received model on round {input_model.current_round} on the test images: {test_global_acc * 100} %"
71 )
72
73 # training
74 model.fit(train_images, train_labels, epochs=1, validation_data=(test_images, test_labels))
75
76 print("Finished Training")
77
78 model.save_weights(WEIGHTS_PATH)
79
80 sys_info = flare.system_info()
81 print(f"system info is: {sys_info}", flush=True)
82 print(f"finished round: {input_model.current_round}", flush=True)
83
84 output_model = flare.FLModel(
85 params={layer.name: layer.get_weights() for layer in model.layers},
86 params_type="FULL",
87 metrics={"accuracy": test_global_acc},
88 current_round=input_model.current_round,
89 )
90
91 flare.send(output_model)
After finishing the local training, the newly-trained weights are sent back to the NVIDIA FLARE server in the params of
FLModel
.
NVIDIA FLARE Server & Application¶
In this example, the server runs FedAvg
with the default settings.
If you export the job with the export
function, you will see the
configurations for the server and each client. The server configuration is config_fed_server.json
in the config folder
in app_server:
{
"format_version": 2,
"workflows": [
{
"id": "controller",
"path": "nvflare.app_common.workflows.fedavg.FedAvg",
"args": {
"num_clients": 2,
"num_rounds": 3
}
}
],
"components": [
{
"id": "json_generator",
"path": "nvflare.app_common.widgets.validation_json_generator.ValidationJsonGenerator",
"args": {}
},
{
"id": "model_selector",
"path": "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector",
"args": {
"aggregation_weights": {},
"key_metric": "accuracy"
}
},
{
"id": "persistor",
"path": "nvflare.app_opt.tf.model_persistor.TFModelPersistor",
"args": {
"model": {
"path": "src.tf_net.TFNet",
"args": {}
}
}
}
],
"task_data_filters": [],
"task_result_filters": []
}
This is automatically created by the Job API. The server application configuration leverages NVIDIA FLARE built-in components.
Note that persistor
points to TFModelPersistor
. This is automatically configured when the model is added
to the server with the to
function. The Job API detects that the model is a TensorFlow model
and automatically configures TFModelPersistor
.
Client Configuration¶
The client configuration is config_fed_client.json
in the config folder of each client app folder:
{
"format_version": 2,
"executors": [
{
"tasks": [
"*"
],
"executor": {
"path": "nvflare.app_common.executors.script_executor.ScriptExecutor",
"args": {
"task_script_path": "src/hello-tf_fl.py"
}
}
}
],
"components": [
{
"id": "event_to_fed",
"path": "nvflare.app_common.widgets.convert_to_fed_event.ConvertToFedEvent",
"args": {
"events_to_convert": [
"analytix_log_stats"
]
}
}
],
"task_data_filters": [],
"task_result_filters": []
}
The task_script_path
is set to the path of the client training script.
The full source code for this exercise can be found in examples/hello-tf.