Hello JAX
Warning
This example tracks main, which is the NVFlare development branch for the next release.
On main, example requirements.txt files may pin the first upcoming NVFlare version
that supports a feature, even before that package is published on PyPI.
If the pinned nvflare version is not available yet, install NVFlare from this repo
instead of from PyPI.
This example demonstrates how to use NVIDIA FLARE with JAX, Flax, and Optax to train an MNIST classifier using federated averaging (FedAvg). It follows the same hello-world recipe structure as hello-pt, but uses a JAX client training loop and a flattened parameter vector for transport.
Install NVFLARE and Dependencies
For the complete installation instructions, see Installation.
pip install nvflare
First get the example code from GitHub:
git clone https://github.com/NVIDIA/NVFlare.git
Then navigate to the hello-jax directory:
git switch <release branch>
cd examples/hello-world/hello-jax
Install the dependencies:
pip install -r requirements.txt
Code Structure
hello-jax
|
|-- client.py # client local training script
|-- model.py # JAX/Flax model helpers
|-- prepare_data.py # helper that downloads MNIST and writes .npy files
|-- prepare_model.py # helper that writes the initial flattened checkpoint
|-- job.py # job recipe that defines client and server configurations
|-- requirements.txt # dependencies
Data
This example uses the MNIST dataset. The job script downloads the raw MNIST files once before the simulator starts and converts them into .npy files. Each client then loads from that prepared cache.
Model
The model in model.py is a small convolutional neural network implemented with Flax.
1
2"""
3JAX/Flax model utilities for the hello-jax MNIST example.
4"""
5
6from functools import lru_cache
7
8import jax
9import jax.numpy as jnp
10import numpy as np
11import optax
12from flax import linen as nn
13from flax.training import train_state
14from jax.flatten_util import ravel_pytree
15
16
17class ConvNet(nn.Module):
18 """Small CNN for MNIST classification."""
19
20 @nn.compact
21 def __call__(self, x):
22 x = nn.Conv(features=32, kernel_size=(3, 3))(x)
23 x = nn.relu(x)
24 x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
25 x = nn.Conv(features=64, kernel_size=(3, 3))(x)
26 x = nn.relu(x)
27 x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
28 x = x.reshape((x.shape[0], -1))
29 x = nn.Dense(features=128)(x)
30 x = nn.relu(x)
31 x = nn.Dense(features=10)(x)
32 return x
33
34
35MODEL = ConvNet()
36INPUT_SHAPE = (1, 28, 28, 1)
37
38
39@lru_cache(maxsize=1)
40def _template_tree_and_unravel_fn():
41 params = MODEL.init(jax.random.PRNGKey(0), jnp.ones(INPUT_SHAPE, dtype=jnp.float32))["params"]
42 _, unravel_fn = ravel_pytree(params)
43 return params, unravel_fn
44
45
46def create_initial_params():
47 params, _ = _template_tree_and_unravel_fn()
48 return params
49
50
51def flatten_params(params) -> np.ndarray:
52 flat_params, _ = ravel_pytree(params)
53 return np.asarray(flat_params, dtype=np.float32)
54
55
56def unflatten_params(flat_params):
57 _, unravel_fn = _template_tree_and_unravel_fn()
58 return unravel_fn(jnp.asarray(flat_params, dtype=jnp.float32))
59
60
61def create_train_state(params, learning_rate: float, momentum: float) -> train_state.TrainState:
62 tx = optax.sgd(learning_rate=learning_rate, momentum=momentum)
63 return train_state.TrainState.create(apply_fn=MODEL.apply, params=params, tx=tx)
Client Code
The client code (client.py) keeps the local training loop in JAX while using NVFlare’s client API to receive the current global model and return the updated parameters.
1
2"""
3Client-side JAX/Flax training script for the hello-jax example.
4"""
5
6import argparse
7import math
8import re
9
10import jax
11import jax.numpy as jnp
12import numpy as np
13import optax
14from model import MODEL, create_train_state, flatten_params, unflatten_params
15
16import nvflare.client as flare
17from nvflare.apis.fl_constant import FLMetaKey
18from nvflare.app_common.np.constants import NPConstants
19
20
21def parse_args():
22 parser = argparse.ArgumentParser()
23 parser.add_argument("--epochs", type=int, default=1)
24 parser.add_argument("--batch_size", type=int, default=128)
25 parser.add_argument("--learning_rate", type=float, default=0.05)
26 parser.add_argument("--momentum", type=float, default=0.9)
27 parser.add_argument("--num_partitions", type=int, default=2)
28 parser.add_argument("--data_dir", type=str, default="/tmp/nvflare/data/hello-jax/mnist")
29 return parser.parse_args()
30
31
32def load_mnist(data_dir: str):
33 train_images = np.load(f"{data_dir}/train_images.npy").astype(np.float32) / 255.0
34 train_labels = np.load(f"{data_dir}/train_labels.npy").astype(np.int32)
35 test_images = np.load(f"{data_dir}/test_images.npy").astype(np.float32) / 255.0
36 test_labels = np.load(f"{data_dir}/test_labels.npy").astype(np.int32)
37
38 return (train_images, train_labels), (test_images, test_labels)
39
40
41def split_for_client(images, labels, client_name: str, num_partitions: int):
42 match = re.search(r"(\d+)$", client_name)
43 if not match:
44 return images, labels
45
46 client_number = int(match.group(1))
47 if client_number <= 0:
48 raise ValueError(f"Client name '{client_name}' must use 1-indexed site numbering.")
49
50 client_idx = client_number - 1
51 partitions = max(num_partitions, 1)
52 image_splits = np.array_split(images, partitions)
53 label_splits = np.array_split(labels, partitions)
54 if client_idx >= len(image_splits):
55 raise ValueError(
56 f"Client index {client_idx + 1} from site name '{client_name}' exceeds available partitions "
57 f"{len(image_splits)}."
58 )
59 return image_splits[client_idx], label_splits[client_idx]
60
61
62@jax.jit
63def train_step(state, images, labels):
64 def loss_fn(params):
65 logits = MODEL.apply({"params": params}, images)
66 loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
67 return loss, logits
68
69 (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
70 state = state.apply_gradients(grads=grads)
71 accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
72 return state, loss, accuracy
73
74
75@jax.jit
76def eval_step(params, images, labels):
77 logits = MODEL.apply({"params": params}, images)
78 loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
79 accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
80 return loss, accuracy
81
82
83def train_epoch(state, images, labels, batch_size: int, rng):
84 num_examples = len(images)
85 if num_examples == 0:
86 raise ValueError("No training data available for this client.")
87
88 permutation = np.asarray(jax.random.permutation(rng, num_examples))
89 total_loss = 0.0
90 total_accuracy = 0.0
91 steps = 0
92
93 for start in range(0, num_examples, batch_size):
94 end = start + batch_size
95 indices = permutation[start:end]
96 batch_images = jnp.asarray(images[indices])
97 batch_labels = jnp.asarray(labels[indices])
98 state, loss, accuracy = train_step(state, batch_images, batch_labels)
99 total_loss += float(loss)
100 total_accuracy += float(accuracy)
101 steps += 1
102
103 return state, total_loss / steps, total_accuracy / steps, steps
104
105
106def evaluate(params, images, labels, batch_size: int):
107 num_examples = len(images)
108 if num_examples == 0:
109 raise ValueError("No evaluation data available for this client.")
110
111 total_loss = 0.0
112 total_accuracy = 0.0
113 steps = 0
114
115 for start in range(0, num_examples, batch_size):
116 end = start + batch_size
117 batch_images = jnp.asarray(images[start:end])
118 batch_labels = jnp.asarray(labels[start:end])
119 loss, accuracy = eval_step(params, batch_images, batch_labels)
120 total_loss += float(loss)
121 total_accuracy += float(accuracy)
122 steps += 1
123
124 return total_loss / steps, total_accuracy / steps
125
126
127def main():
128 args = parse_args()
129 flare.init()
130
131 sys_info = flare.system_info()
132 client_name = sys_info["site_name"]
133
134 (train_images, train_labels), (test_images, test_labels) = load_mnist(args.data_dir)
135 train_images, train_labels = split_for_client(train_images, train_labels, client_name, args.num_partitions)
136 test_images, test_labels = split_for_client(test_images, test_labels, client_name, args.num_partitions)
137
138 print(f"site={client_name}, train_samples={len(train_images)}, test_samples={len(test_images)}")
139
140 rng = jax.random.PRNGKey(0)
141 while flare.is_running():
142 input_model = flare.receive()
143 current_round = input_model.current_round
144 flat_params = input_model.params[NPConstants.NUMPY_KEY]
145 params = unflatten_params(flat_params)
146
147 received_eval_loss, received_accuracy = evaluate(params, test_images, test_labels, args.batch_size)
148 print(
149 f"site={client_name}, round={current_round}, "
150 f"received_model_eval_loss={received_eval_loss:.4f}, accuracy={received_accuracy:.4f}"
151 )
152
153 if flare.is_evaluate():
154 flare.send(flare.FLModel(metrics={"accuracy": received_accuracy, "eval_loss": received_eval_loss}))
155 continue
156
157 state = create_train_state(params, learning_rate=args.learning_rate, momentum=args.momentum)
158 steps_per_epoch = math.ceil(len(train_images) / args.batch_size)
159
160 for epoch in range(args.epochs):
161 rng, epoch_rng = jax.random.split(rng)
162 state, train_loss, train_accuracy, _ = train_epoch(
163 state,
164 train_images,
165 train_labels,
166 args.batch_size,
167 epoch_rng,
168 )
169 print(
170 f"site={client_name}, round={current_round}, epoch={epoch + 1}, "
171 f"train_loss={train_loss:.4f}, train_accuracy={train_accuracy:.4f}"
172 )
173
174 updated_eval_loss, updated_accuracy = evaluate(state.params, test_images, test_labels, args.batch_size)
175 print(
176 f"site={client_name}, round={current_round}, "
177 f"trained_model_eval_loss={updated_eval_loss:.4f}, accuracy={updated_accuracy:.4f}"
178 )
179
180 updated_params = flatten_params(state.params)
181 output_model = flare.FLModel(
182 params={NPConstants.NUMPY_KEY: updated_params},
183 params_type=flare.ParamsType.FULL,
184 metrics={"accuracy": updated_accuracy, "eval_loss": updated_eval_loss},
185 meta={FLMetaKey.NUM_STEPS_CURRENT_ROUND: args.epochs * steps_per_epoch},
186 )
187 flare.send(output_model)
188
189
190if __name__ == "__main__":
191 main()
Server Code
This example uses the base FedAvgRecipe configured for NumPy parameter exchange. The JAX parameter tree is flattened into a single NumPy vector before it is exchanged with the server, then reconstructed on the client before each training round.
Before running the job, prepare two resources:
The initial flattened checkpoint is generated by
prepare_model.pyand passed toFedAvgRecipethroughinitial_ckpt.The shared MNIST
.npycache is prepared once byprepare_data.pyso both simulated clients do not try to download the dataset at the same time or rely on TensorFlow-only data utilities.
Prepare Assets
Prepare the initial checkpoint and dataset using the default locations under /tmp/nvflare/data/hello-jax:
python prepare_model.py
python prepare_data.py
You can also prepare them in custom locations:
python prepare_model.py --output /path/to/initial_model.npy
python prepare_data.py --data_dir /path/to/mnist
Job Recipe Code
1"""Recipe entrypoint for the hello-jax MNIST example."""
2
3import argparse
4import os
5import shlex
6
7from nvflare.client.config import ExchangeFormat
8from nvflare.fuel.utils.constants import FrameworkType
9from nvflare.recipe import FedAvgRecipe, SimEnv
10
11DEFAULT_INITIAL_CKPT = "/tmp/nvflare/data/hello-jax/initial_model.npy"
12DEFAULT_DATA_DIR = "/tmp/nvflare/data/hello-jax/mnist"
13REQUIRED_DATA_FILES = ("train_images.npy", "train_labels.npy", "test_images.npy", "test_labels.npy")
14
15
16def define_parser():
17 parser = argparse.ArgumentParser()
18 parser.add_argument("--n_clients", type=int, default=2)
19 parser.add_argument("--num_rounds", type=int, default=3)
20 parser.add_argument("--epochs", type=int, default=1)
21 parser.add_argument("--batch_size", type=int, default=128)
22 parser.add_argument("--learning_rate", type=float, default=0.05)
23 parser.add_argument("--momentum", type=float, default=0.9)
24 parser.add_argument("--data_dir", type=str, default=DEFAULT_DATA_DIR)
25 parser.add_argument("--initial_ckpt", type=str, default=DEFAULT_INITIAL_CKPT)
26 parser.add_argument("--train_script", type=str, default="client.py")
27 parser.add_argument(
28 "--launch_external_process",
29 action="store_true",
30 help="Run train_script in a separate subprocess instead of in-process.",
31 )
32 return parser.parse_args()
33
34
35def _validate_inputs(initial_ckpt: str, data_dir: str) -> None:
36 if not os.path.isfile(initial_ckpt):
37 raise FileNotFoundError(
38 f"Initial checkpoint not found: {initial_ckpt}. "
39 f"Run `python prepare_model.py --output {initial_ckpt}` first."
40 )
41
42 missing_files = [name for name in REQUIRED_DATA_FILES if not os.path.isfile(os.path.join(data_dir, name))]
43 if missing_files:
44 missing_str = ", ".join(missing_files)
45 raise FileNotFoundError(
46 f"Prepared MNIST files missing in {data_dir}: {missing_str}. "
47 f"Run `python prepare_data.py --data_dir {data_dir}` first."
48 )
49
50
51def _build_train_args(args) -> str:
52 return shlex.join(
53 [
54 "--epochs",
55 str(args.epochs),
56 "--batch_size",
57 str(args.batch_size),
58 "--learning_rate",
59 str(args.learning_rate),
60 "--momentum",
61 str(args.momentum),
62 "--num_partitions",
63 str(args.n_clients),
64 "--data_dir",
65 args.data_dir,
66 ]
67 )
68
69
70def main():
71 args = define_parser()
72
73 _validate_inputs(args.initial_ckpt, args.data_dir)
74 train_args = _build_train_args(args)
75
76 recipe = FedAvgRecipe(
77 name="hello-jax",
78 min_clients=args.n_clients,
79 num_rounds=args.num_rounds,
80 initial_ckpt=args.initial_ckpt,
81 train_script=args.train_script,
82 train_args=train_args,
83 launch_external_process=args.launch_external_process,
84 framework=FrameworkType.NUMPY,
85 server_expected_format=ExchangeFormat.NUMPY,
86 )
87
88 env = SimEnv(num_clients=args.n_clients)
89 run = recipe.execute(env)
90 print()
91 print("Job Status is:", run.get_status())
92 print("Result can be found in :", run.get_result())
93 print()
94
95
96if __name__ == "__main__":
97 main()
Run Job
After the assets are prepared, run the job script to execute the job in a simulation environment.
python job.py
You can adjust the main hyperparameters from the command line as needed:
python job.py --n_clients 2 --num_rounds 3 --epochs 1 --batch_size 128
If you prepared the assets in non-default locations, pass them explicitly:
python job.py --initial_ckpt /path/to/initial_model.npy --data_dir /path/to/mnist
Output Summary
Initialization:
BaseModelControllerstarts the FedAvg workflow, loads the initial flattened checkpoint, and writes simulation output under/tmp/nvflare/simulation/hello-jax.Round 0:
site-1andsite-2are sampled, evaluate the received model at0.0527and0.0398accuracy, then train for one epoch to0.8887/0.8857training accuracy with0.3616/0.3778loss. The client log also includes a post-training evaluation line withtrained_model_eval_lossandaccuracybefore the update is sent back to the server.Round 1: Both sites are sampled again, received-model accuracy improves to
0.9545and0.9799, local training reaches0.9702/0.9686accuracy with0.0990/0.0999loss, the client log includes a second evaluation pass on the trained local model, and the aggregated validation metric becomes a new best at0.9671875.Round 2: Received-model accuracy improves again to
0.9762and0.9900, local training finishes at0.9795/0.9790accuracy with0.0671/0.0683loss, the client log again reportstrained_model_eval_lossandaccuracyafter local training, and the aggregated validation metric becomes a new best at0.98310546875.Completion: FedAvg finishes after 3 rounds, persists the final NumPy checkpoint to
/tmp/nvflare/simulation/hello-jax/server/simulate_job/models/server.npy, and reports the simulation result directory at/tmp/nvflare/simulation/hello-jax.