Hello Cyclic Weight Transfer
Cyclic Weight Transfer (CWT) is an alternative to FedAvg. CWT uses the Cyclic Controller to pass the model weights from one site to the next for repeated fine-tuning.
Note
This example uses the MNIST handwritten digits dataset and will load its data within the trainer code.
Running Tensorflow with GPU
We recommend using NVIDIA TensorFlow docker if you want to use GPU. If you don’t need to run using GPU, you can just use python virtual environment.
Run NVIDIA TensorFlow container
Please install the NVIDIA container toolkit first. Then run the following command:
docker run --gpus=all -it --rm -v [path_to_NVFlare]:/NVFlare nvcr.io/nvidia/tensorflow:xx.xx-tf2-py3
Notes on running with GPUs
If you choose to run the example using GPUs, it is important to note that, by default, TensorFlow will attempt to allocate all available GPU memory at the start. In scenarios where multiple clients are involved, you have to prevent TensorFlow from allocating all GPU memory by setting the following flags.
TF_FORCE_GPU_ALLOW_GROWTH=true TF_GPU_ALLOCATOR=cuda_malloc_async
Install NVFlare
For the complete installation instructions, see Installation
pip install nvflare
Get the example code from GitHub:
git clone https://github.com/NVIDIA/NVFlare.git
git switch <release branch>
cd examples/hello-world/hello-cyclic
Install the dependency
pip install -r requirements.txt
Code Structure
Code structure:
hello-cyclic
|
|-- client.py # client local training script
|-- model.py # model definition
|-- job.py # job recipe that defines client and server configurations
|-- prepare_data.sh # scripts to download the data
|-- requirements.txt # dependencies
Data
In this example, we will use the MNIST datasets, which is provided by TensorFlow Keras API.
Model
The model.py file defines a simple neural network using TensorFlow’s Keras API. The Net model is a sequential architecture designed for image classification, featuring:
Flatten Layer: Prepares input data for dense layers.
Dense Layer: 128 units with ReLU activation for non-linearity.
Dropout Layer: 20% dropout rate to mitigate overfitting.
Output Layer: 10 units for classifying MNIST digits.
1
2from tensorflow.keras import layers, models
3
4
5class Net(models.Sequential):
6 def __init__(self, input_shape=(None, 28, 28)):
7 super().__init__()
8 self._input_shape = input_shape
9 self.add(layers.Flatten())
10 self.add(layers.Dense(128, activation="relu"))
11 self.add(layers.Dropout(0.2))
12 self.add(layers.Dense(10))
Client Code
The client code client.py is responsible for training. Notice the training code is almost identical to the PyTorch standard training code.
The only difference is that we added a few lines to receive and send data to the server.
1
2import tensorflow as tf
3from model import Net
4
5import nvflare.client as flare
6
7WEIGHTS_PATH = "./tf_model.weights.h5"
8
9
10def main():
11 flare.init()
12
13 sys_info = flare.system_info()
14 print(f"system info is: {sys_info}", flush=True)
15
16 model = Net()
17 model.build(input_shape=(None, 28, 28))
18 model.compile(
19 optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"]
20 )
21 model.summary()
22
23 (train_images, train_labels), (
24 test_images,
25 test_labels,
26 ) = tf.keras.datasets.mnist.load_data()
27 train_images, test_images = (
28 train_images / 255.0,
29 test_images / 255.0,
30 )
31
32 # simulate separate datasets for each client by dividing MNIST dataset in half
33 client_name = sys_info["site_name"]
34 if client_name == "site-1":
35 train_images = train_images[: len(train_images) // 2]
36 train_labels = train_labels[: len(train_labels) // 2]
37 test_images = test_images[: len(test_images) // 2]
38 test_labels = test_labels[: len(test_labels) // 2]
39 elif client_name == "site-2":
40 train_images = train_images[len(train_images) // 2 :]
41 train_labels = train_labels[len(train_labels) // 2 :]
42 test_images = test_images[len(test_images) // 2 :]
43 test_labels = test_labels[len(test_labels) // 2 :]
44
45 while flare.is_running():
46 input_model = flare.receive()
47 print(f"current_round={input_model.current_round}")
48
49 sys_info = flare.system_info()
50 print(f"system info is: {sys_info}")
51
52 for k, v in input_model.params.items():
53 model.get_layer(k).set_weights(v)
54
55 _, test_global_acc = model.evaluate(test_images, test_labels, verbose=2)
56 print(
57 f"Accuracy of the received model on round {input_model.current_round} on the test images: {test_global_acc * 100} %"
58 )
59
60 # training
61 model.fit(train_images, train_labels, epochs=1, validation_data=(test_images, test_labels))
62
63 print("Finished Training")
64
65 model.save_weights(WEIGHTS_PATH)
66
67 sys_info = flare.system_info()
68 print(f"system info is: {sys_info}", flush=True)
69 print(f"finished round: {input_model.current_round}", flush=True)
70
71 output_model = flare.FLModel(
72 params={layer.name: layer.get_weights() for layer in model.layers},
73 params_type="FULL",
74 metrics={"accuracy": test_global_acc},
75 current_round=input_model.current_round,
76 )
77
78 flare.send(output_model)
79
80
81if __name__ == "__main__":
82 main()
Server Code
In cyclic transfer, the server code is responsible for replaying model updates from one client to another. We will directly use the default federated cyclic algorithm provided by NVFlare.
Job Recipe
1
2from model import Net
3
4from nvflare.app_opt.tf.recipes.cyclic import CyclicRecipe
5from nvflare.recipe import SimEnv
6
7if __name__ == "__main__":
8 n_clients = 2
9 num_rounds = 3
10 train_script = "client.py"
11
12 recipe = CyclicRecipe(
13 num_rounds=num_rounds,
14 min_clients=n_clients,
15 # Model can be specified as class instance or dict config:
16 model=Net(),
17 # Alternative: model={"class_path": "model.Net", "args": {}},
18 # For pre-trained weights: initial_ckpt="/server/path/to/model.h5",
19 train_script=train_script,
20 )
21
22 env = SimEnv(num_clients=n_clients)
23 run = recipe.execute(env=env)
24 print()
25 print("Result can be found in :", run.get_result())
26 print("Job Status is:", run.get_status())
27 print()
Run the Experiment
Prepare the data first:
bash ./prepare_data.sh
python job.py
Access the Logs and Results
You can find the running logs and results inside the simulator’s workspace:
$ ls "/tmp/nvflare/simulation/cyclic"