Hello TensorFlow

This example demonstrates how to use NVIDIA FLARE with TensorFlow to train an image classifier using federated averaging (FedAvg). TensorFlow serves as the deep learning training framework in this example.

For detailed documentation, see the Hello TensorFlow example page.

We recommend using the NVIDIA TensorFlow docker for GPU support. If GPU is not required, a Python virtual environment is sufficient.

To run this example with the FLARE API, refer to the hello_world notebook.

Run NVIDIA TensorFlow Container

Ensure the NVIDIA container toolkit is installed. Then execute the following command:

docker run --gpus=all -it --rm -v [path_to_NVFlare]:/NVFlare nvcr.io/nvidia/tensorflow:xx.xx-tf2-py3

NVIDIA FLARE Installation

For complete installation instructions, visit Installation.

pip install nvflare

clone the example code from GitHub:

git clone https://github.com/NVIDIA/NVFlare.git

Navigate to the hello-tf directory:

git switch <release branch>
cd examples/hello-world/hello-tf

Install the dependencies:

pip install -r requirements.txt

Code Structure

hello-pt
|
|-- client.py         # client local training script
|-- model.py          # model definition
|-- job.py            # job recipe that defines client and server configurations
|-- requirements.txt  # dependencies

Data

This example uses the MNIST handwritten digits dataset, which is loaded within the trainer code.

Model

The model.py file defines a simple neural network using TensorFlow’s Keras API. The Net model is a sequential architecture designed for image classification, featuring:

  • Flatten Layer: Prepares input data for dense layers.

  • Dense Layer: 128 units with ReLU activation for non-linearity.

  • Dropout Layer: 20% dropout rate to mitigate overfitting.

  • Output Layer: 10 units for classifying MNIST digits.

This model is used in federated learning with NVIDIA FLARE, trained across clients using the FedAvg algorithm.

model code (model.py)
 1
 2from tensorflow.keras import layers, models
 3
 4
 5class Net(models.Sequential):
 6    def __init__(self, input_shape=(None, 28, 28)):
 7        super().__init__()
 8        self._input_shape = input_shape
 9        self.add(layers.Flatten())
10        self.add(layers.Dense(128, activation="relu"))
11        self.add(layers.Dropout(0.2))
12        self.add(layers.Dense(10))

Client Code

The client code client.py is responsible for training. The training code closely resembles standard PyTorch training code, with additional lines to handle data exchange with the server.

client code (client.py)
 1
 2import tensorflow as tf
 3from model import Net
 4
 5import nvflare.client as flare
 6from nvflare.client.tracking import SummaryWriter
 7
 8WEIGHTS_PATH = "./tf_model.weights.h5"
 9
10
11def main():
12    flare.init()
13    writer = SummaryWriter()
14
15    sys_info = flare.system_info()
16    print(f"system info is: {sys_info}", flush=True)
17
18    model = Net()
19    model.build(input_shape=(None, 28, 28))
20    model.compile(
21        optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"]
22    )
23    model.summary()
24
25    (train_images, train_labels), (
26        test_images,
27        test_labels,
28    ) = tf.keras.datasets.mnist.load_data()
29    train_images, test_images = (
30        train_images / 255.0,
31        test_images / 255.0,
32    )
33
34    # simulate separate datasets for each client by dividing MNIST dataset in half
35    client_name = sys_info["site_name"]
36    if client_name == "site-1":
37        train_images = train_images[: len(train_images) // 2]
38        train_labels = train_labels[: len(train_labels) // 2]
39        test_images = test_images[: len(test_images) // 2]
40        test_labels = test_labels[: len(test_labels) // 2]
41    elif client_name == "site-2":
42        train_images = train_images[len(train_images) // 2 :]
43        train_labels = train_labels[len(train_labels) // 2 :]
44        test_images = test_images[len(test_images) // 2 :]
45        test_labels = test_labels[len(test_labels) // 2 :]
46
47    while flare.is_running():
48        input_model = flare.receive()
49        print(f"current_round={input_model.current_round}")
50
51        sys_info = flare.system_info()
52        print(f"system info is: {sys_info}")
53
54        for k, v in input_model.params.items():
55            model.get_layer(k).set_weights(v)
56
57        _, test_global_acc = model.evaluate(test_images, test_labels, verbose=2)
58        print(
59            f"Accuracy of the received model on round {input_model.current_round} on the test images: {test_global_acc * 100} %"
60        )
61        writer.add_scalar(tag="local_acc", scalar=test_global_acc, global_step=input_model.current_round)
62
63        # training
64        model.fit(train_images, train_labels, epochs=1, validation_data=(test_images, test_labels))
65
66        print("Finished Training")
67
68        model.save_weights(WEIGHTS_PATH)
69
70        sys_info = flare.system_info()
71        print(f"system info is: {sys_info}", flush=True)
72        print(f"finished round: {input_model.current_round}", flush=True)
73
74        output_model = flare.FLModel(
75            params={layer.name: layer.get_weights() for layer in model.layers},
76            params_type="FULL",
77            metrics={"accuracy": test_global_acc},
78            current_round=input_model.current_round,
79        )
80
81        flare.send(output_model)
82
83
84if __name__ == "__main__":
85    main()

Server Code

In federated averaging, the server code aggregates model updates from clients, following a scatter-gather workflow pattern. This example uses the default federated averaging algorithm provided by NVFlare, eliminating the need for custom server code.

Job Recipe Code

The job recipe includes client.py and the built-in FedAvg algorithm.

job recipe (job.py)
 1
 2from model import Net
 3
 4from nvflare.app_opt.tf.recipes.fedavg import FedAvgRecipe
 5from nvflare.recipe import SimEnv, add_experiment_tracking
 6
 7if __name__ == "__main__":
 8    n_clients = 2
 9    num_rounds = 3
10    train_script = "client.py"
11
12    recipe = FedAvgRecipe(
13        name="hello-tf_fedavg",
14        num_rounds=num_rounds,
15        # Model can be specified as class instance or dict config:
16        model=Net(),
17        # Alternative: model={"class_path": "model.Net", "args": {}},
18        # For pre-trained weights: initial_ckpt="/server/path/to/model.h5",
19        min_clients=n_clients,
20        train_script=train_script,
21    )
22    add_experiment_tracking(recipe, tracking_type="tensorboard")
23
24    env = SimEnv(num_clients=n_clients)
25    run = recipe.execute(env=env)
26    print()
27    print("Result can be found in :", run.get_result())
28    print("Job Status is:", run.get_status())
29    print()

Model Input Options

The model parameter accepts two formats:

  1. Class instance (subclassed Keras model): model=Net() - Convenient and Pythonic

  2. Dict config: model={"class_path": "model.Net", "args": {}} - Better for large models

To resume from pre-trained weights:

recipe = FedAvgRecipe(
    model=Net(),
    initial_ckpt="/server/path/to/pretrained.h5",  # Absolute path
    ...
)

Note

For TensorFlow/Keras, use a subclassed Keras class instance (for example, Net()) or dict config for model. SavedModel or .h5 files contain both architecture and weights, so initial_ckpt can be used without model.

Run the Experiment

Execute the script using the job API to create the job and run it with the simulator:

TF_FORCE_GPU_ALLOW_GROWTH=true TF_GPU_ALLOCATOR=cuda_malloc_async python3 job.py

Access the Logs and Results

Find the running logs and results inside the simulator’s workspace:

$ ls /tmp/nvflare/jobs/workdir

Notes on Running with GPUs

When using GPUs, TensorFlow attempts to allocate all available GPU memory at startup. To prevent this in multi-client scenarios, set the following flags:

TF_FORCE_GPU_ALLOW_GROWTH=true TF_GPU_ALLOCATOR=cuda_malloc_async

If you have more GPUs than clients, consider running one client per GPU using the –gpu argument during simulation, e.g., nvflare simulator -n 2 –gpu 0,1 [job].