Source code for nvflare.app_common.response_processors.global_weights_initializer

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.apis.client import Client
from nvflare.apis.dxo import DataKind, from_shareable
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.model import make_model_learnable
from nvflare.app_common.abstract.response_processor import ResponseProcessor
from nvflare.app_common.app_constant import AppConstants


[docs]class WeightMethod(object): FIRST = "first" CLIENT = "client"
[docs]class GlobalWeightsInitializer(ResponseProcessor): def __init__( self, weights_prop_name: str = AppConstants.GLOBAL_MODEL, weight_method: str = WeightMethod.FIRST, client_name: str = None, ): """Set global model weights based on specified weight setting method. Args: weights_prop_name: name of the prop to be set into fl_ctx for the determined global weights weight_method: the method to select final weights: one of "first", "client" client_name: the name of the client to be used as the weight provider If weight_method is "first", then use the weights reported from the first client; If weight_method is "client", then only use the weights reported from the specified client. """ if weight_method not in [WeightMethod.FIRST, WeightMethod.CLIENT]: raise ValueError(f"invalid weight_method '{weight_method}'") if weight_method == WeightMethod.CLIENT and not client_name: raise ValueError(f"client name not provided for weight method '{WeightMethod.CLIENT}'") if weight_method == WeightMethod.CLIENT and not isinstance(client_name, str): raise ValueError( f"client name should be a single string for weight method '{WeightMethod.CLIENT}' but it is {client_name} " ) ResponseProcessor.__init__(self) self.weights_prop_name = weights_prop_name self.weight_method = weight_method self.client_name = client_name self.final_weights = None
[docs] def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: """Create the data for the task to be sent to clients to collect their weights Args: task_name: name of the task fl_ctx: the FL context Returns: task data """ # reset internal state in case this processor is used multiple times self.final_weights = None return Shareable()
[docs] def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: """Process the weights submitted by a client. Args: client: the client that submitted the response task_name: name of the task response: submitted data from the client fl_ctx: FLContext Returns: boolean to indicate if the client data is acceptable. If not acceptable, the control flow will exit. """ if not isinstance(response, Shareable): self.log_error( fl_ctx, f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", ) return False try: dxo = from_shareable(response) except Exception: self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") return False if dxo.data_kind != DataKind.WEIGHTS: self.log_error( fl_ctx, f"bad response from client {client.name}: " f"data_kind should be DataKind.WEIGHTS but got {dxo.data_kind}", ) return False weights = dxo.data if not weights: self.log_error(fl_ctx, f"No model weights found from client {client.name}") return False if not self.final_weights and ( self.weight_method == WeightMethod.FIRST or (self.weight_method == WeightMethod.CLIENT and client.name == self.client_name) ): self.final_weights = weights return True
[docs] def final_process(self, fl_ctx: FLContext) -> bool: """Perform the final check on all the received weights from the clients. Args: fl_ctx: FLContext Returns: boolean indicating whether the final response processing is successful. If not successful, the control flow will exit. """ if not self.final_weights: self.log_error(fl_ctx, "no weights available from clients") return False # must set sticky to True so other controllers can get it! fl_ctx.set_prop(self.weights_prop_name, make_model_learnable(self.final_weights, {}), private=True, sticky=True) return True