Source code for nvflare.app_common.homomorphic_encryption.he_model_decryptor

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time

import numpy as np
import tenseal as ts
from tenseal.tensors.ckksvector import CKKSVector

import nvflare.app_common.homomorphic_encryption.he_constant as he
from nvflare.apis.dxo import MetaKey, from_shareable
from nvflare.apis.event_type import EventType
from nvflare.apis.filter import Filter
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.homomorphic_encryption.homomorphic_encrypt import (
    count_encrypted_layers,
    load_tenseal_context_from_workspace,
)


[docs]class HEModelDecryptor(Filter): def __init__(self, tenseal_context_file="client_context.tenseal"): """Filter to decrypt Shareable object using homomorphic encryption (HE) with TenSEAL https://github.com/OpenMined/TenSEAL. Args: tenseal_context_file: tenseal context files containing decryption keys and parameters """ super().__init__() self.logger.info("Using HE model decryptor.") self.tenseal_context = None self.tenseal_context_file = tenseal_context_file
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: self.tenseal_context = load_tenseal_context_from_workspace(self.tenseal_context_file, fl_ctx) elif event_type == EventType.END_RUN: self.tenseal_context = None
[docs] def decryption(self, params, encrypted_layers, fl_ctx: FLContext): n_params = len(params.keys()) self.log_info(fl_ctx, f"Running HE Decryption algorithm {n_params} variables") if encrypted_layers is None: raise ValueError("encrypted_layers is None!") start_time = time.time() n_decrypted, n_total = 0, 0 for i, param_name in enumerate(params.keys()): values = params[param_name] if encrypted_layers[param_name]: _n = values.size() n_total += _n if isinstance(values, CKKSVector): self.log_info(fl_ctx, f"Decrypting vars {i+1} of {n_params}: {param_name} with {_n} values") params[param_name] = values.decrypt() n_decrypted += _n else: self.log_info( fl_ctx, f"{i} of {n_params}: {param_name} = {np.shape(params[param_name])} already decrypted (RAW)!", ) raise ValueError("Should be encrypted at this point!") else: params[param_name] = values end_time = time.time() self.log_info(fl_ctx, f"Decryption time for {n_decrypted} of {n_total} params {end_time - start_time} seconds.") return params
[docs] def to_ckks_vector(self, params, encrypted_layers, fl_ctx: FLContext): """Convert encrypted arrays to CKKS vector.""" if encrypted_layers is None: raise ValueError("encrypted_layers is None!") start_time = time.time() result = {} n_total = 0 self.log_info(fl_ctx, f"params {len(params)} {type(params)}") for v in params: ndarray = params[v] if encrypted_layers[v]: if np.size(ndarray) > 1: raise ValueError(f"size of {v} should not be larger 1 but is {np.size(ndarray)}!") result[v] = ts.ckks_vector_from(self.tenseal_context, ndarray) n = result[v].size() else: result[v] = ndarray n = np.size(ndarray) n_total += n end_time = time.time() self.log_info(fl_ctx, f"to_ckks_vector time for {n_total} values: {end_time - start_time} seconds.") return result
[docs] def process(self, shareable: Shareable, fl_ctx: FLContext) -> Shareable: """Filter process apply to the Shareable object. Args: shareable: shareable fl_ctx: FLContext Returns: a Shareable object with decrypted model weights """ rc = shareable.get_return_code() if rc != ReturnCode.OK: # don't process if RC not OK return shareable try: return self._process(shareable, fl_ctx) except BaseException as e: self.log_exception(fl_ctx, "error performing HE decryption") raise ValueError(f"HEModelDecryptor Exception {e}")
def _process(self, shareable: Shareable, fl_ctx: FLContext): self.log_info(fl_ctx, "Running decryption...") dxo = from_shareable(shareable) encrypted_layers = dxo.get_meta_prop(key=MetaKey.PROCESSED_KEYS, default=None) if not encrypted_layers: self.log_warning(fl_ctx, "dxo does not contain PROCESSED_KEYS (do nothing)") return shareable encrypted_algo = dxo.get_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, default=None) if encrypted_algo != he.HE_ALGORITHM_CKKS: self.log_error(fl_ctx, "shareable is not HE CKKS encrypted") return shareable n_encrypted, n_total = count_encrypted_layers(encrypted_layers) self.log_info(fl_ctx, f"{n_encrypted} of {n_total} layers encrypted") decrypted_params = self.decryption( params=self.to_ckks_vector(params=dxo.data, encrypted_layers=encrypted_layers, fl_ctx=fl_ctx), encrypted_layers=encrypted_layers, fl_ctx=fl_ctx, ) dxo.data = decrypted_params dxo.remove_meta_props([MetaKey.PROCESSED_ALGORITHM, MetaKey.PROCESSED_KEYS]) dxo.update_shareable(shareable) return shareable