Hello Pytorch Lightning
This example demonstrates how to use NVIDIA FLARE with PyTorch lightning to train an image classifier using federated averaging (FedAvg). The complete example code can be found in the hello-lightning directory. It is recommended to create a virtual environment and run everything within a virtualenv.
NVIDIA FLARE Installation
for the complete installation instructions, see Installation
pip install nvflare
get the example code from github:
git clone https://github.com/NVIDIA/NVFlare.git
then navigate to the hello-lightning directory:
git switch <release branch>
cd examples/hello-world/hello-lightning
Install the dependency
pip install -r requirements.txt
Code Structure
.
hello-lightning
|
|-- client.py # client local training script
|-- model.py # model definition
|-- job.py # job recipe that defines client and server configurations
|-- requirements.txt # dependencies
Data
This example uses the CIFAR-10 dataset
In a real FL experiment, each client would have their own dataset used for their local training. You can download the CIFAR10 dataset from the Internet via torchvision’s datasets module, You can split the datasets for different clients, so that each client has its own dataset. Here for simplicity’s sake, the same dataset we will be using on each client.
The pytorch data module can download the datasets directly. since we have every site to download the same dataset, there are case, the training happens before the data is ready, which could lead to error. We can pre-download the data before we start the training by running from command line in a terminal
./prepare_data.sh
1DATASET_ROOT="/tmp/nvflare/data"
2
3python3 -c "import torchvision.datasets as datasets; datasets.CIFAR10(root='${DATASET_ROOT}', train=True, download=True)"
4python3 -c "import torchvision.datasets as datasets; datasets.CIFAR10(root='${DATASET_ROOT}', train=False, download=True)"
In PyTorch Lightning, a LightningDataModule is a standardized way to handle data loading and processing. It encapsulates all the steps required to prepare data for training, validation, and testing, making it easier to manage datasets and data loaders in a clean and organized manner. This abstraction helps separate data-related logic from the model and training code, promoting better code organization and reusability.
LightningDataModule
Purpose: The LightningDataModule is designed to encapsulate all data-related operations, including downloading, transforming, and splitting datasets, as well as providing data loaders for training, validation, testing, and prediction.
Key Methods: - prepare_data(): Used for downloading and preparing data. This method is called only once and is not distributed across multiple GPUs or nodes. - setup(stage): Used to set up datasets for different stages (e.g., ‘fit’, ‘validate’, ‘test’, ‘predict’). This method is called on every GPU or node. - train_dataloader(), val_dataloader(), test_dataloader(), predict_dataloader(): These methods return the respective data loaders for each stage.
Setup of DataModule
In the CIFAR10DataModule, we have implemented the following:
Initialization (`__init__`): The constructor initializes the data directory and batch size, which are used throughout the data module.
Data Preparation (`prepare_data`): This method downloads the CIFAR-10 dataset if it is not already available in the specified directory. It prepares both the training and test datasets.
Setup (`setup`): This method assigns datasets for different stages: - For the ‘fit’ and ‘validate’ stages, it splits the CIFAR-10 training dataset into training and validation sets. - For the ‘test’ and ‘predict’ stages, it assigns the test dataset.
Data Loaders: The module provides data loaders for training, validation, testing, and prediction, each configured with the specified batch size.
By using a LightningDataModule, the data handling logic is neatly encapsulated, making it easier to manage and modify data-related operations without affecting the rest of the training code.
1import argparse
2
3import torchvision
4import torchvision.transforms as transforms
5from model import LitNet
6from pytorch_lightning import LightningDataModule, Trainer, seed_everything
7from torch.utils.data import DataLoader, random_split
8
9# (0) import nvflare lightning client API
10import nvflare.client.lightning as flare
11
12seed_everything(7)
13
14
15DATASET_PATH = "/tmp/nvflare/data"
16BATCH_SIZE = 4
17
18transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
19
20
21class CIFAR10DataModule(LightningDataModule):
22 def __init__(self, data_dir: str = DATASET_PATH, batch_size: int = BATCH_SIZE):
23 super().__init__()
24 self.data_dir = data_dir
25 self.batch_size = batch_size
26
27 def prepare_data(self):
28 torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=transform)
29 torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=transform)
30
31 def setup(self, stage: str):
32 # Assign train/val datasets for use in dataloaders
33 if stage == "fit" or stage == "validate":
34 cifar_full = torchvision.datasets.CIFAR10(
35 root=self.data_dir, train=True, download=False, transform=transform
36 )
37 self.cifar_train, self.cifar_val = random_split(cifar_full, [0.8, 0.2])
38
39 # Assign test dataset for use in dataloader(s)
40 if stage == "test" or stage == "predict":
41 self.cifar_test = torchvision.datasets.CIFAR10(
42 root=self.data_dir, train=False, download=False, transform=transform
43 )
44
45 def train_dataloader(self):
46 return DataLoader(self.cifar_train, batch_size=self.batch_size)
47
48 def val_dataloader(self):
49 return DataLoader(self.cifar_val, batch_size=self.batch_size)
50
51 def test_dataloader(self):
52 return DataLoader(self.cifar_test, batch_size=self.batch_size)
53
54 def predict_dataloader(self):
55 return DataLoader(self.cifar_test, batch_size=self.batch_size)
56
57
Model
In PyTorch Lightning, a LightningModule is a high-level abstraction built on top of PyTorch that streamlines the process of training models. It encapsulates the model architecture, training, validation, and testing logic, allowing developers to focus on the core components of their models without getting bogged down by the boilerplate code typically associated with PyTorch.
General Summary of a LightningModule
Model Definition: The LightningModule is initialized with the model architecture, which is defined using PyTorch’s nn.Module. This includes layers, activation functions, and any other components necessary for the model.
Forward Pass: The forward method specifies how the input data flows through the model. This is where the core computation of the model is defined.
Training Logic: The training_step method contains the logic for a single training iteration. It computes the loss and any metrics you wish to track, such as accuracy. This method is called automatically during the training loop.
Validation and Testing: Similar to the training step, the validation_step and test_step methods define how the model is evaluated on validation and test datasets, respectively. These methods help in monitoring the model’s performance and generalization.
Optimizer Configuration: The configure_optimizers method specifies the optimizer(s) and learning rate scheduler(s) used during training. This allows for flexible and customizable training strategies.
By using a LightningModule, developers can leverage PyTorch Lightning’s features like distributed training, automatic checkpointing, and logging, making it easier to scale experiments and manage complex training workflows. This abstraction promotes cleaner code, better organization, and easier debugging, ultimately accelerating the model development process.
1
2from typing import Any
3
4import torch
5import torch.nn as nn
6import torch.nn.functional as F
7import torch.optim as optim
8from pytorch_lightning import LightningModule
9from torchmetrics import Accuracy
10
11NUM_CLASSES = 10
12criterion = nn.CrossEntropyLoss()
13
14
15class Net(nn.Module):
16 def __init__(self):
17 super().__init__()
18 self.conv1 = nn.Conv2d(3, 6, 5)
19 self.pool = nn.MaxPool2d(2, 2)
20 self.conv2 = nn.Conv2d(6, 16, 5)
21 self.fc1 = nn.Linear(16 * 5 * 5, 120)
22 self.fc2 = nn.Linear(120, 84)
23 self.fc3 = nn.Linear(84, 10)
24
25 def forward(self, x):
26 x = self.pool(F.relu(self.conv1(x)))
27 x = self.pool(F.relu(self.conv2(x)))
28 x = torch.flatten(x, 1) # flatten all dimensions except batch
29 x = F.relu(self.fc1(x))
30 x = F.relu(self.fc2(x))
31 x = self.fc3(x)
32 return x
33
34
35class LitNet(LightningModule):
36 def __init__(self):
37 super().__init__()
38 self.save_hyperparameters()
39 self.model = Net()
40 self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
41 self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
42 # (optional) pass additional information via self.__fl_meta__
43 self.__fl_meta__ = {}
44
45 def forward(self, x):
46 out = self.model(x)
47 return out
48
49 def training_step(self, batch, batch_idx):
50 x, labels = batch
51 outputs = self(x)
52 loss = criterion(outputs, labels)
53 self.train_acc(outputs, labels)
54 self.log("train_loss", loss)
55 self.log("train_acc", self.train_acc, on_step=True, on_epoch=False)
56 return loss
57
58 def evaluate(self, batch, stage=None):
59 x, labels = batch
60 outputs = self(x)
61 loss = criterion(outputs, labels)
62 self.valid_acc(outputs, labels)
63
64 if stage:
65 self.log(f"{stage}_loss", loss)
66 self.log(f"{stage}_acc", self.valid_acc, on_step=True, on_epoch=True)
67 return outputs
68
69 def validation_step(self, batch, batch_idx):
70 self.evaluate(batch, "val")
71
72 def test_step(self, batch, batch_idx):
73 self.evaluate(batch, "test")
74
75 def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
76 return self.evaluate(batch)
77
78 def configure_optimizers(self):
79 optimizer = optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
80 return {"optimizer": optimizer}
Client Code
Notice the training code is almost identical to the pytorch lightning standard training code. The only difference is that we added a few lines to receive and send data to the server. We mark all the changed code with number 0 to 4 to make it easier to understand.
1def define_parser():
2 parser = argparse.ArgumentParser()
3 parser.add_argument("--batch_size", type=int, default=4)
4
5 return parser.parse_args()
6
7
8def main():
9 args = define_parser()
10 batch_size = args.batch_size
11
12 # (1) flare.init() is only needed if the flare function is used (such as flare.get_site_name())
13 flare.init()
14 print(f"batch_size={batch_size}, site={flare.get_site_name()}")
15
16 model = LitNet()
17 cifar10_dm = CIFAR10DataModule(batch_size=batch_size)
18 trainer = Trainer(max_epochs=1, accelerator="auto", devices="auto")
19
20 # (2) patch the lightning trainer
21 flare.patch(trainer)
22
23 while flare.is_running():
24 # (3) receives FLModel from NVFlare
25 # Note that we don't need to pass this input_model to trainer
26 # because after flare.patch the trainer.fit/validate will get the
27 # global model internally
28 input_model = flare.receive()
29 print(f"\n[Current Round={input_model.current_round}, Site = {flare.get_site_name()}]\n")
30
31 # (4) evaluate the current global model to allow server-side model selection
32 print("--- validate global model ---")
33 trainer.validate(model, datamodule=cifar10_dm)
34
35 # perform local training starting with the received global model
36 print("--- train new model ---")
37 trainer.fit(model, datamodule=cifar10_dm)
38
39 # test local model
40 print("--- test new model ---")
41 trainer.test(ckpt_path="best", datamodule=cifar10_dm)
42
43 # get predictions
44 print("--- prediction with new best model ---")
45 trainer.predict(ckpt_path="best", datamodule=cifar10_dm)
46
47
48if __name__ == "__main__":
49 main()
The main flow of the code logic in the client.py file involves running a federated learning (FL) training logics locally on each client using PyTorch Lightning and NVFlare. Here’s a breakdown of the key steps:
Argument Parsing:
The define_parser() function is used to parse command-line arguments, specifically the –batch_size argument, which sets the batch size for data loading.
Initialization:
The main() function begins by parsing the command-line arguments to get the batch size.
The flare.init() function is called to initialize the NVFlare client, which is necessary for using certain NVFlare functions like flare.get_site_name().
Model and Data Module Setup:
An instance of LitNet, a PyTorch Lightning model, is created.
An instance of CIFAR10DataModule is created with the specified batch size to handle data loading and processing.
Trainer Configuration:
A PyTorch Lightning Trainer is configured. If a GPU is available, it is set to use it; otherwise, it defaults to CPU.
NVFlare Integration:
The flare.patch(trainer) function is called to integrate NVFlare with the PyTorch Lightning trainer. This allows the trainer to handle federated learning tasks.
Federated Learning Loop:
A loop runs while flare.is_running() returns True, indicating that the federated learning job is active.
- Within the loop:
The global model is received from the NVFlare server using flare.receive().
The current round and site name are printed for logging purposes.
The global model is validated using trainer.validate().
Local training is performed using trainer.fit(), starting with the received global model.
The local model is tested using trainer.test().
Predictions are made using trainer.predict().
Execution:
The main() function is executed if the script is run as the main module, starting the entire process.
Server Code
In federated averaging, the server code is responsible for aggregating model updates from clients, the workflow pattern is similar to scatter-gather. In this example, we will directly use the default federated averaging algorithm provided by NVFlare. The FedAvg class is defined in nvflare.app_common.workflows.fedavg.FedAvg There is no need to defined a customized server code for this example.
Job Recipe Code
The job recipe code is used to define the client and server configurations.
1
2import argparse
3
4import torchvision.datasets as datasets
5from model import LitNet
6
7from nvflare.app_opt.pt.recipes.fedavg import FedAvgRecipe
8from nvflare.recipe.sim_env import SimEnv
9
10DATASET_ROOT = "/tmp/nvflare/data"
11
12
13def define_parser():
14 parser = argparse.ArgumentParser()
15 parser.add_argument("--n_clients", type=int, default=2)
16 parser.add_argument("--num_rounds", type=int, default=2)
17 parser.add_argument("--batch_size", type=int, default=24)
18
19 return parser.parse_args()
20
21
22def download_data():
23 datasets.CIFAR10(root=DATASET_ROOT, train=True, download=True)
24 datasets.CIFAR10(root=DATASET_ROOT, train=False, download=True)
25
26
27def main():
28 args = define_parser()
29
30 n_clients = args.n_clients
31 num_rounds = args.num_rounds
32 batch_size = args.batch_size
33
34 recipe = FedAvgRecipe(
35 min_clients=n_clients,
36 num_rounds=num_rounds,
37 # Model can be specified as class instance or dict config:
38 model=LitNet(),
39 # Alternative: model={"class_path": "model.LitNet", "args": {}},
40 # For pre-trained weights: initial_ckpt="/server/path/to/pretrained.pt",
41 train_script="client.py",
42 train_args=f"--batch_size {batch_size}",
43 )
44
45 env = SimEnv(num_clients=n_clients, num_threads=n_clients)
46 recipe.execute(env=env)
47
48
49if __name__ == "__main__":
50 download_data()
51 main()
Model Input Options
The model parameter accepts two formats:
Class instance:
model=LitNet()- Convenient and PythonicDict config:
model={"class_path": "model.LitNet", "args": {}}- Better for large models
To resume from pre-trained weights:
recipe = FedAvgRecipe(
model=LitNet(),
initial_ckpt="/server/path/to/pretrained.pt", # Absolute path
...
)
Run FL Job
This section provides the command to execute the federated learning job using the job recipe defined above. Run this command in your terminal. First, run the following command to download the data:
./prepare_data.sh
Command to execute the FL job
Use the following command in your terminal to start the job with the specified number of rounds, batch size, and number of clients.
python job.py --num_rounds 2 --batch_size 16
output
# < ... skip few lines of logs ..>
# 2025-07-22 18:45:45,758 - INFO - Start FedAvg.
# 2025-07-22 18:45:45,759 - INFO - loading initial model from persistor
# 2025-07-22 18:45:45,759 - INFO - Both source_ckpt_file_full_name and ckpt_preload_path are not provided. Using the default model weights initialized on the persistor side.
# 2025-07-22 18:45:45,760 - INFO - Round 0 started.
# 2025-07-22 18:45:45,760 - INFO - Sampled clients: ['site-1', 'site-2']
# 2025-07-22 18:45:45,760 - INFO - Sending task train to ['site-1', 'site-2']
#
# < ... skip .. few lines of logs ..>
#
# 2025-07-22 18:45:50,507 - INFO - batch_size=16, site=site-1
# 2025-07-22 18:45:50,543 - INFO -
# [Current Round=0, Site = site-1]
#
# 2025-07-22 18:45:50,543 - INFO - --- validate global model ---
# 2025-07-22 18:45:50,578 - INFO - batch_size=16, site=site-2
# 2025-07-22 18:45:50,656 - INFO -
# [Current Round=0, Site = site-2]
#
# 2025-07-22 18:45:50,656 - INFO - --- validate global model ---
#
# < ... skip .. few lines of logs ..>
#
# ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
# ┃ Test metric ┃ DataLoader 0 ┃
# ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
# │ test_acc_epoch │ 0.44699999690055847 │
# │ test_loss │ 1.5125484466552734 │
# └───────────────────────────┴───────────────────────────┘
# Testing DataLoader 0: 68%|████████████████████████████████▍ | 422/625 [00:01<00:00, 276.33it/s]2025-07-22 18:46:39,629 - INFO - --- prediction with new best model ---
# Testing DataLoader 0: 76%|████████████████████████████████████▋ | 478/625 [00:01<00:00, 275.61it/s]2025-07-22 18:46:39,837 - INFO - Files already downloaded and verified
# Testing DataLoader 0: 100%|████████████████████████████████████████████████| 625/625 [00:02<00:00, 275.79it/s]
# ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
# ┃ Test metric ┃ DataLoader 0 ┃
# ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
# │ test_acc_epoch │ 0.44699999690055847 │
# │ test_loss │ 1.5125484466552734 │
# └───────────────────────────┴───────────────────────────┘
# 2025-07-22 18:46:40,370 - INFO - --- prediction with new best model ---
# 2025-07-22 18:46:40,431 - INFO - Files already downloaded and verified
# 2025-07-22 18:46:40,577 - INFO - Files already downloaded and verified
# Predicting DataLoader 0: 16%|███████▍ | 103/625 [00:00<00:01, 371.90it/s]2025-07-22 18:46:41,191 - INFO - Files already downloaded and verified
# Predicting DataLoader 0: 100%|█████████████████████████████████████████████| 625/625 [00:01<00:00, 367.54it/s]
# Predicting DataLoader 0: 53%|███████████████████████▊ | 331/625 [00:00<00:00, 346.29it/s]2025-07-22 18:46:42,615 - WARNING - request to stop the job for reason END_RUN received
# Predicting DataLoader 0: 100%|█████████████████████████████████████████████| 625/625 [00:01<00:00, 344.12it/s]
# 2025-07-22 18:46:43,476 - WARNING - request to stop the job for reason END_RUN received