nvflare.app_opt.pt.fedopt_ctl module¶
- class FedOpt(*args, source_model: str | Module, optimizer_args: dict = {'args': {'lr': 1.0, 'momentum': 0.6}, 'path': 'torch.optim.SGD'}, lr_scheduler_args: dict = {'args': {'T_max': 3, 'eta_min': 0.9}, 'path': 'torch.optim.lr_scheduler.CosineAnnealingLR'}, device=None, **kwargs)[source]¶
Bases:
FedAvg
Implement the FedOpt algorithm. Based on FedAvg ModelController.
The algorithm is proposed in Reddi, Sashank, et al. “Adaptive federated optimization.” arXiv preprint arXiv:2003.00295 (2020). After each round, update the global model using the specified PyTorch optimizer and learning rate scheduler. Note: This class will use FedOpt to optimize the global trainable parameters (i.e. self.torch_model.named_parameters()) but use FedAvg to update any other layers such as batch norm statistics.
- Parameters:
source_model – component id of torch model object or a valid torch model object
optimizer_args – dictionary of optimizer arguments, with keys of ‘optimizer_path’ and ‘args.
lr_scheduler_args – dictionary of server-side learning rate scheduler arguments, with keys of ‘lr_scheduler_path’ and ‘args.
device – specify the device to run server-side optimization, e.g. “cpu” or “cuda:0” (will default to cuda if available and no device is specified).
- Raises:
TypeError – when any of input arguments does not have correct type