Source code for nvflare.app_common.filters.convert_weights

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable
from nvflare.apis.dxo_filter import DXOFilter
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable


[docs]class ConvertWeights(DXOFilter): WEIGHTS_TO_DIFF = "weights_to_diff" DIFF_TO_WEIGHTS = "diff_to_weights" def __init__(self, direction: str): """Convert WEIGHTS to WEIGHT_DIFF or vice versa. Args: direction (str): control conversion direction. Either weights_to_diff or diff_to_weights. Raises: ValueError: when the direction string is neither weights_to_diff nor diff_to_weights """ DXOFilter.__init__( self, supported_data_kinds=[DataKind.WEIGHT_DIFF, DataKind.WEIGHTS], data_kinds_to_filter=None ) if direction not in (self.WEIGHTS_TO_DIFF, self.DIFF_TO_WEIGHTS): raise ValueError( f"invalid convert direction {direction}: must be in {(self.WEIGHTS_TO_DIFF, self.DIFF_TO_WEIGHTS)}" ) self.direction = direction def _get_base_weights(self, fl_ctx: FLContext): task_data = fl_ctx.get_prop(FLContextKey.TASK_DATA, None) if not isinstance(task_data, Shareable): self.log_error(fl_ctx, f"invalid task data: expect Shareable but got {type(task_data)}") return None try: dxo = from_shareable(task_data) except ValueError: self.log_error(fl_ctx, "invalid task data: no DXO") return None if dxo.data_kind != DataKind.WEIGHTS: self.log_info(fl_ctx, f"ignored task: expect data to be WEIGHTS but got {dxo.data_kind}") return None processed_algo = dxo.get_meta_prop(MetaKey.PROCESSED_ALGORITHM, None) if processed_algo: self.log_info(fl_ctx, f"ignored task since its processed by {processed_algo}") return None return dxo.data
[docs] def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Union[None, DXO]: """Called by runners to perform weight conversion. Args: dxo (DXO): dxo to be processed. shareable: the shareable that the dxo belongs to fl_ctx (FLContext): this context must include TASK_DATA, which is another shareable containing base weights. If not, the input shareable will be returned. Returns: filtered result """ base_weights = self._get_base_weights(fl_ctx) if not base_weights: return None processed_algo = dxo.get_meta_prop(MetaKey.PROCESSED_ALGORITHM, None) if processed_algo: self.log_info(fl_ctx, f"cannot process task result since its processed by {processed_algo}") return None if self.direction == self.WEIGHTS_TO_DIFF: if dxo.data_kind != DataKind.WEIGHTS: self.log_warning(fl_ctx, f"cannot process task result: expect WEIGHTS but got {dxo.data_kind}") return None new_weights = dxo.data for k, _ in new_weights.items(): if k in base_weights: new_weights[k] -= base_weights[k] dxo.data_kind = DataKind.WEIGHT_DIFF else: # diff to weights if dxo.data_kind != DataKind.WEIGHT_DIFF: self.log_warning(fl_ctx, f"cannot process task result: expect WEIGHT_DIFF but got {dxo.data_kind}") return None new_weights = dxo.data for k, _ in new_weights.items(): if k in base_weights: new_weights[k] += base_weights[k] dxo.data_kind = DataKind.WEIGHTS return dxo