nvflare.app_opt.pt.fedproxloss module

class PTFedProxLoss(mu: float = 0.01)[source]

Bases: _Loss

Compute FedProx loss: a loss penalizing the deviation from global model.

Parameters:

mu – weighting parameter

forward(input, target) Tensor[source]

Forward pass in training.

Parameters:
  • input (nn.Module) – the local pytorch model

  • target (nn.Module) – the copy of global pytorch model when local clients received it at the beginning of each local round

Returns:

FedProx loss term

reduction: str