Hello JAX ========= .. warning:: This example tracks ``main``, which is the NVFlare development branch for the next release. On ``main``, example ``requirements.txt`` files may pin the first upcoming NVFlare version that supports a feature, even before that package is published on PyPI. If the pinned ``nvflare`` version is not available yet, install NVFlare from this repo instead of from PyPI. This example demonstrates how to use NVIDIA FLARE with JAX, Flax, and Optax to train an MNIST classifier using federated averaging (FedAvg). It follows the same hello-world recipe structure as ``hello-pt``, but uses a JAX client training loop and a flattened parameter vector for transport. Install NVFLARE and Dependencies -------------------------------- For the complete installation instructions, see `Installation `_. .. code-block:: bash pip install nvflare First get the example code from GitHub: .. code-block:: bash git clone https://github.com/NVIDIA/NVFlare.git Then navigate to the hello-jax directory: .. code-block:: bash git switch cd examples/hello-world/hello-jax Install the dependencies: .. code-block:: bash pip install -r requirements.txt Code Structure -------------- .. code-block:: text hello-jax | |-- client.py # client local training script |-- model.py # JAX/Flax model helpers |-- prepare_data.py # helper that downloads MNIST and writes .npy files |-- prepare_model.py # helper that writes the initial flattened checkpoint |-- job.py # job recipe that defines client and server configurations |-- requirements.txt # dependencies Data ---- This example uses the `MNIST `_ dataset. The job script downloads the raw MNIST files once before the simulator starts and converts them into ``.npy`` files. Each client then loads from that prepared cache. Model ----- The model in :github_nvflare_link:`model.py ` is a small convolutional neural network implemented with Flax. .. literalinclude:: ../../../examples/hello-world/hello-jax/model.py :language: python :linenos: :caption: model code (model.py) :lines: 14- Client Code ----------- The client code (:github_nvflare_link:`client.py `) keeps the local training loop in JAX while using NVFlare's client API to receive the current global model and return the updated parameters. .. literalinclude:: ../../../examples/hello-world/hello-jax/client.py :language: python :linenos: :caption: client code (client.py) :lines: 14- Server Code ----------- This example uses the base ``FedAvgRecipe`` configured for NumPy parameter exchange. The JAX parameter tree is flattened into a single NumPy vector before it is exchanged with the server, then reconstructed on the client before each training round. Before running the job, prepare two resources: - The initial flattened checkpoint is generated by ``prepare_model.py`` and passed to ``FedAvgRecipe`` through ``initial_ckpt``. - The shared MNIST ``.npy`` cache is prepared once by ``prepare_data.py`` so both simulated clients do not try to download the dataset at the same time or rely on TensorFlow-only data utilities. Prepare Assets -------------- Prepare the initial checkpoint and dataset using the default locations under ``/tmp/nvflare/data/hello-jax``: .. code-block:: bash python prepare_model.py python prepare_data.py You can also prepare them in custom locations: .. code-block:: bash python prepare_model.py --output /path/to/initial_model.npy python prepare_data.py --data_dir /path/to/mnist Job Recipe Code --------------- .. literalinclude:: ../../../examples/hello-world/hello-jax/job.py :language: python :linenos: :caption: job recipe (job.py) :lines: 14- Run Job ------- After the assets are prepared, run the job script to execute the job in a simulation environment. .. code-block:: bash python job.py You can adjust the main hyperparameters from the command line as needed: .. code-block:: bash python job.py --n_clients 2 --num_rounds 3 --epochs 1 --batch_size 128 If you prepared the assets in non-default locations, pass them explicitly: .. code-block:: bash python job.py --initial_ckpt /path/to/initial_model.npy --data_dir /path/to/mnist Output Summary -------------- - **Initialization**: ``BaseModelController`` starts the FedAvg workflow, loads the initial flattened checkpoint, and writes simulation output under ``/tmp/nvflare/simulation/hello-jax``. - **Round 0**: ``site-1`` and ``site-2`` are sampled, evaluate the received model at ``0.0527`` and ``0.0398`` accuracy, then train for one epoch to ``0.8887`` / ``0.8857`` training accuracy with ``0.3616`` / ``0.3778`` loss. The client log also includes a post-training evaluation line with ``trained_model_eval_loss`` and ``accuracy`` before the update is sent back to the server. - **Round 1**: Both sites are sampled again, received-model accuracy improves to ``0.9545`` and ``0.9799``, local training reaches ``0.9702`` / ``0.9686`` accuracy with ``0.0990`` / ``0.0999`` loss, the client log includes a second evaluation pass on the trained local model, and the aggregated validation metric becomes a new best at ``0.9671875``. - **Round 2**: Received-model accuracy improves again to ``0.9762`` and ``0.9900``, local training finishes at ``0.9795`` / ``0.9790`` accuracy with ``0.0671`` / ``0.0683`` loss, the client log again reports ``trained_model_eval_loss`` and ``accuracy`` after local training, and the aggregated validation metric becomes a new best at ``0.98310546875``. - **Completion**: FedAvg finishes after 3 rounds, persists the final NumPy checkpoint to ``/tmp/nvflare/simulation/hello-jax/server/simulate_job/models/server.npy``, and reports the simulation result directory at ``/tmp/nvflare/simulation/hello-jax``.