Source code for nvflare.app_common.executors.multi_process_executor

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import shlex
import subprocess
import threading
import time
from abc import abstractmethod

from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.apis.utils.fl_context_utils import get_serializable_data
from nvflare.fuel.common.multi_process_executor_constants import (
    CommunicateData,
    CommunicationMetaData,
    MultiProcessCommandNames,
)
from nvflare.fuel.f3.cellnet.core_cell import Message as CellMessage
from nvflare.fuel.f3.cellnet.core_cell import MessageHeaderKey
from nvflare.fuel.f3.cellnet.core_cell import ReturnCode as F3ReturnCode
from nvflare.fuel.f3.cellnet.core_cell import make_reply as F3make_reply
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.fuel.utils.class_utils import ModuleScanner
from nvflare.fuel.utils.component_builder import ComponentBuilder
from nvflare.private.defs import CellChannel, CellChannelTopic, new_cell_message
from nvflare.security.logging import secure_format_exception


[docs]class WorkerComponentBuilder(ComponentBuilder): FL_PACKAGES = ["nvflare"] FL_MODULES = ["client", "app"] def __init__(self) -> None: """Component to build workers.""" super().__init__() self.module_scanner = ModuleScanner(WorkerComponentBuilder.FL_PACKAGES, WorkerComponentBuilder.FL_MODULES, True)
[docs] def get_module_scanner(self): return self.module_scanner
[docs]class MultiProcessExecutor(Executor): def __init__(self, executor_id=None, num_of_processes=1, components=None): """Manage the multi-process execution life cycle. Arguments: executor_id: executor component ID num_of_processes: number of processes to create components: a dictionary for component classes to their arguments """ super().__init__() self.executor_id = executor_id self.components_conf = components self.components = {} self.handlers = [] self._build_components(components) if not isinstance(num_of_processes, int): raise TypeError("{} must be an instance of int but got {}".format(num_of_processes, type(num_of_processes))) if num_of_processes < 1: raise ValueError(f"{num_of_processes} must >= 1.") self.num_of_processes = num_of_processes self.executor = None self.execute_result = None self.execute_complete = None self.engine = None self.logger = logging.getLogger(self.__class__.__name__) self.conn_clients = [] self.exe_process = None self.stop_execute = False self.relay_threads = [] self.finalized = False self.event_lock = threading.Lock() self.relay_lock = threading.Lock()
[docs] @abstractmethod def get_multi_process_command(self) -> str: """Provide the command for starting multi-process execution. Returns: multi-process starting command """ return ""
def _build_components(self, components): component_builder = WorkerComponentBuilder() for item in components: cid = item.get("id", None) if not cid: raise TypeError("missing component id") self.components[cid] = component_builder.build_component(item) if isinstance(self.components[cid], FLComponent): self.handlers.append(self.components[cid])
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: self.initialize(fl_ctx) elif event_type == EventType.END_RUN: self.finalize(fl_ctx) self._pass_event_to_rank_processes(event_type, fl_ctx)
def _pass_event_to_rank_processes(self, event_type: str, fl_ctx: FLContext): event_site = fl_ctx.get_prop(FLContextKey.EVENT_ORIGIN_SITE) if self.engine: if event_site != CommunicateData.SUB_WORKER_PROCESS: with self.event_lock: try: data = { CommunicationMetaData.COMMAND: CommunicateData.HANDLE_EVENT, CommunicationMetaData.FL_CTX: get_serializable_data(fl_ctx), CommunicationMetaData.EVENT_TYPE: event_type, } # send the init data to all the child processes request = new_cell_message({}, data) self.engine.client.cell.fire_and_forget( targets=self.targets, channel=CellChannel.CLIENT_SUB_WORKER_COMMAND, topic=MultiProcessCommandNames.FIRE_EVENT, message=request, ) except Exception: # Warning: Have to set fire_event=False, otherwise it will cause dead loop on the event handling!!! self.log_warning( fl_ctx, f"Failed to relay the event to child processes. Event: {event_type}", fire_event=False, )
[docs] def initialize(self, fl_ctx: FLContext): self.executor = self.components.get(self.executor_id, None) if not isinstance(self.executor, Executor): raise ValueError( "invalid executor {}: expect Executor but got {}".format(self.executor_id, type(self.executor)) ) self._initialize_multi_process(fl_ctx)
def _initialize_multi_process(self, fl_ctx: FLContext): try: client_name = fl_ctx.get_identity_name() job_id = fl_ctx.get_job_id() self.engine = fl_ctx.get_engine() simulate_mode = fl_ctx.get_prop(FLContextKey.SIMULATE_MODE, False) cell = self.engine.client.cell # Create the internal listener for grand child process cell.make_internal_listener() command = ( self.get_multi_process_command() + " -m nvflare.private.fed.app.client.sub_worker_process" + " -m " + fl_ctx.get_prop(FLContextKey.ARGS).workspace + " -c " + client_name + " -n " + job_id + " --num_processes " + str(self.num_of_processes) + " --simulator_engine " + str(simulate_mode) + " --parent_pid " + str(os.getpid()) + " --root_url " + str(cell.get_root_url_for_child()) + " --parent_url " + str(cell.get_internal_listener_url()) ) self.logger.info(f"multi_process_executor command: {command}") # use os.setsid to create new process group ID self.exe_process = subprocess.Popen(shlex.split(command, " "), preexec_fn=os.setsid, env=os.environ.copy()) # send the init data to all the child processes cell.register_request_cb( channel=CellChannel.MULTI_PROCESS_EXECUTOR, topic=CellChannelTopic.EXECUTE_RESULT, cb=self.receive_execute_result, ) cell.register_request_cb( channel=CellChannel.MULTI_PROCESS_EXECUTOR, topic=CellChannelTopic.FIRE_EVENT, cb=self._relay_fire_event, ) self.targets = [] for i in range(self.num_of_processes): fqcn = FQCN.join([cell.get_fqcn(), str(i)]) start = time.time() while not cell.is_cell_reachable(fqcn): time.sleep(1.0) if time.time() - start > 60.0: raise RuntimeError(f"Could not reach the communication cell: {fqcn}") self.targets.append(fqcn) request = new_cell_message( {}, { CommunicationMetaData.FL_CTX: get_serializable_data(fl_ctx), CommunicationMetaData.COMPONENTS: self.components_conf, CommunicationMetaData.LOCAL_EXECUTOR: self.executor_id, }, ) replies = cell.broadcast_request( targets=self.targets, channel=CellChannel.CLIENT_SUB_WORKER_COMMAND, topic=MultiProcessCommandNames.INITIALIZE, request=request, ) for name, reply in replies.items(): if reply.get_header(MessageHeaderKey.RETURN_CODE) != F3ReturnCode.OK: self.log_exception(fl_ctx, "error initializing multi_process executor") raise ValueError(reply.get_header(MessageHeaderKey.ERROR)) except Exception as e: self.log_exception(fl_ctx, f"error initializing multi_process executor: {secure_format_exception(e)}")
[docs] def receive_execute_result(self, request: CellMessage) -> CellMessage: return_data = request.payload with self.engine.new_context() as fl_ctx: fl_ctx.props.update(return_data[CommunicationMetaData.FL_CTX].props) self.execute_result = return_data[CommunicationMetaData.SHAREABLE] self.execute_complete = True return F3make_reply(ReturnCode.OK, "", None)
def _relay_fire_event(self, request: CellMessage) -> CellMessage: data = request.payload with self.engine.new_context() as fl_ctx: event_type = data[CommunicationMetaData.EVENT_TYPE] rank_number = data[CommunicationMetaData.RANK_NUMBER] with self.relay_lock: fl_ctx.props.update(data[CommunicationMetaData.FL_CTX].props) fl_ctx.set_prop(FLContextKey.FROM_RANK_NUMBER, rank_number, private=True, sticky=False) fl_ctx.set_prop( FLContextKey.EVENT_ORIGIN_SITE, CommunicateData.SUB_WORKER_PROCESS, private=True, sticky=False, ) self.engine.fire_event(event_type, fl_ctx) return_data = {CommunicationMetaData.FL_CTX: get_serializable_data(fl_ctx)} return F3make_reply(ReturnCode.OK, "", return_data)
[docs] def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: if not self.executor: raise RuntimeError("There's no executor for task {}".format(task_name)) self.execute_complete = False self._execute_multi_process(task_name=task_name, shareable=shareable, fl_ctx=fl_ctx, abort_signal=abort_signal) while not self.execute_complete: time.sleep(0.2) return self.execute_result
def _execute_multi_process( self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal ) -> Shareable: if abort_signal.triggered: self.finalize(fl_ctx) return make_reply(ReturnCode.OK) self.engine = fl_ctx.get_engine() try: data = { CommunicationMetaData.COMMAND: CommunicateData.EXECUTE, CommunicationMetaData.TASK_NAME: task_name, CommunicationMetaData.SHAREABLE: shareable, CommunicationMetaData.FL_CTX: get_serializable_data(fl_ctx), } request = new_cell_message({}, data) self.engine.client.cell.fire_and_forget( targets=self.targets, channel=CellChannel.CLIENT_SUB_WORKER_COMMAND, topic=MultiProcessCommandNames.TASK_EXECUTION, message=request, ) except Exception: self.log_error(fl_ctx, "Multi-Process Execution error.") return make_reply(ReturnCode.EXECUTION_RESULT_ERROR)
[docs] def finalize(self, fl_ctx: FLContext): """This is called when exiting/aborting the executor.""" if self.finalized: return self.finalized = True self.stop_execute = True request = new_cell_message({}, None) self.engine.client.cell.fire_and_forget( targets=self.targets, channel=CellChannel.CLIENT_SUB_WORKER_COMMAND, topic=MultiProcessCommandNames.CLOSE, message=request, ) try: os.killpg(os.getpgid(self.exe_process.pid), 9) self.logger.debug("kill signal sent") except Exception: pass if self.exe_process: self.exe_process.terminate() # wait for all relay threads to join! for t in self.relay_threads: if t.is_alive(): t.join() self.log_info(fl_ctx, "Multi-Process Executor finalized!", fire_event=False)