# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The SCAFFOLD-related functions are based on https://github.com/Xtra-Computing/NIID-Bench
# MIT License
#
# Copyright (c) 2021 Yiqun Diao, Qinbin Li
#
# Copyright (c) 2020 International Business Machines
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import copy
import torch
from torch.optim import Optimizer
[docs]
def get_lr_values(optimizer: Optimizer):
"""
This function is used to get the learning rates of the optimizer.
"""
return [group["lr"] for group in optimizer.state_dict()["param_groups"]]
[docs]
class PTScaffoldHelper(object):
"""Helper to be used with SCAFFOLD components.
Implements the functions used for the algorithm proposed in
Karimireddy et al. "SCAFFOLD: Stochastic Controlled Averaging for Federated Learning"
(https://arxiv.org/abs/1910.06378) using PyTorch.
SCAFFOLD-related functions are based on https://github.com/Xtra-Computing/NIID-Bench.
See also Li et al. "Federated Learning on Non-IID Data Silos: An Experimental Study"
(https://arxiv.org/abs/2102.02079).
"""
def __init__(self):
# SCAFFOLD control terms
self.cnt = 0
self.c_global = None
self.c_local = None
self.c_delta_para = None
[docs]
def init(self, model):
# create models for SCAFFOLD correction terms
self.c_global = copy.deepcopy(model)
self.c_local = copy.deepcopy(model)
# Initialize correction term with zeros
c_init_para = model.state_dict()
for k in c_init_para.keys():
c_init_para[k] = torch.zeros_like(c_init_para[k])
self.c_global.load_state_dict(c_init_para)
self.c_local.load_state_dict(c_init_para)
[docs]
def get_params(self):
self.cnt = 0
# Adapted from https://github.com/Xtra-Computing/NIID-Bench/blob/main/experiments.py#L371
c_global_para = self.c_global.state_dict()
c_local_para = self.c_local.state_dict()
return c_global_para, c_local_para
[docs]
def model_update(self, model, curr_lr, c_global_para, c_local_para):
# Update model using scaffold controls
# See https://github.com/Xtra-Computing/NIID-Bench/blob/main/experiments.py#L391
net_para = model.state_dict()
for key in net_para:
net_para[key] = net_para[key] - curr_lr * (c_global_para[key] - c_local_para[key])
model.load_state_dict(net_para)
self.cnt += 1
[docs]
def terms_update(self, model, curr_lr, c_global_para, c_local_para, model_global):
# Update the local scaffold controls
# See https://github.com/Xtra-Computing/NIID-Bench/blob/main/experiments.py#L403
c_new_para = self.c_local.state_dict()
self.c_delta_para = copy.deepcopy(self.c_local.state_dict())
global_model_para = model_global.state_dict()
net_para = model.state_dict()
for key in net_para:
c_new_para[key] = (
c_new_para[key] - c_global_para[key] + (global_model_para[key] - net_para[key]) / (self.cnt * curr_lr)
)
self.c_delta_para[key] = (c_new_para[key] - c_local_para[key]).cpu().numpy()
self.c_local.load_state_dict(c_new_para)
[docs]
def load_global_controls(self, weights):
self.c_global.load_state_dict(weights)
[docs]
def get_delta_controls(self):
if self.c_delta_para is None:
raise ValueError("c_delta_para hasn't been computed yet!")
return self.c_delta_para