Source code for nvflare.app_common.workflows.broadcast_and_process

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.response_processor import ResponseProcessor


[docs]class BroadcastAndProcess(Controller): def __init__( self, processor: Union[str, ResponseProcessor], task_name: str, min_responses_required: int = 0, wait_time_after_min_received: int = 10, timeout: int = 0, clients=None, ): """This controller broadcast a task to specified clients to collect responses, and uses the ResponseProcessor object to process the client responses. Args: processor: the processor that implements logic for client responses and final check. It must be a component id (str), or a ResponseProcessor object. task_name: name of the task to be sent to client to collect responses min_responses_required: min number responses required from clients. 0 means all. wait_time_after_min_received: how long to wait after min responses are received from clients timeout: timeout of the task. 0 means never time out clients: list of clients to send the task to. None means all clients. """ Controller.__init__(self) if not (isinstance(processor, str) or isinstance(processor, ResponseProcessor)): raise TypeError(f"value of processor must be a str or ResponseProcessor but got {type(processor)}") self.processor = processor self.task_name = task_name self.min_responses_required = min_responses_required self.wait_time_after_min_received = wait_time_after_min_received self.timeout = timeout self.clients = clients
[docs] def start_controller(self, fl_ctx: FLContext) -> None: self.log_info(fl_ctx, "Initializing BroadcastAndProcess.") if isinstance(self.processor, str): checker_id = self.processor # the processor is a component id - get the processor component engine = fl_ctx.get_engine() if not engine: self.system_panic("Engine not found. BroadcastAndProcess exiting.", fl_ctx) return self.processor = engine.get_component(checker_id) if not isinstance(self.processor, ResponseProcessor): self.system_panic( f"component {checker_id} must be a ResponseProcessor type object but got {type(self.processor)}", fl_ctx, )
[docs] def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None: task_data = self.processor.create_task_data(self.task_name, fl_ctx) if not isinstance(task_data, Shareable): self.system_panic( f"ResponseProcessor {type(self.processor)} failed to return valid task data: " f"expect Shareable but got {type(task_data)}", fl_ctx, ) return task = Task( name=self.task_name, data=task_data, timeout=self.timeout, result_received_cb=self._process_client_response, ) self.broadcast_and_wait( task=task, wait_time_after_min_received=self.wait_time_after_min_received, fl_ctx=fl_ctx, abort_signal=abort_signal, targets=self.clients, min_responses=self.min_responses_required, ) success = self.processor.final_process(fl_ctx) if not success: self.system_panic(reason=f"ResponseProcessor {type(self.processor)} failed final check!", fl_ctx=fl_ctx)
def _process_client_response(self, client_task: ClientTask, fl_ctx: FLContext) -> None: task = client_task.task response = client_task.result client = client_task.client ok = self.processor.process_client_response( client=client, task_name=task.name, response=response, fl_ctx=fl_ctx ) # Cleanup task result client_task.result = None if not ok: self.system_panic( reason=f"ResponseProcessor {type(self.processor)} failed to check client {client.name}", fl_ctx=fl_ctx )
[docs] def stop_controller(self, fl_ctx: FLContext): pass
[docs] def process_result_of_unknown_task( self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext ): pass