Hello TensorFlow
This example demonstrates how to use NVIDIA FLARE with TensorFlow to train an image classifier using federated averaging (FedAvg). TensorFlow serves as the deep learning training framework in this example.
For detailed documentation, see the Hello TensorFlow example page.
We recommend using the NVIDIA TensorFlow docker for GPU support. If GPU is not required, a Python virtual environment is sufficient.
To run this example with the FLARE API, refer to the hello_world notebook.
Run NVIDIA TensorFlow Container
Ensure the NVIDIA container toolkit is installed. Then execute the following command:
docker run --gpus=all -it --rm -v [path_to_NVFlare]:/NVFlare nvcr.io/nvidia/tensorflow:xx.xx-tf2-py3
NVIDIA FLARE Installation
For complete installation instructions, visit Installation.
pip install nvflare
clone the example code from GitHub:
git clone https://github.com/NVIDIA/NVFlare.git
Navigate to the hello-tf directory:
git switch <release branch>
cd examples/hello-world/hello-tf
Install the dependencies:
pip install -r requirements.txt
Code Structure
hello-pt
|
|-- client.py # client local training script
|-- model.py # model definition
|-- job.py # job recipe that defines client and server configurations
|-- requirements.txt # dependencies
Data
This example uses the MNIST handwritten digits dataset, which is loaded within the trainer code.
Model
The model.py file defines a simple neural network using TensorFlow’s Keras API. The Net model is a sequential architecture designed for image classification, featuring:
Flatten Layer: Prepares input data for dense layers.
Dense Layer: 128 units with ReLU activation for non-linearity.
Dropout Layer: 20% dropout rate to mitigate overfitting.
Output Layer: 10 units for classifying MNIST digits.
This model is used in federated learning with NVIDIA FLARE, trained across clients using the FedAvg algorithm.
1
2from tensorflow.keras import layers, models
3
4
5class Net(models.Sequential):
6 def __init__(self, input_shape=(None, 28, 28)):
7 super().__init__()
8 self._input_shape = input_shape
9 self.add(layers.Flatten())
10 self.add(layers.Dense(128, activation="relu"))
11 self.add(layers.Dropout(0.2))
12 self.add(layers.Dense(10))
Client Code
The client code client.py is responsible for training. The training code closely resembles standard PyTorch training code, with additional lines to handle data exchange with the server.
1
2import tensorflow as tf
3from model import Net
4
5import nvflare.client as flare
6from nvflare.client.tracking import SummaryWriter
7
8WEIGHTS_PATH = "./tf_model.weights.h5"
9
10
11def main():
12 flare.init()
13 writer = SummaryWriter()
14
15 sys_info = flare.system_info()
16 print(f"system info is: {sys_info}", flush=True)
17
18 model = Net()
19 model.build(input_shape=(None, 28, 28))
20 model.compile(
21 optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"]
22 )
23 model.summary()
24
25 (train_images, train_labels), (
26 test_images,
27 test_labels,
28 ) = tf.keras.datasets.mnist.load_data()
29 train_images, test_images = (
30 train_images / 255.0,
31 test_images / 255.0,
32 )
33
34 # simulate separate datasets for each client by dividing MNIST dataset in half
35 client_name = sys_info["site_name"]
36 if client_name == "site-1":
37 train_images = train_images[: len(train_images) // 2]
38 train_labels = train_labels[: len(train_labels) // 2]
39 test_images = test_images[: len(test_images) // 2]
40 test_labels = test_labels[: len(test_labels) // 2]
41 elif client_name == "site-2":
42 train_images = train_images[len(train_images) // 2 :]
43 train_labels = train_labels[len(train_labels) // 2 :]
44 test_images = test_images[len(test_images) // 2 :]
45 test_labels = test_labels[len(test_labels) // 2 :]
46
47 while flare.is_running():
48 input_model = flare.receive()
49 print(f"current_round={input_model.current_round}")
50
51 sys_info = flare.system_info()
52 print(f"system info is: {sys_info}")
53
54 for k, v in input_model.params.items():
55 model.get_layer(k).set_weights(v)
56
57 _, test_global_acc = model.evaluate(test_images, test_labels, verbose=2)
58 print(
59 f"Accuracy of the received model on round {input_model.current_round} on the test images: {test_global_acc * 100} %"
60 )
61 writer.add_scalar(tag="local_acc", scalar=test_global_acc, global_step=input_model.current_round)
62
63 # training
64 model.fit(train_images, train_labels, epochs=1, validation_data=(test_images, test_labels))
65
66 print("Finished Training")
67
68 model.save_weights(WEIGHTS_PATH)
69
70 sys_info = flare.system_info()
71 print(f"system info is: {sys_info}", flush=True)
72 print(f"finished round: {input_model.current_round}", flush=True)
73
74 output_model = flare.FLModel(
75 params={layer.name: layer.get_weights() for layer in model.layers},
76 params_type="FULL",
77 metrics={"accuracy": test_global_acc},
78 current_round=input_model.current_round,
79 )
80
81 flare.send(output_model)
82
83
84if __name__ == "__main__":
85 main()
Server Code
In federated averaging, the server code aggregates model updates from clients, following a scatter-gather workflow pattern. This example uses the default federated averaging algorithm provided by NVFlare, eliminating the need for custom server code.
Job Recipe Code
The job recipe includes client.py and the built-in FedAvg algorithm.
1
2from model import Net
3
4from nvflare.app_opt.tf.recipes.fedavg import FedAvgRecipe
5from nvflare.recipe import SimEnv, add_experiment_tracking
6
7if __name__ == "__main__":
8 n_clients = 2
9 num_rounds = 3
10 train_script = "client.py"
11
12 recipe = FedAvgRecipe(
13 name="hello-tf_fedavg",
14 num_rounds=num_rounds,
15 # Model can be specified as class instance or dict config:
16 model=Net(),
17 # Alternative: model={"class_path": "model.Net", "args": {}},
18 # For pre-trained weights: initial_ckpt="/server/path/to/model.h5",
19 min_clients=n_clients,
20 train_script=train_script,
21 )
22 add_experiment_tracking(recipe, tracking_type="tensorboard")
23
24 env = SimEnv(num_clients=n_clients)
25 run = recipe.execute(env=env)
26 print()
27 print("Result can be found in :", run.get_result())
28 print("Job Status is:", run.get_status())
29 print()
Model Input Options
The model parameter accepts two formats:
Class instance (subclassed Keras model):
model=Net()- Convenient and PythonicDict config:
model={"class_path": "model.Net", "args": {}}- Better for large models
To resume from pre-trained weights:
recipe = FedAvgRecipe(
model=Net(),
initial_ckpt="/server/path/to/pretrained.h5", # Absolute path
...
)
Note
For TensorFlow/Keras, use a subclassed Keras class instance (for example, Net()) or dict config for model.
SavedModel or .h5 files contain both architecture and weights, so initial_ckpt can be used without model.
Run the Experiment
Execute the script using the job API to create the job and run it with the simulator:
TF_FORCE_GPU_ALLOW_GROWTH=true TF_GPU_ALLOCATOR=cuda_malloc_async python3 job.py
Access the Logs and Results
Find the running logs and results inside the simulator’s workspace:
$ ls /tmp/nvflare/jobs/workdir
Notes on Running with GPUs
When using GPUs, TensorFlow attempts to allocate all available GPU memory at startup. To prevent this in multi-client scenarios, set the following flags:
TF_FORCE_GPU_ALLOW_GROWTH=true TF_GPU_ALLOCATOR=cuda_malloc_async
If you have more GPUs than clients, consider running one client per GPU using the –gpu argument during simulation, e.g., nvflare simulator -n 2 –gpu 0,1 [job].