Source code for nvflare.app_opt.psi.dh_psi.dh_psi_server

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

# version >=1.0.3
import private_set_intersection.python as psi


[docs] class PSIServer: """ Class to represent the psi server in a two-party client, server PSI model. """ def __init__(self, items: List[str], fpr: float = 1e-9): """ Args: items: the items provided by the server fpr: The false positive ratio, note: if the fpr is very small such as 1e-11, PSI algorithm can fail due to a known bug (https://github.com/OpenMined/PSI/issues/143) """ if len(items) == 0: raise ValueError("Server items cannot be empty") self.reveal_intersection = True self.psi_server = psi.server.CreateWithNewKey(self.reveal_intersection) self.items = items self.fpr = fpr
[docs] def setup(self, client_items_size: int): """Return the psi setup Args: client_items_size (int): The length of the client items Returns: setup (ServerSetup): The server setup protobuf serialize string """ # version >= 1.0.3 setup = self.psi_server.CreateSetupMessage( self.fpr, client_items_size, self.items, psi.DataStructure.BLOOM_FILTER ) return setup.SerializeToString()
[docs] def process_request(self, client_request_msg) -> str: """Returns the corresponding response for the client to compute the private set intersection. Args: client_request_msg (Request): The client request serialized string Returns: response (Response): The server response serialized str """ req_stub = psi.Request() req_stub.ParseFromString(client_request_msg) request = req_stub response = self.psi_server.ProcessRequest(request) return response.SerializeToString()