FL Experiment Tracking with MLflow

Introduction

The example for experiment tracking with MLflow has clients streaming their statistics to the server through events and the server writing the statistics to MLflow. This is similar to the FL Experiment Tracking with TensorBoard Streaming example but uses MLflow as a back end for experiment tracking. This example is in the advanced examples folder under experiment-tracking, in the “mlflow” directory.

The setup of this exercise consists of one server and two clients. The clients stream their statistics to the server as events with MLflowWriter, and only the server writes data to the MLflow tracking server with MLflowReceiver. This allows the server to be the only party that needs to deal with authentication and communication with the MLflow tracking server, and streamlines and reduces the communication by buffering the data to send.

Note

Like FL Experiment Tracking with TensorBoard Streaming, this exercise differs from Hello PyTorch by using the Learner API along with the LearnerExecutor. In short, the execution flow is abstracted away into the LearnerExecutor, allowing you to only need to implement the required methods in the Learner class. This will not be the focus of this guide, however you can learn more at Learner and LearnerExecutor.

Let’s get started. Make sure you have an environment with NVIDIA FLARE installed as described in Getting Started. First clone the repo:

$ git clone https://github.com/NVIDIA/NVFlare.git

Now remember to activate your NVIDIA FLARE Python virtual environment from the installation guide.

Install the required dependencies (NVFlare/examples/advanced/experiment-tracking/mlflow).

(nvflare-env) $ python3 -m pip install -r requirements.txt

When running, make sure to set PYTHONPATH to include the custom files of the example (replacing the path below with the appropriate path to the directory containing the “pt” directory with custom files):

(nvflare-env) $ export PYTHONPATH=${YOUR PATH TO NVFLARE}/examples/advanced/experiment-tracking

Adding MLflow Logging to Configurations

Inside the config folder there are two files, config_fed_client.json and config_fed_server.json.

Take a look at the components section of the client config at line 24. The first component is the pt_learner which contains the initialization, training, and validation logic. learner_with_mlflow.py (under NVFlare/examples/advanced/experiment-tracking/pt) contains the code written for the MLflowWriter syntax.

The MLflowWriter mimics the syntax of mlflow, to make it easier to use existing code that is using MLflow for metrics tracking. Instead of writing to the MLflow tracking server, however, the MLflowWriter creates and sends an event within NVFlare with the information to track.

Finally, ConvertToFedEvent converts local events to federated events. This changes the event analytix_log_stats into a fed event fed.analytix_log_stats, which will then be streamed from the clients to the server.

Under the component section in the server config, we have the MLflowReceiver. This component receives events from the clients and internally buffers them before writing to the MLflow tracking server. The default “buffer_flush_time” is one second, but this can be configured as an arg in the component config for MLflowReceiver.

Notice how the accepted event type "fed.analytix_log_stats" matches the output of ConvertToFedEvent in the client config.

Adding MLflow Logging to Your Code

In this exercise, all of the MLflow code additions will be made in learner_with_mlflow.py.

First we must initialize our MLflow writer we defined in the client config:

102        )
103
104        # metrics streaming setup
105        self.writer = parts.get(self.analytic_sender_id)  # user configuration from config_fed_client.json

The LearnerExecutor passes in the component dictionary into the parts parameter of initialize(). We can then access the MLflowWriter component we defined in config_fed_client.json by using the self.analytic_sender_id as the key in the parts dictionary. Note that self.analytic_sender_id defaults to "analytic_sender", but we can also define it in the client config to be passed into the constructor.

Now that our writer is set to MLflowWriter, we can write and stream training metrics to the server in local_train():

148        return outgoing_dxo.to_shareable()
149
150    def local_train(self, fl_ctx, abort_signal):
151        # Basic training
152        current_round = fl_ctx.get_prop(FLContextKey.TASK_DATA).get_header("current_round")
153        for epoch in range(self.epochs):
154            self.model.train()
155            running_loss = 0.0
156            for i, batch in enumerate(self.train_loader):
157                if abort_signal.triggered:
158                    return
159
160                images, labels = batch[0].to(self.device), batch[1].to(self.device)
161                self.optimizer.zero_grad()
162
163                predictions = self.model(images)
164                cost = self.loss(predictions, labels)
165                cost.backward()
166                self.optimizer.step()
167
168                running_loss += cost.cpu().detach().numpy() / images.size()[0]
169                if i % 3000 == 0:
170                    self.log_info(
171                        fl_ctx, f"Epoch: {epoch}/{self.epochs}, Iteration: {i}, " f"Loss: {running_loss/3000}"
172                    )
173                    running_loss = 0.0
174                    self.writer.log_text(
175                        f"last running_loss reset at '{len(self.train_loader) * epoch + i}' step",
176                        "running_loss_reset.txt",
177                    )
178
179                # Stream training loss at each step
180                current_step = self.n_iterations * self.epochs * current_round + self.n_iterations * epoch + i
181                self.writer.log_metrics({"train_loss": cost.item(), "running_loss": running_loss}, current_step)
182

We use self.writer.log_metrics() on line 178 to send training loss metrics, while on line 182 we send the validation accuracy at the end of each epoch.

You can see the currently supported methods for MLflowWriter in MLflowWriter.

Train the Model, Federated!

Now you can use admin command prompt to submit and start this example job. To do this on a proof of concept local FL system, follow the sections Setting Up the Application Environment in POC Mode and Starting the Application Environment in POC Mode if you have not already.

Running the FL System

With the admin client command prompt successfully connected and logged in, enter the command below.

> submit_job hello-pt-mlflow

Pay close attention to what happens in each of four terminals. You can see how the admin submits the job to the server and how the JobRunner on the server automatically picks up the job to deploy and start the run.

This command uploads the job configuration from the admin client to the server. A job id will be returned, and we can use that id to access job information.

Note

If we use submit_job [app] then that app will be treated as a single app job.

From time to time, you can issue check_status server in the admin client to check the entire training progress.

You should now see how the training does in the very first terminal (the one that started the server).

Viewing the MLflow UI

By default, MLflow will create an experiment log directory under a directory named “mlruns” in the workspace. For example, if your server workspace is located at “/example_workspace/workspace/example_project/prod_00/server-1”, then you can launch the MLflow UI with:

mlflow ui --backend-store-uri /example_workspace/workspace/example_project/prod_00/server-1

Accessing the results

The results of each job will usually be stored inside the server side workspace.

Please refer to access server-side workspace for accessing the server side workspace.

Shutdown FL system

Once the FL run is complete and the server has successfully aggregated the client’s results after all the rounds, and cross site model evaluation is finished, run the following commands in the fl_admin to shutdown the system (while inputting admin when prompted with password):

> shutdown client
> shutdown server
> bye

Congratulations!

Now you will be able to see the live training metrics of each client from MLflow, streamed from the server.

The full source code for this exercise can be found in examples/advanced/experiment-tracking/mlflow.