Source code for nvflare.app_common.hub.hub_controller

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import time
from typing import Union

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import (
    ClientTask,
    ControllerSpec,
    OperatorConfigKey,
    OperatorMethod,
    Task,
    TaskOperatorKey,
)
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.operator_spec import OperatorSpec
from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor
from nvflare.app_common.abstract.shareable_generator import ShareableGenerator
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.fuel.utils.constants import PipeChannelName
from nvflare.fuel.utils.pipe.pipe import Message, Pipe
from nvflare.fuel.utils.pipe.pipe_handler import PipeHandler, Topic
from nvflare.fuel.utils.validation_utils import check_object_type, check_positive_number, check_str


[docs]class BcastOperator(OperatorSpec, FLComponent): _PROP_AGGR = "aggr" def __init__(self): OperatorSpec.__init__(self) FLComponent.__init__(self) self.current_aggregator = None @staticmethod def _get_aggregator(op_description: dict, fl_ctx: FLContext): aggr_id = op_description.get(TaskOperatorKey.AGGREGATOR, "") if not aggr_id: raise RuntimeError("missing aggregator component id") engine = fl_ctx.get_engine() aggr = engine.get_component(aggr_id) if not aggr: raise RuntimeError(f"no aggregator defined for component id {aggr_id}") if not isinstance(aggr, Aggregator): raise RuntimeError(f"component {aggr_id} must be Aggregator but got {type(aggr)}") return aggr
[docs] def operate( self, op_description: dict, controller: ControllerSpec, task_name: str, task_data: Shareable, abort_signal: Signal, fl_ctx: FLContext, ) -> Union[Shareable, None]: aggr = self._get_aggregator(op_description, fl_ctx) # reset the internal state of the aggregator for next round of aggregation self.current_aggregator = aggr aggr.reset(fl_ctx) engine = fl_ctx.get_engine() total_num_clients = len(engine.get_clients()) timeout = op_description.get(TaskOperatorKey.TIMEOUT, 0) wait_time_after_min_resps = op_description.get(TaskOperatorKey.WAIT_TIME_AFTER_MIN_RESPS, 5) min_clients = op_description.get(TaskOperatorKey.MIN_TARGETS, 0) if min_clients > total_num_clients: min_clients = total_num_clients wait_time_after_min_resps = 0 targets = op_description.get(TaskOperatorKey.TARGETS, None) # data is from T1 train_task = Task( name=task_name, data=task_data, props={self._PROP_AGGR: aggr}, timeout=timeout, result_received_cb=self._process_bcast_result, ) controller.broadcast_and_wait( task=train_task, targets=targets, min_responses=min_clients, wait_time_after_min_received=wait_time_after_min_resps, fl_ctx=fl_ctx, abort_signal=abort_signal, ) aggr_result = aggr.aggregate(fl_ctx) self.current_aggregator = None return aggr_result
def _process_bcast_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None: result = client_task.result aggr = client_task.task.get_prop(self._PROP_AGGR) aggr.accept(result, fl_ctx) # Cleanup task result client_task.result = None
[docs] def process_result_of_unknown_task( self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext ): aggr = self.current_aggregator if aggr: aggr.accept(result, fl_ctx)
[docs]class RelayOperator(OperatorSpec, FLComponent): _PROP_LAST_RESULT = "last_result" _PROP_SHAREABLE_GEN = "shareable_generator" def __init__(self): OperatorSpec.__init__(self) FLComponent.__init__(self) @staticmethod def _get_shareable_generator(op_description: dict, fl_ctx: FLContext): engine = fl_ctx.get_engine() comp_id = op_description.get(TaskOperatorKey.SHAREABLE_GENERATOR, "") if not comp_id: return None shareable_generator = engine.get_component(comp_id) if not shareable_generator: raise RuntimeError(f"no shareable generator defined for component id {comp_id}") if not isinstance(shareable_generator, ShareableGenerator): raise RuntimeError(f"component {comp_id} must be ShareableGenerator but got {type(shareable_generator)}") return shareable_generator @staticmethod def _get_persistor(op_description: dict, fl_ctx: FLContext): persistor_id = op_description.get(TaskOperatorKey.PERSISTOR, "") if not persistor_id: return None engine = fl_ctx.get_engine() persistor = engine.get_component(persistor_id) if not persistor: raise RuntimeError(f"no persistor defined for component id {persistor_id}") if not isinstance(persistor, LearnablePersistor): raise RuntimeError(f"component {persistor_id} must be LearnablePersistor but got {type(persistor)}") return persistor
[docs] def operate( self, op_description: dict, controller: ControllerSpec, task_name: str, task_data: Shareable, abort_signal: Signal, fl_ctx: FLContext, ) -> Union[None, Shareable]: current_round = task_data.get_header(AppConstants.CURRENT_ROUND, None) shareable_generator = self._get_shareable_generator(op_description, fl_ctx) persistor = self._get_persistor(op_description, fl_ctx) if persistor: # The persistor should convert the TASK_DATA in the fl_ctx into a learnable # This learnable is the base for the relay learnable_base = persistor.load(fl_ctx) fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, learnable_base, private=True, sticky=False) task = Task( name=task_name, data=task_data, props={ AppConstants.CURRENT_ROUND: current_round, self._PROP_LAST_RESULT: None, self._PROP_SHAREABLE_GEN: shareable_generator, }, result_received_cb=self._process_relay_result, ) targets = op_description.get(TaskOperatorKey.TARGETS, None) task_assignment_timeout = op_description.get(TaskOperatorKey.TASK_ASSIGNMENT_TIMEOUT, 0) controller.relay_and_wait( task=task, targets=targets, task_assignment_timeout=task_assignment_timeout, fl_ctx=fl_ctx, dynamic_targets=True, abort_signal=abort_signal, ) if abort_signal.triggered: return None return task.get_prop(self._PROP_LAST_RESULT)
def _process_relay_result(self, client_task: ClientTask, fl_ctx: FLContext): # submitted shareable is stored in client_task.result # we need to update task.data with that shareable so the next target # will get the updated shareable task = client_task.task current_round = task.get_prop(AppConstants.CURRENT_ROUND) task.set_prop(self._PROP_LAST_RESULT, client_task.result) task_data = client_task.result shareable_generator = task.get_prop(self._PROP_SHAREABLE_GEN) if shareable_generator: # turn received result (a Shareable) to learnable (i.e. weight diff => weight) learnable = shareable_generator.shareable_to_learnable(client_task.result, fl_ctx) # turn the learnable to task data for the next leg (i.e. weight Learnable to weight Shareable) task_data = shareable_generator.learnable_to_shareable(learnable, fl_ctx) if current_round: task_data.set_header(AppConstants.CURRENT_ROUND, current_round) task.data = task_data client_task.result = None
[docs]class HubController(Controller): def __init__( self, pipe_id: str, task_wait_time=None, task_data_poll_interval: float = 0.1, ): Controller.__init__(self) check_positive_number("task_data_poll_interval", task_data_poll_interval) check_str("pipe_id", pipe_id) if task_wait_time is not None: check_positive_number("task_wait_time", task_wait_time) self.pipe_id = pipe_id self.operator_descs = None self.task_wait_time = task_wait_time self.task_data_poll_interval = task_data_poll_interval self.pipe = None self.pipe_handler = None self.run_ended = False self.task_abort_signal = None self.current_task_name = None self.current_task_id = None self.current_operator = None self.builtin_operators = {OperatorMethod.BROADCAST: BcastOperator(), OperatorMethod.RELAY: RelayOperator()} self.project_name = ""
[docs] def start_controller(self, fl_ctx: FLContext) -> None: self.project_name = fl_ctx.get_identity_name() # get operators engine = fl_ctx.get_engine() job_id = fl_ctx.get_job_id() workspace = engine.get_workspace() app_config_file = workspace.get_server_app_config_file_path(job_id) with open(app_config_file) as file: app_config = json.load(file) self.operator_descs = app_config.get(OperatorConfigKey.OPERATORS, {}) self.log_debug(fl_ctx, f"Got operator descriptions: {self.operator_descs}")
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): engine = fl_ctx.get_engine() if event_type == EventType.START_RUN: pipe = engine.get_component(self.pipe_id) check_object_type("pipe", pipe, Pipe) pipe.open(name=PipeChannelName.TASK) self.pipe_handler = PipeHandler(pipe) elif event_type == EventType.END_RUN: self.run_ended = True
def _abort(self, reason: str, abort_signal: Signal, fl_ctx): self.pipe_handler.notify_abort(reason) if reason: self.log_error(fl_ctx, reason) if abort_signal: abort_signal.trigger(True) def _get_operator(self, task_name: str, op_desc: dict, fl_ctx: FLContext): method_name = op_desc.get(TaskOperatorKey.METHOD) if not method_name: return None, f"bad operator in task '{task_name}' from T1 - missing method name" # see whether an Operator is defined for the method engine = fl_ctx.get_engine() operator = engine.get_component(method_name) if not operator: operator = self.builtin_operators.get(method_name, None) if not operator: return None, f"bad task '{task_name}' from T1 - no operator for '{method_name}'" if not isinstance(operator, OperatorSpec): return None, f"operator for '{method_name}' must be OperatorSpec but got {type(operator)}" return operator, ""
[docs] def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): try: self.pipe_handler.start() self._control_flow(abort_signal, fl_ctx) self.pipe_handler.stop() except Exception as ex: self.log_exception(fl_ctx, "control flow exception") self._abort(f"control_flow exception {ex}", abort_signal, fl_ctx)
def _control_flow(self, abort_signal: Signal, fl_ctx: FLContext): control_flow_start = time.time() task_start = control_flow_start while True: if self.run_ended: # tell T1 to end the run self._abort(reason="", abort_signal=abort_signal, fl_ctx=fl_ctx) return if abort_signal.triggered: # tell T1 to end the run self._abort(reason="", abort_signal=abort_signal, fl_ctx=fl_ctx) return msg = self.pipe_handler.get_next() if not msg: if self.task_wait_time and time.time() - task_start > self.task_wait_time: # timed out - tell T1 to end the RUN self._abort( reason=f"task data timeout after {self.task_wait_time} secs", abort_signal=abort_signal, fl_ctx=fl_ctx, ) return else: if msg.topic in [Topic.ABORT, Topic.END, Topic.PEER_GONE]: # the T1 peer is gone self.log_info(fl_ctx, f"T1 stopped: '{msg.topic}'") return if msg.msg_type != Message.REQUEST: self.log_info(fl_ctx, f"ignored '{msg.topic}' from T1 - not a request!") continue self.log_info(fl_ctx, f"got data for task '{msg.topic}' from T1") if not isinstance(msg.data, Shareable): self._abort( reason=f"bad data for task '{msg.topic}' from T1 - must be Shareable but got {type(msg.data)}", abort_signal=abort_signal, fl_ctx=fl_ctx, ) return task_data = msg.data task_name = task_data.get_header(ReservedHeaderKey.TASK_NAME) if not task_name: self._abort( reason=f"bad data for task '{msg.topic}' from T1 - missing task name", abort_signal=abort_signal, fl_ctx=fl_ctx, ) return task_id = task_data.get_header(ReservedHeaderKey.TASK_ID) if not task_id: self._abort( reason=f"bad data for task '{msg.topic}' from T1 - missing task id", abort_signal=abort_signal, fl_ctx=fl_ctx, ) return op_desc = task_data.get_header(ReservedHeaderKey.TASK_OPERATOR, {}) op_id = op_desc.get(TaskOperatorKey.OP_ID) if not op_id: # use task_name as the operation id op_desc[TaskOperatorKey.OP_ID] = task_name self._resolve_op_desc(op_desc, fl_ctx) operator, err = self._get_operator(task_name, op_desc, fl_ctx) if not operator: self._abort(reason=err, abort_signal=abort_signal, fl_ctx=fl_ctx) return operator_name = operator.__class__.__name__ self.log_info(fl_ctx, f"Invoking Operator {operator_name} for task {task_name}") try: current_round = task_data.get_header(AppConstants.CURRENT_ROUND, 0) fl_ctx.set_prop(AppConstants.CURRENT_ROUND, current_round, private=True, sticky=True) contrib_round = task_data.get_cookie(AppConstants.CONTRIBUTION_ROUND) if contrib_round is None: self.log_warning(fl_ctx, "CONTRIBUTION_ROUND Not Set!") self.fire_event(AppEventType.ROUND_STARTED, fl_ctx) fl_ctx.set_prop(key=FLContextKey.TASK_DATA, value=task_data, private=True, sticky=False) self.current_task_name = task_name self.current_task_id = task_id self.task_abort_signal = abort_signal self.current_operator = operator result = operator.operate( task_name=task_name, task_data=task_data, op_description=op_desc, controller=self, abort_signal=abort_signal, fl_ctx=fl_ctx, ) except: self.log_exception(fl_ctx, f"exception processing '{task_name}' from operator '{operator_name}'") result = None finally: self.task_abort_signal = None self.current_task_id = None self.current_operator = None self.fire_event(AppEventType.ROUND_DONE, fl_ctx) if not result: self.log_error(fl_ctx, f"no result from operator '{operator_name}'") result = make_reply(ReturnCode.EXECUTION_EXCEPTION) elif not isinstance(result, Shareable): self.log_error( fl_ctx, f"bad result from operator '{operator_name}': expect Shareable but got {type(result)}" ) result = make_reply(ReturnCode.EXECUTION_EXCEPTION) reply = Message.new_reply(topic=msg.topic, data=result, req_msg_id=msg.msg_id) self.pipe_handler.send_to_peer(reply) task_start = time.time() time.sleep(self.task_data_poll_interval) def _resolve_op_desc(self, op_desc: dict, fl_ctx: FLContext): """ Determine the correct operation description. There may be "operators" in job's config_fed_server.json. If present, it describes the operations for tasks, and its descriptions override op_desc that comes from task! It may specify a different method than the one in op_desc! For example, the op_desc may specify the method 'bcast', but the config could specify 'relay'. In this case, the 'relay' method will be used. Args: op_desc: the op description that comes from the task data Returns: None """ op_id = op_desc.get(TaskOperatorKey.OP_ID, None) if op_id: # see whether config is set up for this op # if so, the info in the config overrides op_desc! # first try to find project-specific definition op_config = self.operator_descs.get(f"{self.project_name}.{op_id}", None) if op_config: self.log_debug(fl_ctx, f"Use CONFIGURED OPERATORS for {self.project_name}.{op_id}") else: # try to find general definition op_config = self.operator_descs.get(op_id, None) if op_config: self.log_debug(fl_ctx, f"Use CONFIGURED OPERATORS for {op_id}") if op_config: op_desc.update(op_config) else: self.log_debug(fl_ctx, "OPERATORS NOT CONFIGURED")
[docs] def process_result_of_unknown_task( self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext ): # A late reply is received from client. # We'll include the late reply into the aggregation only if it's for the same type of tasks (i.e. # same task name). Note that the same task name could be used many times (rounds). self.log_info(fl_ctx, f"Late response received from client {client.name} for task '{task_name}'") operator = self.current_operator if task_name == self.current_task_name and operator: operator.process_result_of_unknown_task( client=client, task_name=task_name, client_task_id=client_task_id, result=result, fl_ctx=fl_ctx ) else: self.log_warning(fl_ctx, f"Dropped late response received from client {client.name} for task '{task_name}'")
[docs] def stop_controller(self, fl_ctx: FLContext): pass