Hello Cyclic Weight Transfer

Cyclic Weight Transfer (CWT) is an alternative to FedAvg. CWT uses the Cyclic Controller to pass the model weights from one site to the next for repeated fine-tuning.

Note

This example uses the MNIST handwritten digits dataset and will load its data within the trainer code.

Running Tensorflow with GPU

We recommend using NVIDIA TensorFlow docker if you want to use GPU. If you don’t need to run using GPU, you can just use python virtual environment.

Run NVIDIA TensorFlow container

Please install the NVIDIA container toolkit first. Then run the following command:

docker run --gpus=all -it --rm -v [path_to_NVFlare]:/NVFlare nvcr.io/nvidia/tensorflow:xx.xx-tf2-py3

Notes on running with GPUs

If you choose to run the example using GPUs, it is important to note that, by default, TensorFlow will attempt to allocate all available GPU memory at the start. In scenarios where multiple clients are involved, you have to prevent TensorFlow from allocating all GPU memory by setting the following flags.

TF_FORCE_GPU_ALLOW_GROWTH=true TF_GPU_ALLOCATOR=cuda_malloc_async

Install NVFlare

For the complete installation instructions, see Installation

pip install nvflare

Get the example code from GitHub:

git clone https://github.com/NVIDIA/NVFlare.git
git switch <release branch>
cd examples/hello-world/hello-cyclic

Install the dependency

pip install -r requirements.txt

Code Structure

Code structure:

hello-cyclic
|
|-- client.py           # client local training script
|-- model.py            # model definition
|-- job.py              # job recipe that defines client and server configurations
|-- prepare_data.sh     # scripts to download the data
|-- requirements.txt    # dependencies

Data

In this example, we will use the MNIST datasets, which is provided by TensorFlow Keras API.

Model

The model.py file defines a simple neural network using TensorFlow’s Keras API. The Net model is a sequential architecture designed for image classification, featuring:

  • Flatten Layer: Prepares input data for dense layers.

  • Dense Layer: 128 units with ReLU activation for non-linearity.

  • Dropout Layer: 20% dropout rate to mitigate overfitting.

  • Output Layer: 10 units for classifying MNIST digits.

Model (model.py)
 1
 2from tensorflow.keras import layers, models
 3
 4
 5class Net(models.Sequential):
 6    def __init__(self, input_shape=(None, 28, 28)):
 7        super().__init__()
 8        self._input_shape = input_shape
 9        self.add(layers.Flatten())
10        self.add(layers.Dense(128, activation="relu"))
11        self.add(layers.Dropout(0.2))
12        self.add(layers.Dense(10))

Client Code

The client code client.py is responsible for training. Notice the training code is almost identical to the PyTorch standard training code. The only difference is that we added a few lines to receive and send data to the server.

Client Code (client.py)
 1
 2import tensorflow as tf
 3from model import Net
 4
 5import nvflare.client as flare
 6
 7WEIGHTS_PATH = "./tf_model.weights.h5"
 8
 9
10def main():
11    flare.init()
12
13    sys_info = flare.system_info()
14    print(f"system info is: {sys_info}", flush=True)
15
16    model = Net()
17    model.build(input_shape=(None, 28, 28))
18    model.compile(
19        optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"]
20    )
21    model.summary()
22
23    (train_images, train_labels), (
24        test_images,
25        test_labels,
26    ) = tf.keras.datasets.mnist.load_data()
27    train_images, test_images = (
28        train_images / 255.0,
29        test_images / 255.0,
30    )
31
32    # simulate separate datasets for each client by dividing MNIST dataset in half
33    client_name = sys_info["site_name"]
34    if client_name == "site-1":
35        train_images = train_images[: len(train_images) // 2]
36        train_labels = train_labels[: len(train_labels) // 2]
37        test_images = test_images[: len(test_images) // 2]
38        test_labels = test_labels[: len(test_labels) // 2]
39    elif client_name == "site-2":
40        train_images = train_images[len(train_images) // 2 :]
41        train_labels = train_labels[len(train_labels) // 2 :]
42        test_images = test_images[len(test_images) // 2 :]
43        test_labels = test_labels[len(test_labels) // 2 :]
44
45    while flare.is_running():
46        input_model = flare.receive()
47        print(f"current_round={input_model.current_round}")
48
49        sys_info = flare.system_info()
50        print(f"system info is: {sys_info}")
51
52        for k, v in input_model.params.items():
53            model.get_layer(k).set_weights(v)
54
55        _, test_global_acc = model.evaluate(test_images, test_labels, verbose=2)
56        print(
57            f"Accuracy of the received model on round {input_model.current_round} on the test images: {test_global_acc * 100} %"
58        )
59
60        # training
61        model.fit(train_images, train_labels, epochs=1, validation_data=(test_images, test_labels))
62
63        print("Finished Training")
64
65        model.save_weights(WEIGHTS_PATH)
66
67        sys_info = flare.system_info()
68        print(f"system info is: {sys_info}", flush=True)
69        print(f"finished round: {input_model.current_round}", flush=True)
70
71        output_model = flare.FLModel(
72            params={layer.name: layer.get_weights() for layer in model.layers},
73            params_type="FULL",
74            metrics={"accuracy": test_global_acc},
75            current_round=input_model.current_round,
76        )
77
78        flare.send(output_model)
79
80
81if __name__ == "__main__":
82    main()

Server Code

In cyclic transfer, the server code is responsible for replaying model updates from one client to another. We will directly use the default federated cyclic algorithm provided by NVFlare.

Job Recipe

job recipe (job.py)
 1
 2from model import Net
 3
 4from nvflare.app_opt.tf.recipes.cyclic import CyclicRecipe
 5from nvflare.recipe import SimEnv
 6
 7if __name__ == "__main__":
 8    n_clients = 2
 9    num_rounds = 3
10    train_script = "client.py"
11
12    recipe = CyclicRecipe(
13        num_rounds=num_rounds,
14        min_clients=n_clients,
15        # Model can be specified as class instance or dict config:
16        model=Net(),
17        # Alternative: model={"class_path": "model.Net", "args": {}},
18        # For pre-trained weights: initial_ckpt="/server/path/to/model.h5",
19        train_script=train_script,
20    )
21
22    env = SimEnv(num_clients=n_clients)
23    run = recipe.execute(env=env)
24    print()
25    print("Result can be found in :", run.get_result())
26    print("Job Status is:", run.get_status())
27    print()

Run the Experiment

Prepare the data first:

bash ./prepare_data.sh
python job.py

Access the Logs and Results

You can find the running logs and results inside the simulator’s workspace:

$ ls "/tmp/nvflare/simulation/cyclic"