nvflare.app_opt.pt.fedopt_ctl module

class FedOpt(*args, source_model: str | Module, optimizer_args: dict = {'args': {'lr': 1.0, 'momentum': 0.6}, 'path': 'torch.optim.SGD'}, lr_scheduler_args: dict = {'args': {'T_max': 3, 'eta_min': 0.9}, 'path': 'torch.optim.lr_scheduler.CosineAnnealingLR'}, device=None, **kwargs)[source]

Bases: FedAvg

Implement the FedOpt algorithm. Based on FedAvg ModelController.

The algorithm is proposed in Reddi, Sashank, et al. “Adaptive federated optimization.” arXiv preprint arXiv:2003.00295 (2020). After each round, update the global model using the specified PyTorch optimizer and learning rate scheduler. Note: This class will use FedOpt to optimize the global trainable parameters (i.e. self.torch_model.named_parameters()) but use FedAvg to update any other layers such as batch norm statistics.

Parameters:
  • source_model – component id of torch model object or a valid torch model object

  • optimizer_args – dictionary of optimizer arguments, with keys of ‘optimizer_path’ and ‘args.

  • lr_scheduler_args – dictionary of server-side learning rate scheduler arguments, with keys of ‘lr_scheduler_path’ and ‘args.

  • device – specify the device to run server-side optimization, e.g. “cpu” or “cuda:0” (will default to cuda if available and no device is specified).

Raises:

TypeError – when any of input arguments does not have correct type

optimizer_update(model_diff)[source]

Updates the global model using the specified optimizer.

Parameters:

model_diff – the aggregated model differences from clients.

Returns:

The updated PyTorch model state dictionary.

run()[source]

Main run routine for the controller workflow.

update_model(global_model: FLModel, aggr_result: FLModel)[source]

Called by the run routine to update the current global model (self.model) given the aggregated result.

Parameters:
  • model – FLModel to be updated.

  • aggr_result – aggregated FLModel.

Returns: None.