Source code for nvflare.app_common.psi.dh_psi.dh_psi_workflow

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, NamedTuple, Set

from nvflare.apis.dxo import DXO
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import PSIConst
from nvflare.app_common.psi.psi_workflow_spec import PSIWorkflow
from nvflare.app_common.workflows.broadcast_operator import BroadcastAndWait
from nvflare.utils.decorators import measure_time


[docs]class SiteSize(NamedTuple): name: str size: int
[docs]class DhPSIWorkFlow(PSIWorkflow): def __init__(self, bloom_filter_fpr: float = 1e-11): super().__init__() self.task_name = PSIConst.TASK self.bloom_filter_fpr: float = bloom_filter_fpr self.wait_time_after_min_received = 0 self.abort_signal = None self.fl_ctx = None self.controller = None self.ordered_sites: List[SiteSize] = [] self.forward_processed: Dict[str, int] = {} self.backward_processed: Dict[str, int] = {}
[docs] def initialize(self, fl_ctx: FLContext, **kwargs): self.fl_ctx = fl_ctx self.controller = kwargs["controller"]
[docs] def pre_process(self, abort_signal: Signal) -> bool: # ask client send back their item sizes # sort client by ascending order self.log_info(self.fl_ctx, f"pre_process on task {self.task_name}") if abort_signal.triggered: return False self.abort_signal = abort_signal self.prepare_sites(abort_signal)
[docs] def run(self, abort_signal: Signal): if abort_signal.triggered: return False self.abort_signal = abort_signal self.log_info(self.fl_ctx, f"order sites = {self.ordered_sites}") intersect_site = self.forward_pass(self.ordered_sites, self.forward_processed) self.log_info( self.fl_ctx, f"forward_processed sites {self.forward_processed}\n," f"intersect_sites={intersect_site}\n" f"ordered sites = {self.ordered_sites}\n", ) self.check_processed_sites(intersect_site, self.forward_processed) self.backward_processed.update(self.backward_pass(self.ordered_sites, intersect_site)) self.log_info( self.fl_ctx, f"backward_processed sites {self.backward_processed}\n," f"intersect_sites={intersect_site}\n" f"ordered sites = {self.ordered_sites}\n", ) self.check_final_intersection_sizes(intersect_site) self.log_pass_time_taken()
[docs] def check_processed_sites(self, last_site: SiteSize, processed_sites: Dict[str, int]): valid = all(value >= last_site.size for value in processed_sites.values()) if not valid: raise RuntimeError( f"Intersection calculation failed:\n" f"processed sites :{processed_sites},\n" f"last_site ={last_site} \n" f"ordered sites = {self.ordered_sites} \n" )
[docs] def check_final_intersection_sizes(self, intersect_site: SiteSize): all_equal = all(value == intersect_site.size for value in self.backward_processed.values()) if not all_equal: raise RuntimeError( f"Intersection calculation failed: the intersection sizes from all sites must be equal.\n" f"backward processed sites:{self.backward_processed},\n" f"intersect sites ={intersect_site} \n" f"ordered sites = {self.ordered_sites} \n" ) else: self.log_info(self.fl_ctx, "Intersection calculation succeed")
[docs] def log_pass_time_taken(self): self.log_info(self.fl_ctx, f"'forward_pass' took {self.forward_pass.time_taken} ms.") self.log_info(self.fl_ctx, f"'backward_pass' took {self.backward_pass.time_taken} ms.")
[docs] def post_process(self, abort_signal: Signal): pass
[docs] def finalize(self, fl_ctx: FLContext): pass
[docs] @staticmethod def get_ordered_sites(results: Dict[str, DXO]): def compare_fn(e): return e.size site_sizes = [] for site_name in results: data = results[site_name].data if PSIConst.ITEMS_SIZE in data: size = data[PSIConst.ITEMS_SIZE] else: size = 0 if size > 0: c = SiteSize(site_name, size) site_sizes.append(c) site_sizes.sort(key=compare_fn) return site_sizes
[docs] @measure_time def forward_pass(self, ordered_sites: List[SiteSize], processed: Dict[str, int]) -> SiteSize: if self.abort_signal.triggered: return ordered_sites[0] total_sites = len(ordered_sites) if total_sites <= 1: return ordered_sites[0] return self.parallel_forward_pass(ordered_sites, processed)
[docs] def pairwise_setup(self, ordered_sites: List[SiteSize]): total_sites = len(ordered_sites) n = int(total_sites / 2) task_inputs = {} for i in range(n): s = ordered_sites[i] c = ordered_sites[i + n] inputs = Shareable() inputs[PSIConst.TASK_KEY] = PSIConst.TASK_SETUP inputs[PSIConst.ITEMS_SIZE] = c.size task_inputs[s.name] = inputs bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) return {site_name: results[site_name].data[PSIConst.SETUP_MSG] for site_name in results}
[docs] def pairwise_requests(self, ordered_sites: List[SiteSize], setup_msgs: Dict[str, str]): total_sites = len(ordered_sites) n = int(total_sites / 2) task_inputs = {} for i in range(n): s = ordered_sites[i] c = ordered_sites[i + n] inputs = Shareable() inputs[PSIConst.TASK_KEY] = PSIConst.TASK_REQUEST inputs[PSIConst.SETUP_MSG] = setup_msgs[s.name] task_inputs[c.name] = inputs bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) return {site_name: results[site_name].data[PSIConst.REQUEST_MSG] for site_name in results}
[docs] def pairwise_responses(self, ordered_sites: List[SiteSize], request_msgs: Dict[str, str]): total_sites = len(ordered_sites) n = int(total_sites / 2) task_inputs = {} for i in range(n): s = ordered_sites[i] c = ordered_sites[i + n] inputs = Shareable() inputs[PSIConst.TASK_KEY] = PSIConst.TASK_RESPONSE inputs[PSIConst.REQUEST_MSG] = request_msgs[c.name] task_inputs[s.name] = inputs bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) return {site_name: results[site_name].data[PSIConst.RESPONSE_MSG] for site_name in results}
[docs] def pairwise_intersect(self, ordered_sites: List[SiteSize], response_msg: Dict[str, str]): total_sites = len(ordered_sites) n = int(total_sites / 2) task_inputs = {} for i in range(n): s = ordered_sites[i] c = ordered_sites[i + n] inputs = Shareable() inputs[PSIConst.TASK_KEY] = PSIConst.TASK_INTERSECT inputs[PSIConst.RESPONSE_MSG] = response_msg[s.name] task_inputs[c.name] = inputs bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) return {site_name: results[site_name].data[PSIConst.ITEMS_SIZE] for site_name in results}
[docs] def parallel_forward_pass(self, target_sites, processed: dict): self.log_info(self.fl_ctx, f"target_sites: {target_sites}") total_sites = len(target_sites) if total_sites < 2: final_site = target_sites[0] processed.update({final_site.name: final_site.size}) return final_site else: setup_msgs = self.pairwise_setup(target_sites) request_msgs = self.pairwise_requests(target_sites, setup_msgs) response_msgs = self.pairwise_responses(target_sites, request_msgs) it_sites = self.pairwise_intersect(target_sites, response_msgs) processed.update(it_sites) new_targets = [SiteSize(site.name, it_sites[site.name]) for site in target_sites if site.name in it_sites] if total_sites % 2 == 1: new_targets.append(target_sites[total_sites - 1]) return self.parallel_forward_pass(new_targets, processed)
[docs] @measure_time def backward_pass(self, ordered_clients: list, intersect_site: SiteSize) -> dict: processed = {} if self.abort_signal.triggered: return processed total_clients = len(ordered_clients) if total_clients <= 1: return processed status = self.parallel_backward_pass(ordered_clients, intersect_site) time_taken = self.parallel_backward_pass.time_taken self.log_info(self.fl_ctx, f"parallel_back_pass took {time_taken} (ms)") return status
[docs] @measure_time def parallel_backward_pass(self, ordered_clients: list, intersect_site: SiteSize): # parallel version other_sites = [site for site in ordered_clients if site.name != intersect_site.name] other_sites = self.get_updated_site_sizes(other_sites) s = intersect_site other_site_sizes = set([site.size for site in other_sites]) setup_msgs: Dict[str, str] = self.prepare_setup_messages(s, other_site_sizes) site_setup_msgs = {site.name: setup_msgs[str(site.size)] for site in other_sites} request_msgs: Dict[str, str] = self.create_requests(site_setup_msgs) response_msgs: Dict[str, str] = self.process_requests(s, request_msgs) return self.calculate_intersections(response_msgs)
[docs] def calculate_intersections(self, response_msg) -> Dict[str, int]: task_inputs = {} for client_name in response_msg: inputs = Shareable() inputs[PSIConst.TASK_KEY] = PSIConst.TASK_INTERSECT inputs[PSIConst.RESPONSE_MSG] = response_msg[client_name] task_inputs[client_name] = inputs bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) intersects = {client_name: results[client_name].data[PSIConst.ITEMS_SIZE] for client_name in results} self.log_info(self.fl_ctx, f"received intersections : {intersects} ") return intersects
[docs] def process_requests(self, s: SiteSize, request_msgs: Dict[str, str]) -> Dict[str, str]: task_inputs = Shareable() task_inputs[PSIConst.TASK_KEY] = PSIConst.TASK_RESPONSE task_inputs[PSIConst.REQUEST_MSG_SET] = request_msgs bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.broadcast_and_wait( task_name=self.task_name, task_input=task_inputs, fl_ctx=self.fl_ctx, targets=[s.name], abort_signal=self.abort_signal, ) dxo = results[s.name] response_msgs = dxo.data[PSIConst.RESPONSE_MSG] return response_msgs
[docs] def create_requests(self, site_setup_msgs) -> Dict[str, str]: task_inputs = {} for client_name in site_setup_msgs: inputs = Shareable() inputs[PSIConst.TASK_KEY] = PSIConst.TASK_REQUEST inputs[PSIConst.SETUP_MSG] = site_setup_msgs[client_name] task_inputs[client_name] = inputs bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) request_msgs = {client_name: results[client_name].data[PSIConst.REQUEST_MSG] for client_name in results} return request_msgs
[docs] def get_updated_site_sizes(self, ordered_sites): updated_sites = [] for site in ordered_sites: new_size = self.forward_processed.get(site.name, site.size) updated_sites.append(SiteSize(site.name, new_size)) return updated_sites
[docs] def prepare_sites(self, abort_signal): inputs = Shareable() inputs[PSIConst.TASK_KEY] = PSIConst.TASK_PREPARE inputs[PSIConst.BLOOM_FILTER_FPR] = self.bloom_filter_fpr targets = None engine = self.fl_ctx.get_engine() min_responses = len(engine.get_clients()) bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.broadcast_and_wait( task_name=self.task_name, task_input=inputs, fl_ctx=self.fl_ctx, targets=targets, min_responses=min_responses, abort_signal=abort_signal, ) self.log_info(self.fl_ctx, f"{PSIConst.TASK_PREPARE} results = {results}") if not results: abort_signal.trigger("no items to perform PSI") raise RuntimeError("There is no item to perform PSI calculation") else: self.ordered_sites = self.get_ordered_sites(results)
[docs] def prepare_setup_messages(self, s: SiteSize, other_site_sizes: Set[int]) -> Dict[str, str]: inputs = Shareable() inputs[PSIConst.TASK_KEY] = PSIConst.TASK_SETUP inputs[PSIConst.ITEMS_SIZE_SET] = other_site_sizes bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.broadcast_and_wait( task_name=self.task_name, task_input=inputs, fl_ctx=self.fl_ctx, targets=[s.name], abort_signal=self.abort_signal, ) dxo = results[s.name] return dxo.data[PSIConst.SETUP_MSG]