Source code for nvflare.app_common.ccwf.cyclic_client_ctl

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random

from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.ccwf.client_ctl import ClientSideController
from nvflare.app_common.ccwf.common import Constant, CyclicOrder, ResultType, rotate_to_front
from nvflare.fuel.utils.validation_utils import check_non_empty_str


[docs]class CyclicClientController(ClientSideController): def __init__( self, task_name_prefix=Constant.TN_PREFIX_CYCLIC, learn_task_name=AppConstants.TASK_TRAIN, persistor_id=AppConstants.DEFAULT_PERSISTOR_ID, shareable_generator_id=AppConstants.DEFAULT_SHAREABLE_GENERATOR_ID, learn_task_check_interval=Constant.LEARN_TASK_CHECK_INTERVAL, learn_task_abort_timeout=Constant.LEARN_TASK_ABORT_TIMEOUT, learn_task_ack_timeout=Constant.LEARN_TASK_ACK_TIMEOUT, final_result_ack_timeout=Constant.FINAL_RESULT_ACK_TIMEOUT, ): check_non_empty_str("learn_task_name", learn_task_name) check_non_empty_str("persistor_id", persistor_id) check_non_empty_str("shareable_generator_id", shareable_generator_id) super().__init__( task_name_prefix=task_name_prefix, learn_task_name=learn_task_name, persistor_id=persistor_id, shareable_generator_id=shareable_generator_id, learn_task_check_interval=learn_task_check_interval, learn_task_abort_timeout=learn_task_abort_timeout, learn_task_ack_timeout=learn_task_ack_timeout, final_result_ack_timeout=final_result_ack_timeout, allow_busy_task=False, ) @staticmethod def _set_task_headers(task_data: Shareable, num_rounds, current_round, client_order): task_data.set_header(AppConstants.NUM_ROUNDS, num_rounds) task_data.set_header(AppConstants.CURRENT_ROUND, current_round) task_data.set_header(Constant.CLIENT_ORDER, client_order)
[docs] def start_workflow(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: clients = self.get_config_prop(Constant.CLIENTS) # make sure the starting client is the 1st rotate_to_front(self.me, clients) cyclic_order = self.get_config_prop(Constant.ORDER) self.log_info(fl_ctx, f"Starting cyclic workflow on clients {clients} with order {cyclic_order} ") self._set_task_headers( task_data=shareable, num_rounds=self.get_config_prop(AppConstants.NUM_ROUNDS), current_round=self.get_config_prop(Constant.START_ROUND, 0), client_order=clients, ) self.set_learn_task(task_data=shareable, fl_ctx=fl_ctx) return make_reply(ReturnCode.OK)
[docs] def do_learn_task(self, name: str, data: Shareable, fl_ctx: FLContext, abort_signal: Signal): # set status report of starting task current_round = data.get_header(AppConstants.CURRENT_ROUND) self.update_status( last_round=current_round, action="start_learn_task", ) # need to prepare the GLOBAL_MODEL prop in case the shareable generator needs it # for shareable_to_learnable after training. # Note: the "data" shareable contains full weight before training. # However, the training process may only return weight diffs. To convert to full weights again, # the original weights (GLOBAL_MODEL prop) are needed. global_weights = self.shareable_generator.shareable_to_learnable(data, fl_ctx) fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, global_weights, private=True, sticky=True) data.set_header(FLContextKey.TASK_NAME, name) # execute the task result = self.execute_learn_task(data, fl_ctx, abort_signal) rc = result.get_return_code(ReturnCode.OK) if rc != ReturnCode.OK: self.log_error(fl_ctx, f"learn executor failed: {rc}") self.update_status(action="do_learn_task", error=rc) return self.last_result = result self.last_round = current_round # see whether we need to send to next leg num_rounds = data.get_header(AppConstants.NUM_ROUNDS) current_round = data.get_header(AppConstants.CURRENT_ROUND) client_order = data.get_header(Constant.CLIENT_ORDER) all_done = False assert isinstance(client_order, list) my_idx = client_order.index(self.me) if my_idx == len(client_order) - 1: # I'm the last leg num_rounds_done = current_round - self.get_config_prop(Constant.START_ROUND, 0) + 1 if num_rounds_done >= num_rounds: # The RR is done! self.log_info(fl_ctx, f"Cyclic Done: number of rounds completed {num_rounds_done}") all_done = True else: # decide the next round order cyclic_order = self.get_config_prop(Constant.ORDER) if cyclic_order == CyclicOrder.RANDOM: random.shuffle(client_order) # make sure I'm not the first in the new order if client_order[0] == self.me: # put me at the end client_order.pop(0) client_order.append(self.me) result.set_header(Constant.CLIENT_ORDER, client_order) current_round += 1 self.log_info(fl_ctx, f"Starting new round {current_round} on clients: {client_order}") last_learnable = self.shareable_generator.shareable_to_learnable(result, fl_ctx) if all_done: self.record_last_result(fl_ctx, self.last_round, last_learnable) self.broadcast_final_result(fl_ctx, ResultType.LAST, last_learnable, round_num=self.last_round) return # send to next leg if my_idx < len(client_order) - 1: next_client = client_order[my_idx + 1] else: next_client = client_order[0] next_task_data = self.shareable_generator.learnable_to_shareable(last_learnable, fl_ctx) self._set_task_headers(next_task_data, num_rounds, current_round, client_order) sent = self.send_learn_task( targets=[next_client], request=next_task_data, fl_ctx=fl_ctx, ) if sent: self.log_info(fl_ctx, f"sent learn request to next client {next_client}")