Source code for nvflare.app_common.filters.exclude_vars

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from typing import List, Union

from nvflare.apis.dxo import DataKind
from nvflare.apis.dxo_filter import DXO, DXOFilter
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable


[docs]class ExcludeVars(DXOFilter): def __init__(self, exclude_vars: Union[List[str], str, None] = None, data_kinds: List[str] = None): """Exclude/Remove variables from Shareable. Args: exclude_vars (Union[List[str], str, None] , optional): variables/layer names to be excluded. data_kinds: kinds of DXO object to filter Notes: Based on different types of exclude_vars, this filter has different behavior: if a list of variable/layer names, only specified variables will be excluded. if a string, it will be converted into a regular expression, only matched variables will be excluded. if not provided or other formats the Shareable remains unchanged. """ if not data_kinds: data_kinds = [DataKind.WEIGHT_DIFF, DataKind.WEIGHTS] super().__init__(supported_data_kinds=[DataKind.WEIGHTS, DataKind.WEIGHT_DIFF], data_kinds_to_filter=data_kinds) self.exclude_vars = exclude_vars self.skip = False if self.exclude_vars is not None: if not (isinstance(self.exclude_vars, list) or isinstance(self.exclude_vars, str)): self.skip = True self.logger.debug( "Need to provide a list of layer names or a string for regex matching, but got {}".format( type(self.exclude_vars) ) ) return if isinstance(self.exclude_vars, list): for var in self.exclude_vars: if not isinstance(var, str): self.skip = True self.logger.debug( "encrypt_layers needs to be a list of layer names to encrypt, but contains element of type {}".format( type(var) ) ) return self.logger.debug(f"Excluding {self.exclude_vars} from shareable") elif isinstance(self.exclude_vars, str): self.exclude_vars = re.compile(self.exclude_vars) if self.exclude_vars else None if self.exclude_vars is None: self.skip = True self.logger.debug(f'Excluding all layers based on regex matches with "{self.exclude_vars}"') else: self.logger.debug("Not excluding anything") self.skip = True
[docs] def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Union[None, DXO]: """Called by upper layer to remove variables in weights/weight_diff dictionary. When the return code of shareable is not ReturnCode.OK, this function will not perform any process and returns the shareable back. Args: dxo (DXO): DXO to be filtered. shareable: that the dxo belongs to fl_ctx (FLContext): only used for logging. Returns: filtered dxo """ if self.skip: return None weights = dxo.data # remove variables n_excluded = 0 var_names = list(weights.keys()) # make a copy of keys n_vars = len(var_names) for var_name in var_names: if (isinstance(self.exclude_vars, re.Pattern) and self.exclude_vars.search(var_name)) or ( isinstance(self.exclude_vars, list) and var_name in self.exclude_vars ): self.log_debug(fl_ctx, f"Excluding {var_name}") weights.pop(var_name, None) n_excluded += 1 if isinstance(self.exclude_vars, re.Pattern) and n_excluded == 0: self.log_warning(fl_ctx, f"No matching layers found with regex {self.exclude_vars}") self.log_debug(fl_ctx, f"Excluded {n_excluded} of {n_vars} variables. {len(weights.keys())} remaining.") dxo.data = weights return dxo