Quickstart (PyTorch with TensorBoard)¶
Introduction¶
In this exercise, you will learn how to stream TensorBoard events from the clients to the server in order to visualize live training metrics from a central place on the server.
This exercise will be working with the hello-pt-tb
application in the examples folder, which builds upon Quickstart (PyTorch) by adding TensorBoard streaming.
Note
This exercise also differs from Quickstart (PyTorch), as it uses the Learner
API along with the LearnerExecutor
.
In short, the execution flow is abstracted away into the LearnerExecutor
, allowing you to only need to implement the required methods in the Learner
class.
This will not be the focus of this guide, however you can learn more at Learner
and LearnerExecutor
.
Let’s get started. Make sure you have an environment with NVIDIA FLARE installed as described in installation guide. First clone the repo:
$ git clone https://github.com/NVIDIA/NVFlare.git
Now remember to activate your NVIDIA FLARE Python virtual environment from the installation guide. Since you will use PyTorch, torchvision, and TensorBoard for this exercise, let’s go ahead and install these libraries:
(nvflare-env) $ python3 -m pip install torch torchvision tensorboard
Adding TensorBoard Streaming to Configurations¶
Inside the config folder there are two files, config_fed_client.json
and config_fed_server.json
.
1{
2 "format_version": 2,
3
4 "executors": [
5 {
6 "tasks": [
7 "train",
8 "submit_model",
9 "validate"
10 ],
11 "executor": {
12 "id": "Executor",
13 "path": "nvflare.app_common.executors.learner_executor.LearnerExecutor",
14 "args": {
15 "learner_id": "pt_learner"
16 }
17 }
18 }
19 ],
20 "task_result_filters": [
21 ],
22 "task_data_filters": [
23 ],
24 "components": [
25 {
26 "id": "pt_learner",
27 "path": "pt_learner.PTLearner",
28 "args": {
29 "lr": 0.01,
30 "epochs": 5,
31 "analytic_sender_id": "analytic_sender"
32 }
33 },
34 {
35 "id": "analytic_sender",
36 "name": "AnalyticsSender",
37 "args": {}
38 },
39 {
40 "id": "event_to_fed",
41 "name": "ConvertToFedEvent",
42 "args": {"events_to_convert": ["analytix_log_stats"], "fed_event_prefix": "fed."}
43 }
44 ]
45}
Take a look at the components section of the client config at line 24. The first component is the pt_learner
which contains the initialization, training, and validation logic.
pt_learner.py
is where we will add our TensorBoard streaming changes.
Next we have the AnalyticsSender
, which implements some common methods that follow the signatures from the PyTorch SummaryWriter.
This makes it easy for the pt_learner
to log metrics and send events.
Finally, we have the ConvertToFedEvent
, which converts local events to federated events.
This changes the event analytix_log_stats
into a fed event fed.analytix_log_stats
, which will then be streamed from the clients to the server.
1{
2 "format_version": 2,
3
4 "server": {
5 "heart_beat_timeout": 600
6 },
7 "task_data_filters": [],
8 "task_result_filters": [],
9 "components": [
10 {
11 "id": "persistor",
12 "name": "PTFileModelPersistor",
13 "args": {
14 "model": {
15 "path": "simple_network.SimpleNetwork"
16 }
17 }
18 },
19 {
20 "id": "shareable_generator",
21 "path": "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator",
22 "args": {}
23 },
24 {
25 "id": "aggregator",
26 "path": "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator",
27 "args": {
28 "expected_data_kind": "WEIGHTS"
29 }
30 },
31 {
32 "id": "model_locator",
33 "path": "nvflare.app_common.pt.pt_file_model_locator.PTFileModelLocator",
34 "args": {
35 "pt_persistor_id": "persistor"
36 }
37 },
38 {
39 "id": "json_generator",
40 "path": "nvflare.app_common.widgets.validation_json_generator.ValidationJsonGenerator",
41 "args": {}
42 },
43 {
44 "id": "tb_analytics_receiver",
45 "name": "TBAnalyticsReceiver",
46 "args": {"events": ["fed.analytix_log_stats"]}
47 }
48 ],
49 "workflows": [
50 {
51 "id": "scatter_and_gather",
52 "name": "ScatterAndGather",
53 "args": {
54 "min_clients" : 2,
55 "num_rounds" : 1,
56 "start_round": 0,
57 "wait_time_after_min_received": 10,
58 "aggregator_id": "aggregator",
59 "persistor_id": "persistor",
60 "shareable_generator_id": "shareable_generator",
61 "train_task_name": "train",
62 "train_timeout": 0
63 }
64 },
65 {
66 "id": "cross_site_validate",
67 "name": "CrossSiteModelEval",
68 "args": {
69 "model_locator_id": "model_locator"
70 }
71 }
72 ]
73}
Under the component section in the server config, we have the TBAnalyticsReceiver
of type AnalyticsReceiver
.
This component receives TensorBoard events from the clients and saves them to a specified folder (default tb_events
) under the server’s run folder. Notice how the accepted event type
"fed.analytix_log_stats"
matches the output of ConvertToFedEvent
in the client config.
Adding TensorBoard Streaming to your Code¶
In this exercise, all of the TensorBoard code additions will be made in pt_learner.py
.
First we must initalize our TensorBoard writer to the AnalyticsSender
we defined in the client config:
61 def initialize(self, parts: dict, fl_ctx: FLContext):
62 # Tensorboard streaming setup
63 self.writer = parts.get(self.analytic_sender_id) # user configuration from config_fed_client.json
64 if not self.writer: # else use local TensorBoard writer only
65 self.writer = SummaryWriter(fl_ctx.get_prop(FLContextKey.APP_ROOT))
The LearnerExecutor
passes in the component dictionary into the parts
parameter of initialize()
.
We can then access the AnalyticsSender
component we defined in config_fed_client.json
by using the self.analytic_sender_id
as the key in the parts
dictionary.
Note that self.analytic_sender_id
defaults to "analytic_sender"
, but we can also define it in the client config to be passed into the constructor.
Now that our TensorBoard writer is set to AnalyticsSender
, we can write and stream training metrics to the server in local_train()
:
127 def local_train(self, fl_ctx, weights, abort_signal):
128 # Set the model weights
129 self.model.load_state_dict(state_dict=weights)
130
131 # Basic training
132 for epoch in range(self.epochs):
133 self.model.train()
134 running_loss = 0.0
135 for i, batch in enumerate(self.train_loader):
136 if abort_signal.triggered:
137 return
138
139 images, labels = batch[0].to(self.device), batch[1].to(self.device)
140 self.optimizer.zero_grad()
141
142 predictions = self.model(images)
143 cost = self.loss(predictions, labels)
144 cost.backward()
145 self.optimizer.step()
146
147 running_loss += (cost.cpu().detach().numpy()/images.size()[0])
148 if i % 3000 == 0:
149 self.log_info(fl_ctx, f"Epoch: {epoch}/{self.epochs}, Iteration: {i}, "
150 f"Loss: {running_loss/3000}")
151 running_loss = 0.0
152
153 # Stream training loss at each step
154 current_step = len(self.train_loader) * epoch + i
155 self.writer.add_scalar("train_loss", cost.item(), current_step)
156
157 # Stream validation accuracy at the end of each epoch
158 metric = self.local_validate(self.test_loader, abort_signal)
159 self.writer.add_scalar("validation_accuracy", metric, epoch)
We use add_scalar(tag, scalar, global_step)
on line 155 to send training loss metrics, while on line 159 we send the validation accuracy at the end of each epoch.
You can learn more about other supported writer methods in AnalyticsSender
.
Viewing the TensorBoard Dashboard during Training¶
Now you can use admin commands to upload, deploy, and start this example app. To do this on a proof of concept local FL system, follow the sections Setting Up the Application Environment in POC Mode and Starting the Application Environment in POC Mode if you have not already.
Log into the Admin client by entering admin
for both the username and password.
Then, use these Admin commands to run the experiment:
> set_run_number 1
> upload_app hello-pt-tb
> deploy_app hello-pt-tb
> start_app all
On the client side, the AnalyticsSender
works as a TensorBoard SummaryWriter. Instead of writing to TB files, it actually generates NVFLARE events of type analytix_log_stats
.
The ConvertToFedEvent
widget will turn the event analytix_log_stats
into a fed event fed.analytix_log_stats
, which will be delivered to the server side.
On the server side, the TBAnalyticsReceiver
is configured to process fed.analytix_log_stats
events, which writes received TB data into appropriate TB files on the server
(defaults to server/run_1/tb_events
).
To view training metrics that are being streamed to the server, run:
tensorboard --logdir=poc/server/run_1/tb_events
Note: if the server is running on a remote machine, use port forwarding to view the TensorBoard dashboard in a browser. For example:
ssh -L {local_machine_port}:127.0.0.1:6006 user@server_ip
Congratulations! Now you will be able to see the live training metrics of each client from a central place on the server.