nvflare.app_opt.pt.fedproxloss module¶
- class PTFedProxLoss(mu: float = 0.01)[source]¶
Bases:
_Loss
Compute FedProx loss: a loss penalizing the deviation from global model.
- Parameters:
mu – weighting parameter
- forward(input, target) Tensor [source]¶
Forward pass in training.
- Parameters:
input (nn.Module) – the local pytorch model
target (nn.Module) – the copy of global pytorch model when local clients received it at the beginning of each local round
- Returns:
FedProx loss term
- reduction: str¶