Hello PyTorch

Before You Start

Feel free to refer to the detailed documentation at any point to learn more about the specifics of NVIDIA FLARE.

Make sure you have an environment with NVIDIA FLARE installed.

You can follow the installation guide on the general concept of setting up a Python virtual environment (the recommended environment) and how to install NVIDIA FLARE.


Through this exercise, you will integrate NVIDIA FLARE with the popular deep learning framework PyTorch and learn how to use NVIDIA FLARE to train a convolutional network with the CIFAR10 dataset using the included Scatter and Gather workflow.

The setup of this exercise consists of one server and two clients.

The following steps compose one cycle of weight updates, called a round:

  1. Clients are responsible for generating individual weight-updates for the model using their own CIFAR10 dataset.

  2. These updates are then sent to the server which will aggregate them to produce a model with new weights.

  3. Finally, the server sends this updated version of the model back to each client.

For this exercise, we will be working with the hello-pt application in the examples folder. Custom FL applications can contain the folders:

  1. custom: contains the custom components (simple_network.py, cifar10trainer.py)

  2. config: contains client and server configurations (config_fed_client.json, config_fed_server.json)

  3. resources: contains the logger config (log.config)

Now that you have a rough idea of what is going on, let’s get started. First clone the repo:

$ git clone https://github.com/NVIDIA/NVFlare.git

Now remember to activate your NVIDIA FLARE Python virtual environment from the installation guide.

Since you will use PyTorch and torchvision for this exercise, let’s go ahead and install both libraries:

(nvflare-env) $ python3 -m pip install torch torchvision


There is a pending fix related to Pillow, PyTorch==1.9 and Numpy. If you see exception related to enumerate(self.train_loader), downgrade your Pillow to 8.2.0.

(nvflare-env) $ python3 -m pip install torch torchvision Pillow==8.2.0

If you would like to go ahead and run the exercise now, you can skip directly to Train the Model, Federated!.


Neural Network

With all the required dependencies installed, you are ready to run a Federated Learning with two clients and one server. The training procedure and network architecture are modified from Training a Classifier.

Let’s see what an extremely simplified CIFAR10 training looks like:

15import torch
16import torch.nn as nn
17import torch.nn.functional as F
20class SimpleNetwork(nn.Module):
21    def __init__(self):
22        super(SimpleNetwork, self).__init__()
24        self.conv1 = nn.Conv2d(3, 6, 5)
25        self.pool = nn.MaxPool2d(2, 2)
26        self.conv2 = nn.Conv2d(6, 16, 5)
27        self.fc1 = nn.Linear(16 * 5 * 5, 120)
28        self.fc2 = nn.Linear(120, 84)
29        self.fc3 = nn.Linear(84, 10)
31    def forward(self, x):
32        x = self.pool(F.relu(self.conv1(x)))
33        x = self.pool(F.relu(self.conv2(x)))
34        x = torch.flatten(x, 1)  # flatten all dimensions except batch
35        x = F.relu(self.fc1(x))
36        x = F.relu(self.fc2(x))
37        x = self.fc3(x)
38        return x

This SimpleNetwork class is your convolutional neural network to train with the CIFAR10 dataset. This is not related to NVIDIA FLARE, so we implement it in a file called simple_network.py.

Dataset & Setup

Now implement the custom class Cifar10Trainer as an NVIDIA FLARE Executor in a file called cifar10trainer.py.

In a real FL experiment, each client would have their own dataset used for their local training. For simplicity’s sake, you can download the same CIFAR10 dataset from the Internet via torchvision’s datasets module. Additionally, you need to set up the optimizer, loss function and transform to process the data. You can think of all of this code as part of your local training loop, as every deep learning training has a similar setup.

Since you will encapsulate every training-related step in the Cifar10Trainer class, let’s put this preparation stage into the __init__ method:

37class Cifar10Trainer(Executor):
38    def __init__(
39        self,
40        data_path="~/data",
41        lr=0.01,
42        epochs=5,
43        train_task_name=AppConstants.TASK_TRAIN,
44        submit_model_task_name=AppConstants.TASK_SUBMIT_MODEL,
45        exclude_vars=None,
46    ):
47        """Cifar10 Trainer handles train and submit_model tasks. During train_task, it trains a
48        simple network on CIFAR10 dataset. For submit_model task, it sends the locally trained model
49        (if present) to the server.
51        Args:
52            lr (float, optional): Learning rate. Defaults to 0.01
53            epochs (int, optional): Epochs. Defaults to 5
54            train_task_name (str, optional): Task name for train task. Defaults to "train".
55            submit_model_task_name (str, optional): Task name for submit model. Defaults to "submit_model".
56            exclude_vars (list): List of variables to exclude during model loading.
57        """
58        super(Cifar10Trainer, self).__init__()
60        self._lr = lr
61        self._epochs = epochs
62        self._train_task_name = train_task_name
63        self._submit_model_task_name = submit_model_task_name
64        self._exclude_vars = exclude_vars
66        # Training setup
67        self.model = SimpleNetwork()
68        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
69        self.model.to(self.device)
70        self.loss = nn.CrossEntropyLoss()
71        self.optimizer = SGD(self.model.parameters(), lr=lr, momentum=0.9)
73        # Create Cifar10 dataset for training.
74        transforms = Compose(
75            [
76                ToTensor(),
77                Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
78            ]
79        )
80        self._train_dataset = CIFAR10(root=data_path, transform=transforms, download=True, train=True)
81        self._train_loader = DataLoader(self._train_dataset, batch_size=4, shuffle=True)
82        self._n_iterations = len(self._train_loader)

Local Train

Now that you have your network and dataset setup, in the Cifar10Trainer class. Let’s also implement a local training loop in a method called local_train:

    def local_train(self, fl_ctx, weights, abort_signal):
        # Set the model weights

        # Basic training
        for epoch in range(self._epochs):
            running_loss = 0.0
            for i, batch in enumerate(self._train_loader):
                if abort_signal.triggered:
                    # If abort_signal is triggered, we simply return.
                    # The outside function will check it again and decide steps to take.

                images, labels = batch[0].to(self.device), batch[1].to(self.device)

                predictions = self.model(images)
                cost = self.loss(predictions, labels)

                running_loss += cost.cpu().detach().numpy() / images.size()[0]
                if i % 3000 == 0:
                        fl_ctx, f"Epoch: {epoch}/{self._epochs}, Iteration: {i}, " f"Loss: {running_loss/3000}"
                    running_loss = 0.0


Everything up to this point is completely independent of NVIDIA FLARE. It is just purely a PyTorch deep learning exercise. You will now build the NVIDIA FLARE application based on this PyTorch code.

Integrate NVIDIA FLARE with Local Train

NVIDIA FLARE makes it easy to integrate your local train code into the NVIDIA FLARE API.

The simplest way to do this is to subclass the Executor class and implement one method execute, which is called every time the client receives an updated model from the server with the task “train” (the server will broadcast the “train” task in the Scatter and Gather workflow we will configure below). We can then call our local train inside the execute method.


The execute method inside the Executor class is where all of the client side computation occurs. In these exercises, we update the weights by training on a local dataset, however, it is important to remember that NVIDIA FLARE is not restricted to just deep learning. The type of data passed between the server and the clients, and the computations that the clients perform can be anything, as long as all of the FL Components agree on the same format.

Take a look at the following code:

    def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
            if task_name == self._train_task_name:
                # Get model weights
                    dxo = from_shareable(shareable)
                    self.log_error(fl_ctx, "Unable to extract dxo from shareable.")
                    return make_reply(ReturnCode.BAD_TASK_DATA)

                # Ensure data kind is weights.
                if not dxo.data_kind == DataKind.WEIGHTS:
                    self.log_error(fl_ctx, f"data_kind expected WEIGHTS but got {dxo.data_kind} instead.")
                    return make_reply(ReturnCode.BAD_TASK_DATA)

                # Convert weights to tensor. Run training
                torch_weights = {k: torch.as_tensor(v) for k, v in dxo.data.items()}
                self.local_train(fl_ctx, torch_weights, abort_signal)

                # Check the abort_signal after training.
                # local_train returns early if abort_signal is triggered.
                if abort_signal.triggered:
                    return make_reply(ReturnCode.TASK_ABORTED)

                # Save the local model after training.

                # Get the new state dict and send as weights
                new_weights = self.model.state_dict()
                new_weights = {k: v.cpu().numpy() for k, v in new_weights.items()}

                outgoing_dxo = DXO(
                    meta={MetaKey.NUM_STEPS_CURRENT_ROUND: self._n_iterations},
                return outgoing_dxo.to_shareable()
            elif task_name == self._submit_model_task_name:
                # Load local model
                ml = self.load_local_model(fl_ctx)

                # Get the model parameters and create dxo from it
                dxo = model_learnable_to_dxo(ml)
                return dxo.to_shareable()
                return make_reply(ReturnCode.TASK_UNKNOWN)
            self.log_exception(fl_ctx, "Exception in simple trainer.")
            return make_reply(ReturnCode.EXECUTION_EXCEPTION)

The concept of Shareable is described in shareable. Essentially, every NVIDIA FLARE client receives the model weights from the server in shareable format. It is then passed into the execute method, and returns a new shareable back to the server. The data is managed by using DXO (see Data Exchange Object (DXO) for details).

Thus, the first thing is to retrieve the model weights delivered by server via shareable, and this can be seen in the first part of the code block above before local_train is called.

We then perform a local train so the client’s model is trained with its own dataset.

After finishing the local train, the train method builds a new shareable with newly-trained weights and metadata and returns it back to the NVIDIA FLARE server for aggregation.

There is additional logic to handle the “submit_model” task, but that is for the CrossSiteModelEval workflow, so we will be addressing that in a later example.


The FLContext is used to set and retrieve FL related information among the FL components via set_prop() and get_prop() as well as get services provided by the underlying infrastructure. You can find more details in the documentation.

NVIDIA FLARE Server & Application

In this exercise, you can use the default settings, which leverage NVIDIA FLARE built-in components for NVIDIA FLARE server.

These built-in components are commonly used in most deep learning scenarios.

However, you are encouraged to build your own components to fully customize NVIDIA FLARE to meet your environment,

which we will demonstrate in the following exercises.

Application Configuration

Inside the config folder there are two files, config_fed_client.json and config_fed_server.json.

 2  "format_version": 2,
 4  "executors": [
 5    {
 6      "tasks": ["train", "submit_model"],
 7      "executor": {
 8        "path": "cifar10trainer.Cifar10Trainer",
 9        "args": {
10          "lr": 0.01,
11          "epochs": 1
12        }
13      }
14    },
15    {
16      "tasks": ["validate"],
17      "executor": {
18        "path": "cifar10validator.Cifar10Validator",
19        "args": {
20        }
21      }
22    }
23  ],
24  "task_result_filters": [
25  ],
26  "task_data_filters": [
27  ],
28  "components": [
29  ]

Take a look at line 8.

This is the Cifar10Trainer you just implemented.

The NVIDIA FLARE client loads this application configuration and picks your implementation.

You can easily change it to another class so your NVIDIA FLARE client has different training logic.

The tasks “train” and “submit_model” have been configured to work with the Cifar10Trainer Executor. The “validate” task for Cifar10Validator and the “submit_model” task are used for the CrossSiteModelEval workflow, so we will be addressing that in a later example.

 2  "format_version": 2,
 4  "server": {
 5    "heart_beat_timeout": 600
 6  },
 7  "task_data_filters": [],
 8  "task_result_filters": [],
 9  "components": [
10    {
11      "id": "persistor",
12      "path": "nvflare.app_common.pt.pt_file_model_persistor.PTFileModelPersistor",
13      "args": {
14        "model": {
15          "path": "simple_network.SimpleNetwork"
16        }
17      }
18    },
19    {
20      "id": "shareable_generator",
21      "path": "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator",
22      "args": {}
23    },
24    {
25      "id": "aggregator",
26      "path": "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator",
27      "args": {
28        "expected_data_kind": "WEIGHTS"
29      }
30    },
31    {
32      "id": "model_locator",
33      "path": "pt_model_locator.PTModelLocator",
34      "args": {}
35    },
36    {
37      "id": "json_generator",
38      "path": "nvflare.app_common.widgets.validation_json_generator.ValidationJsonGenerator",
39      "args": {}
40    }
41  ],
42  "workflows": [
43      {
44        "id": "scatter_and_gather",
45        "name": "ScatterAndGather",
46        "args": {
47            "min_clients" : 2,
48            "num_rounds" : 2,
49            "start_round": 0,
50            "wait_time_after_min_received": 10,
51            "aggregator_id": "aggregator",
52            "persistor_id": "persistor",
53            "shareable_generator_id": "shareable_generator",
54            "train_task_name": "train",
55            "train_timeout": 0
56        }
57      },
58      {
59        "id": "cross_site_validate",
60        "name": "CrossSiteModelEval",
61        "args": {
62          "model_locator_id": "model_locator"
63        }
64      }
65  ]

The server application configuration, like said before, leverages NVIDIA FLARE built-in components. Remember, you are encouraged to change them to your own classes whenever you have different application logic.

Note that on line 12, persistor points to PTFileModelPersistor. NVIDIA FLARE provides a built-in PyTorch implementation for a model persistor, however for other frameworks/libraries, you will have to implement your own.

The Scatter and Gather workflow is implemented by ScatterAndGather and is configured to make use of the components with id “aggregator”, “persistor”, and “shareable_generator”. The workflow code is all open source now, so feel free to study and use it as inspiration to write your own workflows to support your needs.

Train the Model, Federated!

Now you can use admin command prompt to submit and start this example job. To do this on a proof of concept local FL system, follow the sections Setting Up the Application Environment in POC Mode and Starting the Application Environment in POC Mode if you have not already.

Running the FL System

With the admin client command prompt successfully connected and logged in, enter the command below.

> submit_job hello-pt

Pay close attention to what happens in each of four terminals. You can see how the admin submits the job to the server and how the JobRunner on the server automatically picks up the job to deploy and start the run.

This command uploads the job configuration from the admin client to the server. A job id will be returned, and we can use that id to access job information.


If we use submit_job [app] then that app will be treated as a single app job.

From time to time, you can issue check_status server in the admin client to check the entire training progress.

You should now see how the training does in the very first terminal (the one that started the server).

Accessing the results

The results of each job will usually be stored inside the server side workspace.

Please refer to access server-side workspace for accessing the server side workspace.

Shutdown FL system

Once the FL run is complete and the server has successfully aggregated the client’s results after all the rounds, and cross site model evaluation is finished, run the following commands in the fl_admin to shutdown the system (while inputting admin when prompted with password):

> shutdown client
> shutdown server
> bye

Congratulations! You’ve successfully built and run your first federated learning system.

The full source code for this exercise can be found in examples/hello-pt.