# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import numpy as np
from nvflare.apis.dxo import DXO, DataKind, from_shareable
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Task
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.model import model_learnable_to_dxo
from nvflare.app_common.app_constant import AlgorithmConstants, AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather
from nvflare.security.logging import secure_format_exception
[docs]class ScatterAndGatherScaffold(ScatterAndGather):
def __init__(
self,
min_clients: int = 1000,
num_rounds: int = 5,
start_round: int = 0,
wait_time_after_min_received: int = 10,
aggregator_id=AppConstants.DEFAULT_AGGREGATOR_ID,
persistor_id=AppConstants.DEFAULT_PERSISTOR_ID,
shareable_generator_id=AppConstants.DEFAULT_SHAREABLE_GENERATOR_ID,
train_task_name=AppConstants.TASK_TRAIN,
train_timeout: int = 0,
ignore_result_error: bool = False,
task_check_period: float = 0.5,
persist_every_n_rounds: int = 1,
snapshot_every_n_rounds: int = 1,
):
"""The controller for ScatterAndGatherScaffold workflow.
The model persistor (persistor_id) is used to load the initial global model which is sent to all clients.
Each client sends it's updated weights after local training which is aggregated (aggregator_id). The
shareable generator is used to convert the aggregated weights to shareable and shareable back to weight.
The model_persistor also saves the model after training.
Args:
min_clients (int, optional): The minimum number of clients responses before
SAG starts to wait for `wait_time_after_min_received`. Note that SAG will move forward when all
available clients have responded regardless of this value. Defaults to 1000.
num_rounds (int, optional): The total number of training rounds. Defaults to 5.
start_round (int, optional): Start round for training. Defaults to 0.
wait_time_after_min_received (int, optional): Time to wait before beginning aggregation after
minimum number of clients responses has been received. Defaults to 10.
aggregator_id (str, optional): ID of the aggregator component. Defaults to "aggregator".
persistor_id (str, optional): ID of the persistor component. Defaults to "persistor".
shareable_generator_id (str, optional): ID of the shareable generator. Defaults to "shareable_generator".
train_task_name (str, optional): Name of the train task. Defaults to "train".
train_timeout (int, optional): Time to wait for clients to do local training.
ignore_result_error (bool, optional): whether this controller can proceed if client result has errors.
Defaults to False.
task_check_period (float, optional): interval for checking status of tasks. Defaults to 0.5.
persist_every_n_rounds (int, optional): persist the global model every n rounds. Defaults to 1.
If n is 0 then no persist.
snapshot_every_n_rounds (int, optional): persist the server state every n rounds. Defaults to 1.
If n is 0 then no persist.
"""
super().__init__(
min_clients=min_clients,
num_rounds=num_rounds,
start_round=start_round,
wait_time_after_min_received=wait_time_after_min_received,
aggregator_id=aggregator_id,
persistor_id=persistor_id,
shareable_generator_id=shareable_generator_id,
train_task_name=train_task_name,
train_timeout=train_timeout,
ignore_result_error=ignore_result_error,
task_check_period=task_check_period,
persist_every_n_rounds=persist_every_n_rounds,
snapshot_every_n_rounds=snapshot_every_n_rounds,
)
# for SCAFFOLD
self.aggregator_ctrl = None
self._global_ctrl_weights = None
[docs] def start_controller(self, fl_ctx: FLContext) -> None:
super().start_controller(fl_ctx=fl_ctx)
self.log_info(fl_ctx, "Initializing ScatterAndGatherScaffold workflow.")
# for SCAFFOLD
if not self._global_weights:
self.system_panic("Global weights not available!", fl_ctx)
return
self._global_ctrl_weights = copy.deepcopy(self._global_weights["weights"])
# Initialize correction term with zeros
for k in self._global_ctrl_weights.keys():
self._global_ctrl_weights[k] = np.zeros_like(self._global_ctrl_weights[k])
# TODO: Print some stats of the correction magnitudes
[docs] def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None:
try:
self.log_info(fl_ctx, "Beginning ScatterAndGatherScaffold training phase.")
self._phase = AppConstants.PHASE_TRAIN
fl_ctx.set_prop(AppConstants.PHASE, self._phase, private=True, sticky=False)
fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self._num_rounds, private=True, sticky=False)
self.fire_event(AppEventType.TRAINING_STARTED, fl_ctx)
if self._current_round is None:
self._current_round = self._start_round
while self._current_round < self._start_round + self._num_rounds:
if self._check_abort_signal(fl_ctx, abort_signal):
return
self.log_info(fl_ctx, f"Round {self._current_round} started.")
fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self._global_weights, private=True, sticky=True)
fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self._current_round, private=True, sticky=True)
self.fire_event(AppEventType.ROUND_STARTED, fl_ctx)
# Create train_task
# get DXO with global model weights
dxo_global_weights = model_learnable_to_dxo(self._global_weights)
# add global SCAFFOLD controls using a DXO collection
dxo_global_ctrl = DXO(data_kind=DataKind.WEIGHT_DIFF, data=self._global_ctrl_weights)
dxo_dict = {
AppConstants.MODEL_WEIGHTS: dxo_global_weights,
AlgorithmConstants.SCAFFOLD_CTRL_GLOBAL: dxo_global_ctrl,
}
dxo_collection = DXO(data_kind=DataKind.COLLECTION, data=dxo_dict)
data_shareable = dxo_collection.to_shareable()
# add meta information
data_shareable.set_header(AppConstants.CURRENT_ROUND, self._current_round)
data_shareable.set_header(AppConstants.NUM_ROUNDS, self._num_rounds)
data_shareable.add_cookie(AppConstants.CONTRIBUTION_ROUND, self._current_round)
train_task = Task(
name=self.train_task_name,
data=data_shareable,
props={},
timeout=self._train_timeout,
before_task_sent_cb=self._prepare_train_task_data,
result_received_cb=self._process_train_result,
)
self.broadcast_and_wait(
task=train_task,
min_responses=self._min_clients,
wait_time_after_min_received=self._wait_time_after_min_received,
fl_ctx=fl_ctx,
abort_signal=abort_signal,
)
if self._check_abort_signal(fl_ctx, abort_signal):
return
self.fire_event(AppEventType.BEFORE_AGGREGATION, fl_ctx)
aggr_result = self.aggregator.aggregate(fl_ctx)
# extract aggregated weights and controls
collection_dxo = from_shareable(aggr_result)
dxo_aggr_result = collection_dxo.data.get(AppConstants.MODEL_WEIGHTS)
if not dxo_aggr_result:
self.log_error(fl_ctx, "Aggregated model weights are missing!")
return
dxo_ctrl_aggr_result = collection_dxo.data.get(AlgorithmConstants.SCAFFOLD_CTRL_DIFF)
if not dxo_ctrl_aggr_result:
self.log_error(fl_ctx, "Aggregated model weight controls are missing!")
return
fl_ctx.set_prop(AppConstants.AGGREGATION_RESULT, aggr_result, private=True, sticky=False)
self.fire_event(AppEventType.AFTER_AGGREGATION, fl_ctx)
if self._check_abort_signal(fl_ctx, abort_signal):
return
# update global model using shareable generator
self.fire_event(AppEventType.BEFORE_SHAREABLE_TO_LEARNABLE, fl_ctx)
self._global_weights = self.shareable_gen.shareable_to_learnable(dxo_aggr_result.to_shareable(), fl_ctx)
# update SCAFFOLD global controls
ctr_diff = dxo_ctrl_aggr_result.data
for v_name, v_value in ctr_diff.items():
self._global_ctrl_weights[v_name] += v_value
fl_ctx.set_prop(
AlgorithmConstants.SCAFFOLD_CTRL_GLOBAL, self._global_ctrl_weights, private=True, sticky=True
)
fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self._global_weights, private=True, sticky=True)
fl_ctx.sync_sticky()
self.fire_event(AppEventType.AFTER_SHAREABLE_TO_LEARNABLE, fl_ctx)
if self._check_abort_signal(fl_ctx, abort_signal):
return
if self._persist_every_n_rounds != 0 and (self._current_round + 1) % self._persist_every_n_rounds == 0:
self.log_info(fl_ctx, "Start persist model on server.")
self.fire_event(AppEventType.BEFORE_LEARNABLE_PERSIST, fl_ctx)
self.persistor.save(self._global_weights, fl_ctx)
self.fire_event(AppEventType.AFTER_LEARNABLE_PERSIST, fl_ctx)
self.log_info(fl_ctx, "End persist model on server.")
self.fire_event(AppEventType.ROUND_DONE, fl_ctx)
self.log_info(fl_ctx, f"Round {self._current_round} finished.")
self._current_round += 1
# need to persist snapshot after round increased because the global weights should be set to
# the last finished round's result
if self._snapshot_every_n_rounds != 0 and self._current_round % self._snapshot_every_n_rounds == 0:
self._engine.persist_components(fl_ctx, completed=False)
gc.collect()
self._phase = AppConstants.PHASE_FINISHED
self.log_info(fl_ctx, "Finished ScatterAndGatherScaffold Training.")
except Exception as e:
error_msg = f"Exception in ScatterAndGatherScaffold control_flow: {secure_format_exception(e)}"
self.log_exception(fl_ctx, error_msg)
self.system_panic(error_msg, fl_ctx)