Source code for nvflare.private.fed.app.deployer.simulator_deployer

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import shutil
import tempfile

from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.utils.dict_utils import augment
from nvflare.fuel.utils.network_utils import get_open_ports
from nvflare.private.fed.app.utils import create_admin_server
from nvflare.private.fed.simulator.simulator_client_engine import SimulatorParentClientEngine
from nvflare.private.fed.simulator.simulator_server import SimulatorServer
from nvflare.security.logging import secure_format_exception

from .base_client_deployer import BaseClientDeployer
from .server_deployer import ServerDeployer


[docs]class SimulatorDeployer(ServerDeployer): def __init__(self): super().__init__() self.open_ports = get_open_ports(2) self.admin_storage = tempfile.mkdtemp()
[docs] def create_fl_server(self, args, secure_train=False): simulator_server = self._create_simulator_server_config(self.admin_storage, args.max_clients) heart_beat_timeout = simulator_server.get("heart_beat_timeout", 600) services = SimulatorServer( project_name=simulator_server.get("name", ""), max_num_clients=simulator_server.get("max_num_clients", 100), cmd_modules=self.cmd_modules, args=args, secure_train=secure_train, snapshot_persistor=self.snapshot_persistor, overseer_agent=self.overseer_agent, heart_beat_timeout=heart_beat_timeout, ) services.deploy(args, grpc_args=simulator_server) admin_server = create_admin_server( services, server_conf=simulator_server, args=args, secure_train=False, ) admin_server.start() services.set_admin_server(admin_server) # mpm.add_cleanup_cb(admin_server.stop) return simulator_server, services
[docs] def create_fl_client(self, client_name, args): client_config, build_ctx = self._create_simulator_client_config(client_name, args) deployer = BaseClientDeployer() deployer.build(build_ctx) federated_client = deployer.create_fed_client(args) self._create_client_cell(client_config, client_name, federated_client) client_engine = SimulatorParentClientEngine(federated_client, args) federated_client.set_client_engine(client_engine) with client_engine.new_context() as fl_ctx: federated_client.register(fl_ctx=fl_ctx) # federated_client.start_heartbeat() federated_client.run_manager = None return federated_client, client_config, args, build_ctx
def _create_client_cell(self, client_config, client_name, federated_client): target = client_config["servers"][0].get("service").get("target") scheme = client_config["servers"][0].get("service").get("scheme", "grpc") credentials = {} parent_url = None cell = Cell( fqcn=client_name, root_url=scheme + "://" + target, secure=self.secure_train, credentials=credentials, create_internal_listener=False, parent_url=parent_url, ) cell.start() federated_client.cell = cell federated_client.communicator.cell = cell # if self.engine: # self.engine.admin_agent.register_cell_cb() mpm.add_cleanup_cb(cell.stop) def _create_simulator_server_config(self, admin_storage, max_clients): simulator_server = { "name": "simulator_server", "service": { "target": "localhost:" + str(self.open_ports[0]), "scheme": "tcp", }, "admin_host": "localhost", "admin_port": self.open_ports[1], "max_num_clients": max_clients, "heart_beat_timeout": 600, "num_server_workers": 4, "compression": "Gzip", "admin_storage": admin_storage, "download_job_url": "http://download.server.com/", } return simulator_server def _create_simulator_client_config(self, client_name, args): client_config = { "servers": [ { "name": "simulator_server", "service": { "target": "localhost:" + str(self.open_ports[0]), "scheme": "tcp", }, } ], "client": {"retry_timeout": 30, "compression": "Gzip"}, } resources = os.path.join(args.workspace, "local/resources.json") if os.path.exists(resources): with open(resources) as file: try: data = json.load(file) augment(to_dict=client_config, from_dict=data, from_override_to=False) except Exception as e: raise RuntimeError(f"Error processing config file {resources}: {secure_format_exception(e)}") build_ctx = { "client_name": client_name, "server_config": client_config.get("servers", []), "client_config": client_config["client"], "server_host": None, "secure_train": False, "enable_byoc": True, "overseer_agent": None, "client_components": {}, "client_handlers": None, } return client_config, build_ctx
[docs] def close(self): shutil.rmtree(self.admin_storage) super().close()