Source code for nvflare.dashboard.application.blob

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import json
import os
import subprocess
import tempfile

from nvflare.lighter import tplt_utils, utils

from .cert import CertPair, Entity, deserialize_ca_key, make_cert
from .models import Client, Project, User

lighter_folder = os.path.dirname(utils.__file__)
template = utils.load_yaml(os.path.join(lighter_folder, "impl", "master_template.yml"))


[docs]def get_csp_template(csp, participant, template): return template[f"{csp}_start_{participant}_sh"]
[docs]def get_csp_start_script_name(csp): return f"{csp}_start.sh"
def _write(file_full_path, content, mode, exe=False): mode = mode + "w" with open(file_full_path, mode) as f: f.write(content) if exe: os.chmod(file_full_path, 0o755)
[docs]def gen_overseer(key): project = Project.query.first() entity = Entity(project.overseer) issuer = Entity(project.short_name) signing_cert_pair = CertPair(issuer, project.root_key, project.root_cert) cert_pair = make_cert(entity, signing_cert_pair) with tempfile.TemporaryDirectory() as tmp_dir: overseer_dir = os.path.join(tmp_dir, entity.name) dest_dir = os.path.join(overseer_dir, "startup") os.mkdir(overseer_dir) os.mkdir(dest_dir) _write( os.path.join(dest_dir, "start.sh"), template["start_ovsr_sh"], "t", exe=True, ) _write( os.path.join(dest_dir, "gunicorn.conf.py"), utils.sh_replace(template["gunicorn_conf_py"], {"port": "8443"}), "t", exe=False, ) _write(os.path.join(dest_dir, "overseer.crt"), cert_pair.ser_cert, "b", exe=False) _write(os.path.join(dest_dir, "overseer.key"), cert_pair.ser_pri_key, "b", exe=False) _write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False) run_args = ["zip", "-rq", "-P", key, "tmp.zip", "."] subprocess.run(run_args, cwd=tmp_dir) fileobj = io.BytesIO() with open(os.path.join(tmp_dir, "tmp.zip"), "rb") as fo: fileobj.write(fo.read()) fileobj.seek(0) return fileobj, f"{entity.name}.zip"
[docs]def gen_server(key, first_server=True): project = Project.query.first() if first_server: entity = Entity(project.server1) fl_port = 8002 admin_port = 8003 else: entity = Entity(project.server2) fl_port = 8102 admin_port = 8103 issuer = Entity(project.short_name) signing_cert_pair = CertPair(issuer, project.root_key, project.root_cert) cert_pair = make_cert(entity, signing_cert_pair) config = json.loads(template["fed_server"]) server_0 = config["servers"][0] server_0["name"] = project.short_name server_0["service"]["target"] = f"{entity.name}:{fl_port}" server_0["service"]["scheme"] = project.scheme if hasattr(project, "scheme") else "grpc" server_0["admin_host"] = entity.name server_0["admin_port"] = admin_port if project.ha_mode: overseer_agent = {"path": "nvflare.ha.overseer_agent.HttpOverseerAgent"} overseer_agent["args"] = { "role": "server", "overseer_end_point": f"https://{project.overseer}:8443/api/v1", "project": project.short_name, "name": entity.name, "fl_port": str(fl_port), "admin_port": str(admin_port), } else: overseer_agent = {"path": "nvflare.ha.dummy_overseer_agent.DummyOverseerAgent"} overseer_agent["args"] = {"sp_end_point": f"{project.server1}:8002:8003"} config["overseer_agent"] = overseer_agent replacement_dict = { "admin_port": admin_port, "fed_learn_port": fl_port, "config_folder": "config", "ha_mode": "true" if project.ha_mode else "false", "docker_image": project.app_location.split(" ")[-1] if project.app_location else "nvflare/nvflare", "org_name": "", } tplt = tplt_utils.Template(template) with tempfile.TemporaryDirectory() as tmp_dir: server_dir = os.path.join(tmp_dir, entity.name) dest_dir = os.path.join(server_dir, "startup") os.mkdir(server_dir) os.mkdir(dest_dir) _write(os.path.join(dest_dir, "fed_server.json"), json.dumps(config, indent=2), "t") _write( os.path.join(dest_dir, "docker.sh"), utils.sh_replace(template["docker_svr_sh"], replacement_dict), "t", exe=True, ) _write( os.path.join(dest_dir, "start.sh"), utils.sh_replace(template["start_svr_sh"], replacement_dict), "t", exe=True, ) _write( os.path.join(dest_dir, "sub_start.sh"), utils.sh_replace(template["sub_start_svr_sh"], replacement_dict), "t", exe=True, ) _write( os.path.join(dest_dir, "stop_fl.sh"), template["stop_fl_sh"], "t", exe=True, ) _write(os.path.join(dest_dir, "server.crt"), cert_pair.ser_cert, "b", exe=False) _write(os.path.join(dest_dir, "server.key"), cert_pair.ser_pri_key, "b", exe=False) _write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False) if not project.ha_mode: _write( os.path.join(dest_dir, get_csp_start_script_name("azure")), utils.sh_replace( tplt.get_cloud_script_header() + get_csp_template("azure", "svr", template), {"server_name": entity.name, "ORG": ""}, ), "t", exe=True, ) _write( os.path.join(dest_dir, get_csp_start_script_name("aws")), utils.sh_replace( tplt.get_cloud_script_header() + get_csp_template("aws", "svr", template), {"server_name": entity.name, "ORG": ""}, ), "t", exe=True, ) signatures = utils.sign_all(dest_dir, deserialize_ca_key(project.root_key)) json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt")) # local folder creation dest_dir = os.path.join(server_dir, "local") os.mkdir(dest_dir) _write( os.path.join(dest_dir, "log.config.default"), template["log_config"], "t", ) _write( os.path.join(dest_dir, "resources.json.default"), template["local_server_resources"], "t", ) _write( os.path.join(dest_dir, "privacy.json.sample"), template["sample_privacy"], "t", ) _write( os.path.join(dest_dir, "authorization.json.default"), template["default_authz"], "t", ) # workspace folder file _write( os.path.join(server_dir, "readme.txt"), template["readme_fs"], "t", ) run_args = ["zip", "-rq", "-P", key, "tmp.zip", "."] subprocess.run(run_args, cwd=tmp_dir) fileobj = io.BytesIO() with open(os.path.join(tmp_dir, "tmp.zip"), "rb") as fo: fileobj.write(fo.read()) fileobj.seek(0) return fileobj, f"{entity.name}.zip"
[docs]def gen_client(key, id): project = Project.query.first() client = Client.query.get(id) entity = Entity(client.name, client.organization.name) issuer = Entity(project.short_name) signing_cert_pair = CertPair(issuer, project.root_key, project.root_cert) cert_pair = make_cert(entity, signing_cert_pair) config = json.loads(template["fed_client"]) config["servers"][0]["name"] = project.short_name config["servers"][0]["service"]["scheme"] = project.scheme if hasattr(project, "scheme") else "grpc" replacement_dict = { "client_name": entity.name, "config_folder": "config", "docker_image": project.app_location.split(" ")[-1] if project.app_location else "nvflare/nvflare", "org_name": entity.org, } if project.ha_mode: overseer_agent = {"path": "nvflare.ha.overseer_agent.HttpOverseerAgent"} overseer_agent["args"] = { "role": "client", "overseer_end_point": f"https://{project.overseer}:8443/api/v1", "project": project.short_name, "name": entity.name, } else: overseer_agent = {"path": "nvflare.ha.dummy_overseer_agent.DummyOverseerAgent"} overseer_agent["args"] = {"sp_end_point": f"{project.server1}:8002:8003"} config["overseer_agent"] = overseer_agent tplt = tplt_utils.Template(template) with tempfile.TemporaryDirectory() as tmp_dir: client_dir = os.path.join(tmp_dir, entity.name) dest_dir = os.path.join(client_dir, "startup") os.mkdir(client_dir) os.mkdir(dest_dir) _write(os.path.join(dest_dir, "fed_client.json"), json.dumps(config, indent=2), "t") _write( os.path.join(dest_dir, "docker.sh"), utils.sh_replace(template["docker_cln_sh"], replacement_dict), "t", exe=True, ) _write( os.path.join(dest_dir, "start.sh"), template["start_cln_sh"], "t", exe=True, ) _write( os.path.join(dest_dir, "sub_start.sh"), utils.sh_replace(template["sub_start_cln_sh"], replacement_dict), "t", exe=True, ) _write( os.path.join(dest_dir, "stop_fl.sh"), template["stop_fl_sh"], "t", exe=True, ) _write(os.path.join(dest_dir, "client.crt"), cert_pair.ser_cert, "b", exe=False) _write(os.path.join(dest_dir, "client.key"), cert_pair.ser_pri_key, "b", exe=False) _write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False) _write( os.path.join(dest_dir, get_csp_start_script_name("azure")), utils.sh_replace( tplt.get_cloud_script_header() + get_csp_template("azure", "cln", template), {"SITE": entity.name, "ORG": entity.org}, ), "t", exe=True, ) _write( os.path.join(dest_dir, get_csp_start_script_name("aws")), utils.sh_replace( tplt.get_cloud_script_header() + get_csp_template("aws", "cln", template), {"SITE": entity.name, "ORG": entity.org}, ), "t", exe=True, ) signatures = utils.sign_all(dest_dir, deserialize_ca_key(project.root_key)) json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt")) # local folder creation dest_dir = os.path.join(client_dir, "local") os.mkdir(dest_dir) _write( os.path.join(dest_dir, "log.config.default"), template["log_config"], "t", ) resources = json.loads(template["local_client_resources"]) for component in resources["components"]: if "nvflare.app_common.resource_managers.gpu_resource_manager.GPUResourceManager" == component["path"]: component["args"] = json.loads(client.capacity.capacity) break _write( os.path.join(dest_dir, "resources.json.default"), json.dumps(resources, indent=2), "t", ) _write( os.path.join(dest_dir, "privacy.json.sample"), template["sample_privacy"], "t", ) _write( os.path.join(dest_dir, "authorization.json.default"), template["default_authz"], "t", ) # workspace folder file _write( os.path.join(client_dir, "readme.txt"), template["readme_fc"], "t", ) run_args = ["zip", "-rq", "-P", key, "tmp.zip", "."] subprocess.run(run_args, cwd=tmp_dir) fileobj = io.BytesIO() with open(os.path.join(tmp_dir, "tmp.zip"), "rb") as fo: fileobj.write(fo.read()) fileobj.seek(0) return fileobj, f"{entity.name}.zip"
[docs]def gen_user(key, id): project = Project.query.first() server_name = project.server1 user = User.query.get(id) entity = Entity(user.email, user.organization.name, user.role.name) issuer = Entity(project.short_name) signing_cert_pair = CertPair(issuer, project.root_key, project.root_cert) cert_pair = make_cert(entity, signing_cert_pair) config = json.loads(template["fed_admin"]) replacement_dict = {"admin_name": entity.name, "cn": server_name, "admin_port": "8003", "docker_image": ""} if project.ha_mode: overseer_agent = {"path": "nvflare.ha.overseer_agent.HttpOverseerAgent"} overseer_agent["args"] = { "role": "admin", "overseer_end_point": f"https://{project.overseer}:8443/api/v1", "project": project.short_name, "name": entity.name, } else: overseer_agent = {"path": "nvflare.ha.dummy_overseer_agent.DummyOverseerAgent"} overseer_agent["args"] = {"sp_end_point": f"{project.server1}:8002:8003"} config["admin"].update({"overseer_agent": overseer_agent}) with tempfile.TemporaryDirectory() as tmp_dir: user_dir = os.path.join(tmp_dir, entity.name) dest_dir = os.path.join(user_dir, "startup") os.mkdir(user_dir) os.mkdir(dest_dir) _write(os.path.join(dest_dir, "fed_admin.json"), json.dumps(config, indent=2), "t") _write( os.path.join(dest_dir, "fl_admin.sh"), utils.sh_replace(template["fl_admin_sh"], replacement_dict), "t", exe=True, ) _write(os.path.join(dest_dir, "client.crt"), cert_pair.ser_cert, "b", exe=False) _write(os.path.join(dest_dir, "client.key"), cert_pair.ser_pri_key, "b", exe=False) _write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False) signatures = utils.sign_all(dest_dir, deserialize_ca_key(project.root_key)) json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt")) # local folder creation dest_dir = os.path.join(user_dir, "local") os.mkdir(dest_dir) # workspace folder file _write( os.path.join(user_dir, "readme.txt"), template["readme_am"], "t", ) _write( os.path.join(user_dir, "system_info.ipynb"), utils.sh_replace(template["adm_notebook"], replacement_dict), "t", ) run_args = ["zip", "-rq", "-P", key, "tmp.zip", "."] subprocess.run(run_args, cwd=tmp_dir) fileobj = io.BytesIO() with open(os.path.join(tmp_dir, "tmp.zip"), "rb") as fo: fileobj.write(fo.read()) fileobj.seek(0) return fileobj, f"{entity.name}.zip"