Hello JAX

Warning

This example tracks main, which is the NVFlare development branch for the next release. On main, example requirements.txt files may pin the first upcoming NVFlare version that supports a feature, even before that package is published on PyPI. If the pinned nvflare version is not available yet, install NVFlare from this repo instead of from PyPI.

This example demonstrates how to use NVIDIA FLARE with JAX, Flax, and Optax to train an MNIST classifier using federated averaging (FedAvg). It follows the same hello-world recipe structure as hello-pt, but uses a JAX client training loop and a flattened parameter vector for transport.

Install NVFLARE and Dependencies

For the complete installation instructions, see Installation.

pip install nvflare

First get the example code from GitHub:

git clone https://github.com/NVIDIA/NVFlare.git

Then navigate to the hello-jax directory:

git switch <release branch>
cd examples/hello-world/hello-jax

Install the dependencies:

pip install -r requirements.txt

Code Structure

hello-jax
|
|-- client.py         # client local training script
|-- model.py          # JAX/Flax model helpers
|-- prepare_data.py   # helper that downloads MNIST and writes .npy files
|-- prepare_model.py  # helper that writes the initial flattened checkpoint
|-- job.py            # job recipe that defines client and server configurations
|-- requirements.txt  # dependencies

Data

This example uses the MNIST dataset. The job script downloads the raw MNIST files once before the simulator starts and converts them into .npy files. Each client then loads from that prepared cache.

Model

The model in model.py is a small convolutional neural network implemented with Flax.

model code (model.py)
 1
 2"""
 3JAX/Flax model utilities for the hello-jax MNIST example.
 4"""
 5
 6from functools import lru_cache
 7
 8import jax
 9import jax.numpy as jnp
10import numpy as np
11import optax
12from flax import linen as nn
13from flax.training import train_state
14from jax.flatten_util import ravel_pytree
15
16
17class ConvNet(nn.Module):
18    """Small CNN for MNIST classification."""
19
20    @nn.compact
21    def __call__(self, x):
22        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
23        x = nn.relu(x)
24        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
25        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
26        x = nn.relu(x)
27        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
28        x = x.reshape((x.shape[0], -1))
29        x = nn.Dense(features=128)(x)
30        x = nn.relu(x)
31        x = nn.Dense(features=10)(x)
32        return x
33
34
35MODEL = ConvNet()
36INPUT_SHAPE = (1, 28, 28, 1)
37
38
39@lru_cache(maxsize=1)
40def _template_tree_and_unravel_fn():
41    params = MODEL.init(jax.random.PRNGKey(0), jnp.ones(INPUT_SHAPE, dtype=jnp.float32))["params"]
42    _, unravel_fn = ravel_pytree(params)
43    return params, unravel_fn
44
45
46def create_initial_params():
47    params, _ = _template_tree_and_unravel_fn()
48    return params
49
50
51def flatten_params(params) -> np.ndarray:
52    flat_params, _ = ravel_pytree(params)
53    return np.asarray(flat_params, dtype=np.float32)
54
55
56def unflatten_params(flat_params):
57    _, unravel_fn = _template_tree_and_unravel_fn()
58    return unravel_fn(jnp.asarray(flat_params, dtype=jnp.float32))
59
60
61def create_train_state(params, learning_rate: float, momentum: float) -> train_state.TrainState:
62    tx = optax.sgd(learning_rate=learning_rate, momentum=momentum)
63    return train_state.TrainState.create(apply_fn=MODEL.apply, params=params, tx=tx)

Client Code

The client code (client.py) keeps the local training loop in JAX while using NVFlare’s client API to receive the current global model and return the updated parameters.

client code (client.py)
  1
  2"""
  3Client-side JAX/Flax training script for the hello-jax example.
  4"""
  5
  6import argparse
  7import math
  8import re
  9
 10import jax
 11import jax.numpy as jnp
 12import numpy as np
 13import optax
 14from model import MODEL, create_train_state, flatten_params, unflatten_params
 15
 16import nvflare.client as flare
 17from nvflare.apis.fl_constant import FLMetaKey
 18from nvflare.app_common.np.constants import NPConstants
 19
 20
 21def parse_args():
 22    parser = argparse.ArgumentParser()
 23    parser.add_argument("--epochs", type=int, default=1)
 24    parser.add_argument("--batch_size", type=int, default=128)
 25    parser.add_argument("--learning_rate", type=float, default=0.05)
 26    parser.add_argument("--momentum", type=float, default=0.9)
 27    parser.add_argument("--num_partitions", type=int, default=2)
 28    parser.add_argument("--data_dir", type=str, default="/tmp/nvflare/data/hello-jax/mnist")
 29    return parser.parse_args()
 30
 31
 32def load_mnist(data_dir: str):
 33    train_images = np.load(f"{data_dir}/train_images.npy").astype(np.float32) / 255.0
 34    train_labels = np.load(f"{data_dir}/train_labels.npy").astype(np.int32)
 35    test_images = np.load(f"{data_dir}/test_images.npy").astype(np.float32) / 255.0
 36    test_labels = np.load(f"{data_dir}/test_labels.npy").astype(np.int32)
 37
 38    return (train_images, train_labels), (test_images, test_labels)
 39
 40
 41def split_for_client(images, labels, client_name: str, num_partitions: int):
 42    match = re.search(r"(\d+)$", client_name)
 43    if not match:
 44        return images, labels
 45
 46    client_number = int(match.group(1))
 47    if client_number <= 0:
 48        raise ValueError(f"Client name '{client_name}' must use 1-indexed site numbering.")
 49
 50    client_idx = client_number - 1
 51    partitions = max(num_partitions, 1)
 52    image_splits = np.array_split(images, partitions)
 53    label_splits = np.array_split(labels, partitions)
 54    if client_idx >= len(image_splits):
 55        raise ValueError(
 56            f"Client index {client_idx + 1} from site name '{client_name}' exceeds available partitions "
 57            f"{len(image_splits)}."
 58        )
 59    return image_splits[client_idx], label_splits[client_idx]
 60
 61
 62@jax.jit
 63def train_step(state, images, labels):
 64    def loss_fn(params):
 65        logits = MODEL.apply({"params": params}, images)
 66        loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
 67        return loss, logits
 68
 69    (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
 70    state = state.apply_gradients(grads=grads)
 71    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
 72    return state, loss, accuracy
 73
 74
 75@jax.jit
 76def eval_step(params, images, labels):
 77    logits = MODEL.apply({"params": params}, images)
 78    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
 79    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
 80    return loss, accuracy
 81
 82
 83def train_epoch(state, images, labels, batch_size: int, rng):
 84    num_examples = len(images)
 85    if num_examples == 0:
 86        raise ValueError("No training data available for this client.")
 87
 88    permutation = np.asarray(jax.random.permutation(rng, num_examples))
 89    total_loss = 0.0
 90    total_accuracy = 0.0
 91    steps = 0
 92
 93    for start in range(0, num_examples, batch_size):
 94        end = start + batch_size
 95        indices = permutation[start:end]
 96        batch_images = jnp.asarray(images[indices])
 97        batch_labels = jnp.asarray(labels[indices])
 98        state, loss, accuracy = train_step(state, batch_images, batch_labels)
 99        total_loss += float(loss)
100        total_accuracy += float(accuracy)
101        steps += 1
102
103    return state, total_loss / steps, total_accuracy / steps, steps
104
105
106def evaluate(params, images, labels, batch_size: int):
107    num_examples = len(images)
108    if num_examples == 0:
109        raise ValueError("No evaluation data available for this client.")
110
111    total_loss = 0.0
112    total_accuracy = 0.0
113    steps = 0
114
115    for start in range(0, num_examples, batch_size):
116        end = start + batch_size
117        batch_images = jnp.asarray(images[start:end])
118        batch_labels = jnp.asarray(labels[start:end])
119        loss, accuracy = eval_step(params, batch_images, batch_labels)
120        total_loss += float(loss)
121        total_accuracy += float(accuracy)
122        steps += 1
123
124    return total_loss / steps, total_accuracy / steps
125
126
127def main():
128    args = parse_args()
129    flare.init()
130
131    sys_info = flare.system_info()
132    client_name = sys_info["site_name"]
133
134    (train_images, train_labels), (test_images, test_labels) = load_mnist(args.data_dir)
135    train_images, train_labels = split_for_client(train_images, train_labels, client_name, args.num_partitions)
136    test_images, test_labels = split_for_client(test_images, test_labels, client_name, args.num_partitions)
137
138    print(f"site={client_name}, train_samples={len(train_images)}, test_samples={len(test_images)}")
139
140    rng = jax.random.PRNGKey(0)
141    while flare.is_running():
142        input_model = flare.receive()
143        current_round = input_model.current_round
144        flat_params = input_model.params[NPConstants.NUMPY_KEY]
145        params = unflatten_params(flat_params)
146
147        received_eval_loss, received_accuracy = evaluate(params, test_images, test_labels, args.batch_size)
148        print(
149            f"site={client_name}, round={current_round}, "
150            f"received_model_eval_loss={received_eval_loss:.4f}, accuracy={received_accuracy:.4f}"
151        )
152
153        if flare.is_evaluate():
154            flare.send(flare.FLModel(metrics={"accuracy": received_accuracy, "eval_loss": received_eval_loss}))
155            continue
156
157        state = create_train_state(params, learning_rate=args.learning_rate, momentum=args.momentum)
158        steps_per_epoch = math.ceil(len(train_images) / args.batch_size)
159
160        for epoch in range(args.epochs):
161            rng, epoch_rng = jax.random.split(rng)
162            state, train_loss, train_accuracy, _ = train_epoch(
163                state,
164                train_images,
165                train_labels,
166                args.batch_size,
167                epoch_rng,
168            )
169            print(
170                f"site={client_name}, round={current_round}, epoch={epoch + 1}, "
171                f"train_loss={train_loss:.4f}, train_accuracy={train_accuracy:.4f}"
172            )
173
174        updated_eval_loss, updated_accuracy = evaluate(state.params, test_images, test_labels, args.batch_size)
175        print(
176            f"site={client_name}, round={current_round}, "
177            f"trained_model_eval_loss={updated_eval_loss:.4f}, accuracy={updated_accuracy:.4f}"
178        )
179
180        updated_params = flatten_params(state.params)
181        output_model = flare.FLModel(
182            params={NPConstants.NUMPY_KEY: updated_params},
183            params_type=flare.ParamsType.FULL,
184            metrics={"accuracy": updated_accuracy, "eval_loss": updated_eval_loss},
185            meta={FLMetaKey.NUM_STEPS_CURRENT_ROUND: args.epochs * steps_per_epoch},
186        )
187        flare.send(output_model)
188
189
190if __name__ == "__main__":
191    main()

Server Code

This example uses the base FedAvgRecipe configured for NumPy parameter exchange. The JAX parameter tree is flattened into a single NumPy vector before it is exchanged with the server, then reconstructed on the client before each training round.

Before running the job, prepare two resources:

  • The initial flattened checkpoint is generated by prepare_model.py and passed to FedAvgRecipe through initial_ckpt.

  • The shared MNIST .npy cache is prepared once by prepare_data.py so both simulated clients do not try to download the dataset at the same time or rely on TensorFlow-only data utilities.

Prepare Assets

Prepare the initial checkpoint and dataset using the default locations under /tmp/nvflare/data/hello-jax:

python prepare_model.py
python prepare_data.py

You can also prepare them in custom locations:

python prepare_model.py --output /path/to/initial_model.npy
python prepare_data.py --data_dir /path/to/mnist

Job Recipe Code

job recipe (job.py)
 1"""Recipe entrypoint for the hello-jax MNIST example."""
 2
 3import argparse
 4import os
 5import shlex
 6
 7from nvflare.client.config import ExchangeFormat
 8from nvflare.fuel.utils.constants import FrameworkType
 9from nvflare.recipe import FedAvgRecipe, SimEnv
10
11DEFAULT_INITIAL_CKPT = "/tmp/nvflare/data/hello-jax/initial_model.npy"
12DEFAULT_DATA_DIR = "/tmp/nvflare/data/hello-jax/mnist"
13REQUIRED_DATA_FILES = ("train_images.npy", "train_labels.npy", "test_images.npy", "test_labels.npy")
14
15
16def define_parser():
17    parser = argparse.ArgumentParser()
18    parser.add_argument("--n_clients", type=int, default=2)
19    parser.add_argument("--num_rounds", type=int, default=3)
20    parser.add_argument("--epochs", type=int, default=1)
21    parser.add_argument("--batch_size", type=int, default=128)
22    parser.add_argument("--learning_rate", type=float, default=0.05)
23    parser.add_argument("--momentum", type=float, default=0.9)
24    parser.add_argument("--data_dir", type=str, default=DEFAULT_DATA_DIR)
25    parser.add_argument("--initial_ckpt", type=str, default=DEFAULT_INITIAL_CKPT)
26    parser.add_argument("--train_script", type=str, default="client.py")
27    parser.add_argument(
28        "--launch_external_process",
29        action="store_true",
30        help="Run train_script in a separate subprocess instead of in-process.",
31    )
32    return parser.parse_args()
33
34
35def _validate_inputs(initial_ckpt: str, data_dir: str) -> None:
36    if not os.path.isfile(initial_ckpt):
37        raise FileNotFoundError(
38            f"Initial checkpoint not found: {initial_ckpt}. "
39            f"Run `python prepare_model.py --output {initial_ckpt}` first."
40        )
41
42    missing_files = [name for name in REQUIRED_DATA_FILES if not os.path.isfile(os.path.join(data_dir, name))]
43    if missing_files:
44        missing_str = ", ".join(missing_files)
45        raise FileNotFoundError(
46            f"Prepared MNIST files missing in {data_dir}: {missing_str}. "
47            f"Run `python prepare_data.py --data_dir {data_dir}` first."
48        )
49
50
51def _build_train_args(args) -> str:
52    return shlex.join(
53        [
54            "--epochs",
55            str(args.epochs),
56            "--batch_size",
57            str(args.batch_size),
58            "--learning_rate",
59            str(args.learning_rate),
60            "--momentum",
61            str(args.momentum),
62            "--num_partitions",
63            str(args.n_clients),
64            "--data_dir",
65            args.data_dir,
66        ]
67    )
68
69
70def main():
71    args = define_parser()
72
73    _validate_inputs(args.initial_ckpt, args.data_dir)
74    train_args = _build_train_args(args)
75
76    recipe = FedAvgRecipe(
77        name="hello-jax",
78        min_clients=args.n_clients,
79        num_rounds=args.num_rounds,
80        initial_ckpt=args.initial_ckpt,
81        train_script=args.train_script,
82        train_args=train_args,
83        launch_external_process=args.launch_external_process,
84        framework=FrameworkType.NUMPY,
85        server_expected_format=ExchangeFormat.NUMPY,
86    )
87
88    env = SimEnv(num_clients=args.n_clients)
89    run = recipe.execute(env)
90    print()
91    print("Job Status is:", run.get_status())
92    print("Result can be found in :", run.get_result())
93    print()
94
95
96if __name__ == "__main__":
97    main()

Run Job

After the assets are prepared, run the job script to execute the job in a simulation environment.

python job.py

You can adjust the main hyperparameters from the command line as needed:

python job.py --n_clients 2 --num_rounds 3 --epochs 1 --batch_size 128

If you prepared the assets in non-default locations, pass them explicitly:

python job.py --initial_ckpt /path/to/initial_model.npy --data_dir /path/to/mnist

Output Summary

  • Initialization: BaseModelController starts the FedAvg workflow, loads the initial flattened checkpoint, and writes simulation output under /tmp/nvflare/simulation/hello-jax.

  • Round 0: site-1 and site-2 are sampled, evaluate the received model at 0.0527 and 0.0398 accuracy, then train for one epoch to 0.8887 / 0.8857 training accuracy with 0.3616 / 0.3778 loss. The client log also includes a post-training evaluation line with trained_model_eval_loss and accuracy before the update is sent back to the server.

  • Round 1: Both sites are sampled again, received-model accuracy improves to 0.9545 and 0.9799, local training reaches 0.9702 / 0.9686 accuracy with 0.0990 / 0.0999 loss, the client log includes a second evaluation pass on the trained local model, and the aggregated validation metric becomes a new best at 0.9671875.

  • Round 2: Received-model accuracy improves again to 0.9762 and 0.9900, local training finishes at 0.9795 / 0.9790 accuracy with 0.0671 / 0.0683 loss, the client log again reports trained_model_eval_loss and accuracy after local training, and the aggregated validation metric becomes a new best at 0.98310546875.

  • Completion: FedAvg finishes after 3 rounds, persists the final NumPy checkpoint to /tmp/nvflare/simulation/hello-jax/server/simulate_job/models/server.npy, and reports the simulation result directory at /tmp/nvflare/simulation/hello-jax.