nvflare.app_opt.pt.fedopt module¶
Bases:
FullModelShareableGenerator
Implement the FedOpt algorithm.
The algorithm is proposed in Reddi, Sashank, et al. “Adaptive federated optimization.” arXiv preprint arXiv:2003.00295 (2020). This SharableGenerator will update the global model using the specified PyTorch optimizer and learning rate scheduler. Note: This class will use FedOpt to optimize the global trainable parameters (i.e. self.model.named_parameters()) but use FedAvg to update any other layers such as batch norm statistics.
- Parameters:
optimizer_args – dictionary of optimizer arguments, e.g. {‘path’: ‘torch.optim.SGD’, ‘args’: {‘lr’: 1.0}} (default).
lr_scheduler_args – dictionary of server-side learning rate scheduler arguments, e.g. {‘path’: ‘torch.optim.lr_scheduler.CosineAnnealingLR’, ‘args’: {‘T_max’: 100}} (default: None).
source_model – either a valid torch model object or a component ID of a torch model object
device – specify the device to run server-side optimization, e.g. “cpu” or “cuda:0” (will default to cuda if available and no device is specified).
- Raises:
TypeError – when any of input arguments does not have correct type
Handles events.
- Parameters:
event_type (str) – event type fired by workflow.
fl_ctx (FLContext) – FLContext information.
Updates the global model using the specified optimizer.
- Parameters:
model_diff – the aggregated model differences from clients.
- Returns:
The updated PyTorch model state dictionary.
Convert Shareable to Learnable while doing a FedOpt update step.
Supporting data_kind == DataKind.WEIGHT_DIFF
- Parameters:
shareable (Shareable) – Shareable to be converted
fl_ctx (FLContext) – FL context
- Returns:
Updated global ModelLearnable.
- Return type:
Model