nvflare.app_opt.pt.fedopt module

class PTFedOptModelShareableGenerator(optimizer_args: dict | None = None, lr_scheduler_args: dict | None = None, source_model='model', device=None)[source]

Bases: FullModelShareableGenerator

Implement the FedOpt algorithm.

The algorithm is proposed in Reddi, Sashank, et al. “Adaptive federated optimization.” arXiv preprint arXiv:2003.00295 (2020). This SharableGenerator will update the global model using the specified PyTorch optimizer and learning rate scheduler. Note: This class will use FedOpt to optimize the global trainable parameters (i.e. self.model.named_parameters()) but use FedAvg to update any other layers such as batch norm statistics.

Parameters:
  • optimizer_args – dictionary of optimizer arguments, e.g. {‘path’: ‘torch.optim.SGD’, ‘args’: {‘lr’: 1.0}} (default).

  • lr_scheduler_args – dictionary of server-side learning rate scheduler arguments, e.g. {‘path’: ‘torch.optim.lr_scheduler.CosineAnnealingLR’, ‘args’: {‘T_max’: 100}} (default: None).

  • source_model – either a valid torch model object or a component ID of a torch model object

  • device – specify the device to run server-side optimization, e.g. “cpu” or “cuda:0” (will default to cuda if available and no device is specified).

Raises:

TypeError – when any of input arguments does not have correct type

handle_event(event_type: str, fl_ctx: FLContext)[source]

Handles events.

Parameters:
  • event_type (str) – event type fired by workflow.

  • fl_ctx (FLContext) – FLContext information.

server_update(model_diff)[source]

Updates the global model using the specified optimizer.

Parameters:

model_diff – the aggregated model differences from clients.

Returns:

The updated PyTorch model state dictionary.

shareable_to_learnable(shareable: Shareable, fl_ctx: FLContext) Learnable[source]

Convert Shareable to Learnable while doing a FedOpt update step.

Supporting data_kind == DataKind.WEIGHT_DIFF

Parameters:
  • shareable (Shareable) – Shareable to be converted

  • fl_ctx (FLContext) – FL context

Returns:

Updated global ModelLearnable.

Return type:

Model