# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import time
import numpy as np
import tenseal as ts
from tenseal.tensors.ckksvector import CKKSVector
import nvflare.app_common.homomorphic_encryption.he_constant as he
from nvflare.apis.dxo import DXO, MetaKey, from_shareable
from nvflare.apis.event_type import EventType
from nvflare.apis.filter import Filter
from nvflare.apis.fl_constant import ReservedKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.homomorphic_encryption.homomorphic_encrypt import (
count_encrypted_layers,
load_tenseal_context_from_workspace,
)
[docs]class HEModelEncryptor(Filter):
def __init__(
self,
tenseal_context_file="client_context.tenseal",
encrypt_layers=None,
aggregation_weights=None,
weigh_by_local_iter=True,
):
"""Filter to encrypt Shareable object using homomorphic encryption (HE) with TenSEAL https://github.com/OpenMined/TenSEAL.
Args:
tenseal_context_file: tenseal context files containing encryption keys and parameters
encrypt_layers: if not specified (None), all layers are being encrypted;
if list of variable/layer names, only specified variables are encrypted;
if string containing regular expression (e.g. "conv"), only matched variables are being encrypted.
aggregation_weights: dictionary of client aggregation `{"client1": 1.0, "client2": 2.0, "client3": 3.0}`;
defaults to a weight of 1.0 if not specified.
weigh_by_local_iter: If true, multiply client weights on first before encryption (default: `True` which is recommended for HE)
"""
super().__init__()
self.logger.info("Using HE model encryptor.")
self.tenseal_context = None
self.tenseal_context_file = tenseal_context_file
self.aggregation_weights = aggregation_weights or {}
self.logger.info(f"client weights control: {self.aggregation_weights}")
self.weigh_by_local_iter = weigh_by_local_iter
self.n_iter = None
self.client_name = None
self.aggregation_weight = None
# choose which layers to encrypt
if encrypt_layers is not None:
if not (isinstance(encrypt_layers, list) or isinstance(encrypt_layers, str)):
raise ValueError(
"Must provide a list of layer names or a string for regex matching, but got {}".format(
type(encrypt_layers)
)
)
if isinstance(encrypt_layers, list):
for encrypt_layer in encrypt_layers:
if not isinstance(encrypt_layer, str):
raise ValueError(
"encrypt_layers needs to be a list of layer names to encrypt, but found element of type {}".format(
type(encrypt_layer)
)
)
self.encrypt_layers = encrypt_layers
self.logger.info(f"Encrypting {len(encrypt_layers)} layers")
elif isinstance(encrypt_layers, str):
self.encrypt_layers = re.compile(encrypt_layers) if encrypt_layers else None
self.logger.info(f'Encrypting all layers based on regex matches with "{encrypt_layers}"')
else:
self.encrypt_layers = [True] # needs to be list for logic in encryption()
self.logger.info("Encrypting all layers")
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
self.tenseal_context = load_tenseal_context_from_workspace(self.tenseal_context_file, fl_ctx)
elif event_type == EventType.END_RUN:
self.tenseal_context = None
[docs] def encryption(self, params, fl_ctx: FLContext):
n_params = len(params.keys())
self.log_info(fl_ctx, f"Running HE Encryption algorithm on {n_params} variables")
# parse regex encrypt layers
if isinstance(self.encrypt_layers, re.Pattern):
re_pattern = self.encrypt_layers
self.encrypt_layers = []
for var_name in params:
if re_pattern.search(var_name):
self.encrypt_layers.append(var_name)
self.log_info(fl_ctx, f"Regex found {self.encrypt_layers} matching layers.")
if len(self.encrypt_layers) == 0:
raise ValueError(f"No matching layers found with regex {re_pattern}")
start_time = time.time()
n_encrypted, n_total = 0, 0
encryption_dict = {}
vmins, vmaxs = [], []
for i, param_name in enumerate(params.keys()):
values = params[param_name].ravel()
_n = np.size(values)
n_total += _n
# weigh before encryption
if self.aggregation_weight:
values = values * np.float(self.aggregation_weight)
if self.weigh_by_local_iter:
values = values * np.float(self.n_iter)
if param_name in self.encrypt_layers or self.encrypt_layers[0] is True:
self.log_info(fl_ctx, f"Encrypting vars {i+1} of {n_params}: {param_name} with {_n} values")
vmin = np.min(params[param_name])
vmax = np.max(params[param_name])
vmins.append(vmin)
vmaxs.append(vmax)
params[param_name] = ts.ckks_vector(self.tenseal_context, values).serialize()
encryption_dict[param_name] = True
n_encrypted += _n
elif isinstance(values, CKKSVector):
self.log_error(
fl_ctx, f"{i} of {n_params}: {param_name} = {np.shape(params[param_name])} already encrypted!"
)
raise ValueError("This should not happen!")
else:
params[param_name] = values
encryption_dict[param_name] = False
end_time = time.time()
if n_encrypted == 0:
raise ValueError("Nothing has been encrypted! Check provided encrypt_layers list of layer names or regex.")
self.log_info(
fl_ctx,
f"Encryption time for {n_encrypted} of {n_total} params"
f" (encrypted value range [{np.min(vmins)}, {np.max(vmaxs)}])"
f" {end_time - start_time} seconds.",
)
# params is a dictionary. keys are layer names. values are either weights or serialized ckks_vector of weights.
# encryption_dict: keys are layer names. values are True for serialized ckks_vectors, False elsewhere.
return params, encryption_dict
[docs] def process(self, shareable: Shareable, fl_ctx: FLContext) -> Shareable:
"""Filter process apply to the Shareable object.
Args:
shareable: shareable
fl_ctx: FLContext
Returns:
a Shareable object with encrypted model weights
"""
rc = shareable.get_return_code()
if rc != ReturnCode.OK:
# don't process if RC not OK
return shareable
dxo = from_shareable(shareable)
if self.aggregation_weights:
self.client_name = shareable.get_peer_prop(ReservedKey.IDENTITY_NAME, default="?")
self.aggregation_weight = self.aggregation_weights.get(self.client_name, 1.0)
self.log_info(fl_ctx, f"weighting {self.client_name} by aggregation weight {self.aggregation_weight}")
if self.weigh_by_local_iter:
self.n_iter = dxo.get_meta_prop(MetaKey.NUM_STEPS_CURRENT_ROUND, None)
if self.n_iter is None:
raise ValueError("DXO data does not have local iterations for weighting!")
self.log_info(fl_ctx, f"weighting by local iter before encryption with {self.n_iter}")
try:
new_dxo = self._process(dxo, fl_ctx)
new_dxo.update_shareable(shareable)
except BaseException as e:
self.log_exception(fl_ctx, f"Exception occurred: {e}")
return shareable
def _process(self, dxo: DXO, fl_ctx: FLContext) -> DXO:
self.log_info(fl_ctx, "Running HE encryption...")
encrypted_params, encryption_dict = self.encryption(params=dxo.data, fl_ctx=fl_ctx)
new_dxo = DXO(data_kind=dxo.data_kind, data=encrypted_params, meta=dxo.meta)
new_dxo.set_meta_prop(key=MetaKey.PROCESSED_KEYS, value=encryption_dict)
new_dxo.set_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, value=he.HE_ALGORITHM_CKKS)
n_encrypted, n_total = count_encrypted_layers(encryption_dict)
self.log_info(fl_ctx, f"{n_encrypted} of {n_total} layers encrypted")
return new_dxo