nvflare.app_opt.pt.ditto module¶
- class PTDittoHelper(criterion, model, optimizer, device, app_dir: str, ditto_lambda: float = 0.1, model_epochs: int = 1)[source]¶
Bases:
object
Helper to be used with Ditto components. Implements the functions used for the algorithm proposed in Li et al. “Ditto: Fair and Robust Federated Learning Through Personalization” (https://arxiv.org/abs/2012.04221) using PyTorch.
- Parameters:
criterion – base loss criterion
model – the personalized model of Ditto method
optimizer – training optimizer for personalized model
device – device for personalized model training
app_dir – needed for local personalized model saving
ditto_lambda – lambda weight for Ditto prox loss term when combining with the base loss, defaults to 0.1
model_epochs – training epoch for personalized model, defaults to 1
- Returns:
None