nvflare.app_opt.p2p.executors.distributed_gradient_descent module

class DGDExecutor(model: Module | None = None, loss: _Loss | None = None, train_dataloader: DataLoader | None = None, test_dataloader: DataLoader | None = None, val_dataloader: DataLoader | None = None)[source]

Bases: SyncAlgorithmExecutor

An executor that implements Stochastic Distributed Gradient Descent (DGD) in a peer-to-peer (P2P) learning setup.

Each client maintains its own local model and synchronously exchanges model parameters with its neighbors at each iteration. The model parameters are updated based on the neighbors’ parameters and local gradient descent steps. The executor also tracks and records training, validation and test losses over time.

The number of iterations and the learning rate must be provided by the controller when assigning to run the algorithm. They can be set in the extra parameters of the controller’s config with the “iterations” and “stepsize” keys.

Note

Subclasses must implement the __init__ method to initialize the model, loss function, and data loaders.

Parameters:
  • model (torch.nn.Module, optional) – The neural network model used for training.

  • loss (torch.nn.modules.loss._Loss, optional) – The loss function used for training.

  • train_dataloader (torch.utils.data.DataLoader, optional) – DataLoader for the training dataset.

  • test_dataloader (torch.utils.data.DataLoader, optional) – DataLoader for the testing dataset.

  • val_dataloader (torch.utils.data.DataLoader, optional) – DataLoader for the validation dataset.

model

The neural network model.

Type:

torch.nn.Module

loss

The loss function.

Type:

torch.nn.modules.loss._Loss

train_dataloader

DataLoader for training data.

Type:

torch.utils.data.DataLoader

test_dataloader

DataLoader for testing data.

Type:

torch.utils.data.DataLoader

val_dataloader

DataLoader for validation data.

Type:

torch.utils.data.DataLoader

train_loss_sequence

Records of training loss over time.

Type:

list[tuple]

test_loss_sequence

Records of testing loss over time.

Type:

list[tuple]

Init FLComponent.

The FLComponent is the base class of all FL Components. (executors, controllers, responders, filters, aggregators, and widgets are all FLComponents)

FLComponents have the capability to handle and fire events and contain various methods for logging.

run_algorithm(fl_ctx, shareable, abort_signal)[source]

Abstract method to execute the main P2P algorithm.

Subclasses must implement this method to define the algorithm logic.