Client API

The FLARE Client API provides an easy way for users to convert their centralized, local training code into federated learning code with the following benefits:

  • Only requires a few lines of code changes, without the need to restructure the code or implement a new class

  • Reduces the number of new FLARE specific concepts exposed to users

  • Easy adaptation from existing local training code using different frameworks (PyTorch, PyTorch Lightning, HuggingFace)

Core concept

The general structure of the popular federated learning (FL) workflow, “FedAvg” is as follows:

  1. FL server initializes an initial model

  2. For each round (global iteration):

    1. FL server sends the global model to clients

    2. Each FL client starts with this global model and trains on their own data

    3. Each FL client sends back their trained model

    4. FL server aggregates all the models and produces a new global model

On the client side, the training workflow is as follows:

  1. Receive the model from the FL server

  2. Perform local training on the received global model and/or evaluate the received global model for model selection

  3. Send the new model back to the FL server

To convert a centralized training code to federated learning, we need to adapt the code to do the following steps:

  1. Obtain the required information from received FLModel

  2. Run local training

  3. Put the results in a new FLModel to be sent back

For a general use case, there are three essential methods for the Client API:

  • init(): Initializes NVFlare Client API environment.

  • receive(): Receives model from NVFlare side.

  • send(): Sends the model to NVFlare side.

Users can use the Client API to change their centralized training code to federated learning, for example:

import nvflare.client as flare

flare.init() # 1. Initializes NVFlare Client API environment.
input_model = flare.receive() # 2. Receives model from NVFlare side.
params = input_model.params # 3. Obtain the required information from received FLModel

# original local training code begins
new_params = local_train(params)
# original local training code ends

output_model = flare.FLModel(params=new_params) # 4. Put the results in a new FLModel
flare.send(output_model) # 5. Sends the model to NVFlare side.

With 5 lines of code changes, we convert the centralized training code to federated learning setting.

After this, we can utilize the job templates and the NVIDIA FLARE Job CLI to generate a job so it can be run using NVIDIA FLARE FL Simulator or submit to a deployed NVFlare system.

Below is a table overview of key Client APIs.

Client API

API

Description

API Doc Link

init

Initializes NVFlare Client API environment.

init

receive

Receives model from NVFlare side.

receive

send

Sends the model to NVFlare side.

send

system_info

Gets NVFlare system information.

system_info

get_job_id

Gets job id.

get_job_id

get_site_name

Gets site name.

get_site_name

is_running

Returns whether the NVFlare system is up and running.

is_running

is_train

Returns whether the current task is a training task.

is_train

is_evaluate

Returns whether the current task is an evaluate task.

is_evaluate

is_submit_model

Returns whether the current task is a submit_model task.

is_submit_model

Decorator APIs

API

Description

API Doc Link

train

A decorator to wraps the training logic.

train

evaluate

A decorator to wraps the evaluate logic.

evaluate

Lightning APIs

API

Description

API Doc Link

patch

Patches the PyTorch Lightning Trainer for usage with FLARE.

patch

Metrics Logger

API

Description

API Doc Link

SummaryWriter

SummaryWriter mimics the usage of Tensorboard’s SummaryWriter.

SummaryWriter

WandBWriter

WandBWriter mimics the usage of weights and biases.

WandBWriter

MLflowWriter

MLflowWriter mimics the usage of MLflow.

MLflowWriter

Please check Client API Module nvflare.client.api for more in-depth information about all of the Client API functionalities.

If you are using PyTorch Lightning in your training code, you can check the Lightning API Module nvflare.app_opt.lightning.api.

Configuration

In the config_fed_client in the FLARE app, in order to launch the training script we use the SubprocessLauncher component. The defined script is invoked, and launch_once can be set to either launch once for the whole job, or launch a process for each task received from the server.

A corresponding LauncherExecutor is used as the executor to handle the tasks and perform the data exchange using the pipe. For the Pipe component we provide implementations of FilePipe and CellPipe.

{
  # version of the configuration
  format_version = 2

  # This is the application script which will be invoked. Client can replace this script with user's own training script.
  app_script = "cifar10.py"

  # Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx.
  app_config = ""

  # Client Computing Executors.
  executors = [
    {
      # tasks the executors are defined to handle
      tasks = ["train"]

      # This particular executor
      executor {

        # This is an executor for Client API. The underline data exchange is using Pipe.
        path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor"

        args {
          # launcher_id is used to locate the Launcher object in "components"
          launcher_id = "launcher"

          # pipe_id is used to locate the Pipe object in "components"
          pipe_id = "pipe"

          # Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds.
          # Please refer to the class docstring for all available arguments
          heartbeat_timeout = 60

          # format of the exchange parameters
          params_exchange_format =  "pytorch"

          # if the transfer_type is FULL, then it will be sent directly
          # if the transfer_type is DIFF, then we will calculate the
          # difference VS received parameters and send the difference
          params_transfer_type = "DIFF"

          # if train_with_evaluation is true, the executor will expect
          # the custom code need to send back both the trained parameters and the evaluation metric
          # otherwise only trained parameters are expected
          train_with_evaluation = true
        }
      }
    }
  ],

  # this defined an array of task data filters. If provided, it will control the data from server controller to client executor
  task_data_filters =  []

  # this defined an array of task result filters. If provided, it will control the result from client executor to server controller
  task_result_filters = []

  components =  [
    {
      # component id is "launcher"
      id = "launcher"

      # the class path of this component
      path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"

      args {
        # the launcher will invoke the script
        script = "python3 custom/{app_script}  {app_config} "
        # if launch_once is true, the SubprocessLauncher will launch once for the whole job
        # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server
        launch_once = true
      }
    }
    {
      id = "pipe"
      path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
      args {
        mode = "PASSIVE"
        site_name = "{SITE_NAME}"
        token = "{JOB_ID}"
        root_url = "{ROOT_URL}"
        secure_mode = "{SECURE_MODE}"
        workspace_dir = "{WORKSPACE}"
      }
    }
    {
      id = "metrics_pipe"
      path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
      args {
        mode = "PASSIVE"
        site_name = "{SITE_NAME}"
        token = "{JOB_ID}"
        root_url = "{ROOT_URL}"
        secure_mode = "{SECURE_MODE}"
        workspace_dir = "{WORKSPACE}"
      }
    },
    {
      id = "metric_relay"
      path = "nvflare.app_common.widgets.metric_relay.MetricRelay"
      args {
        pipe_id = "metrics_pipe"
        event_type = "fed.analytix_log_stats"
        # how fast should it read from the peer
        read_interval = 0.1
      }
    },
    {
      # we use this component so the client api `flare.init()` can get required information
      id = "config_preparer"
      path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator"
      args {
        component_ids = ["metric_relay"]
      }
    }
  ]
}

For example configurations, take a look at the job_templates directory for templates using the launcher and Client API.

Note

In that case that the user does not need to launch the process and instead has their own existing external training system, this would involve using the 3rd-Party System Integration, which is based on the same underlying mechanisms.

Examples

For examples of using Client API with different frameworks, please refer to examples/hello-world/ml-to-fl.

For additional examples, also take a look at the step-by-step series that use Client API to write the train script.