Source code for nvflare.lighter.impl.he

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import tenseal as ts

from nvflare.lighter.constants import ProvFileName
from nvflare.lighter.spec import Builder, Project, ProvisionContext


[docs] class HEBuilder(Builder): def __init__( self, poly_modulus_degree=8192, coeff_mod_bit_sizes=None, scale_bits=40, scheme="CKKS", ): """Build Homomorphic related contents. Generates Tenseal homomorphic encryption context for server and client and writes them to server and client participant folders. Args: poly_modulus_degree: defaults to 8192. coeff_mod_bit_sizes: defaults to [60, 40, 40]. scale_bits: defaults to 40. scheme: defaults to "CKKS". """ if not coeff_mod_bit_sizes: coeff_mod_bit_sizes = [60, 40, 40] self._context = None self.scheme_type_mapping = { "CKKS": ts.SCHEME_TYPE.CKKS, "BFV": ts.SCHEME_TYPE.BFV, } self.poly_modulus_degree = poly_modulus_degree self.coeff_mod_bit_sizes = coeff_mod_bit_sizes self.scale_bits = scale_bits _scheme = scheme # Setup TenSEAL context self.scheme_type = self.scheme_type_mapping[_scheme] self.serialized = None
[docs] def initialize(self, project: Project, ctx: ProvisionContext): self._context = ts.context( self.scheme_type, poly_modulus_degree=self.poly_modulus_degree, coeff_mod_bit_sizes=self.coeff_mod_bit_sizes, encryption_type=ts.ENCRYPTION_TYPE.SYMMETRIC, ) # dynamically call different generate keys method # getattr(self._context, f'generate_{self.key_type}_keys')() self._context.generate_relin_keys() self._context.global_scale = 2**self.scale_bits
[docs] def build(self, project: Project, ctx: ProvisionContext): server = project.get_server() if server: dest_dir = ctx.get_kit_dir(server) with open(os.path.join(dest_dir, ProvFileName.SERVER_CONTEXT_TENSEAL), "wb") as f: f.write(self.get_serialized_context()) for client in project.get_clients(): dest_dir = ctx.get_kit_dir(client) with open(os.path.join(dest_dir, ProvFileName.CLIENT_CONTEXT_TENSEAL), "wb") as f: f.write(self.get_serialized_context(is_client=True))
[docs] def get_serialized_context(self, is_client=False): _serialized_context = self._context.serialize( save_public_key=is_client, save_secret_key=is_client, save_galois_keys=False, save_relin_keys=True, ) return _serialized_context