nvflare.app_opt.pt.ditto module

class PTDittoHelper(criterion, model, optimizer, device, app_dir: str, ditto_lambda: float = 0.1, model_epochs: int = 1)[source]

Bases: object

Helper to be used with Ditto components. Implements the functions used for the algorithm proposed in Li et al. “Ditto: Fair and Robust Federated Learning Through Personalization” (https://arxiv.org/abs/2012.04221) using PyTorch.

Parameters:
  • criterion – base loss criterion

  • model – the personalized model of Ditto method

  • optimizer – training optimizer for personalized model

  • device – device for personalized model training

  • app_dir – needed for local personalized model saving

  • ditto_lambda – lambda weight for Ditto prox loss term when combining with the base loss, defaults to 0.1

  • model_epochs – training epoch for personalized model, defaults to 1

Returns:

None

load_model(global_weights)[source]
abstract local_train(train_loader, model_global, abort_signal: Signal, writer)[source]
save_model(is_best=False)[source]
update_metric_save_model(metric)[source]