Source code for nvflare.dashboard.application.blob

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import json
import os
import subprocess
import tempfile

from nvflare.lighter import tplt_utils, utils

from .cert import CertPair, Entity, deserialize_ca_key, make_cert
from .models import Client, Project, User

lighter_folder = os.path.dirname(utils.__file__)
template = utils.load_yaml(os.path.join(lighter_folder, "impl", "master_template.yml"))
supported_csps = ["aws", "azure"]
for csp in supported_csps:
    csp_template_file = os.path.join(lighter_folder, "impl", f"{csp}_template.yml")
    if os.path.exists(csp_template_file):
        template.update(utils.load_yaml(csp_template_file))


[docs]def get_csp_start_script_name(csp): return f"{csp}_start.sh"
[docs]def gen_overseer(key): project = Project.query.first() entity = Entity(project.overseer) issuer = Entity(project.short_name) signing_cert_pair = CertPair(issuer, project.root_key, project.root_cert) cert_pair = make_cert(entity, signing_cert_pair) with tempfile.TemporaryDirectory() as tmp_dir: overseer_dir = os.path.join(tmp_dir, entity.name) dest_dir = os.path.join(overseer_dir, "startup") os.mkdir(overseer_dir) os.mkdir(dest_dir) utils._write( os.path.join(dest_dir, "start.sh"), template["start_ovsr_sh"], "t", exe=True, ) utils._write( os.path.join(dest_dir, "gunicorn.conf.py"), utils.sh_replace(template["gunicorn_conf_py"], {"port": "8443"}), "t", exe=False, ) utils._write_pki(type="overseer", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert) run_args = ["zip", "-rq", "-P", key, "tmp.zip", "."] subprocess.run(run_args, cwd=tmp_dir) fileobj = io.BytesIO() with open(os.path.join(tmp_dir, "tmp.zip"), "rb") as fo: fileobj.write(fo.read()) fileobj.seek(0) return fileobj, f"{entity.name}.zip"
[docs]def gen_server(key, first_server=True): project = Project.query.first() if first_server: entity = Entity(project.server1) fl_port = 8002 admin_port = 8003 else: entity = Entity(project.server2) fl_port = 8102 admin_port = 8103 issuer = Entity(project.short_name) signing_cert_pair = CertPair(issuer, project.root_key, project.root_cert) cert_pair = make_cert(entity, signing_cert_pair) config = json.loads(template["fed_server"]) server_0 = config["servers"][0] server_0["name"] = project.short_name server_0["service"]["target"] = f"{entity.name}:{fl_port}" server_0["service"]["scheme"] = project.scheme if hasattr(project, "scheme") else "grpc" server_0["admin_host"] = entity.name server_0["admin_port"] = admin_port if project.ha_mode: overseer_agent = {"path": "nvflare.ha.overseer_agent.HttpOverseerAgent"} overseer_agent["args"] = { "role": "server", "overseer_end_point": f"https://{project.overseer}:8443/api/v1", "project": project.short_name, "name": entity.name, "fl_port": str(fl_port), "admin_port": str(admin_port), } else: overseer_agent = {"path": "nvflare.ha.dummy_overseer_agent.DummyOverseerAgent"} overseer_agent["args"] = {"sp_end_point": f"{project.server1}:8002:8003"} config["overseer_agent"] = overseer_agent replacement_dict = { "admin_port": admin_port, "fed_learn_port": fl_port, "config_folder": "config", "ha_mode": "true" if project.ha_mode else "false", "docker_image": project.app_location.split(" ")[-1] if project.app_location else "nvflare/nvflare", "org_name": "", "type": "server", "cln_uid": "", } tplt = tplt_utils.Template(template) with tempfile.TemporaryDirectory() as tmp_dir: server_dir = os.path.join(tmp_dir, entity.name) dest_dir = os.path.join(server_dir, "startup") os.mkdir(server_dir) os.mkdir(dest_dir) utils._write_common( type="server", dest_dir=dest_dir, template=template, tplt=tplt, replacement_dict=replacement_dict, config=config, ) utils._write_pki(type="server", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert) if not project.ha_mode: for csp in supported_csps: utils._write( os.path.join(dest_dir, get_csp_start_script_name(csp)), tplt.get_start_sh(csp=csp, type="server", entity=entity), "t", exe=True, ) signatures = utils.sign_all(dest_dir, deserialize_ca_key(project.root_key)) json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt")) # local folder creation dest_dir = os.path.join(server_dir, "local") os.mkdir(dest_dir) utils._write_local(type="server", dest_dir=dest_dir, template=template) # workspace folder file utils._write( os.path.join(server_dir, "readme.txt"), template["readme_fs"], "t", ) run_args = ["zip", "-rq", "-P", key, "tmp.zip", "."] subprocess.run(run_args, cwd=tmp_dir) fileobj = io.BytesIO() with open(os.path.join(tmp_dir, "tmp.zip"), "rb") as fo: fileobj.write(fo.read()) fileobj.seek(0) return fileobj, f"{entity.name}.zip"
[docs]def gen_client(key, id): project = Project.query.first() client = Client.query.get(id) entity = Entity(client.name, client.organization.name) issuer = Entity(project.short_name) signing_cert_pair = CertPair(issuer, project.root_key, project.root_cert) cert_pair = make_cert(entity, signing_cert_pair) config = json.loads(template["fed_client"]) config["servers"][0]["name"] = project.short_name config["servers"][0]["service"]["scheme"] = project.scheme if hasattr(project, "scheme") else "grpc" replacement_dict = { "client_name": entity.name, "config_folder": "config", "docker_image": project.app_location.split(" ")[-1] if project.app_location else "nvflare/nvflare", "org_name": entity.org, "type": "client", "cln_uid": f"uid={entity.name}", } if project.ha_mode: overseer_agent = {"path": "nvflare.ha.overseer_agent.HttpOverseerAgent"} overseer_agent["args"] = { "role": "client", "overseer_end_point": f"https://{project.overseer}:8443/api/v1", "project": project.short_name, "name": entity.name, } else: overseer_agent = {"path": "nvflare.ha.dummy_overseer_agent.DummyOverseerAgent"} overseer_agent["args"] = {"sp_end_point": f"{project.server1}:8002:8003"} config["overseer_agent"] = overseer_agent tplt = tplt_utils.Template(template) with tempfile.TemporaryDirectory() as tmp_dir: client_dir = os.path.join(tmp_dir, entity.name) dest_dir = os.path.join(client_dir, "startup") os.mkdir(client_dir) os.mkdir(dest_dir) utils._write_pki(type="client", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert) utils._write_common( type="client", dest_dir=dest_dir, template=template, tplt=tplt, replacement_dict=replacement_dict, config=config, ) for csp in supported_csps: utils._write( os.path.join(dest_dir, get_csp_start_script_name(csp)), tplt.get_start_sh(csp=csp, type="client", entity=entity), "t", exe=True, ) signatures = utils.sign_all(dest_dir, deserialize_ca_key(project.root_key)) json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt")) # local folder creation dest_dir = os.path.join(client_dir, "local") os.mkdir(dest_dir) utils._write_local(type="client", dest_dir=dest_dir, template=template, capacity=client.capacity.capacity) # workspace folder file utils._write( os.path.join(client_dir, "readme.txt"), template["readme_fc"], "t", ) run_args = ["zip", "-rq", "-P", key, "tmp.zip", "."] subprocess.run(run_args, cwd=tmp_dir) fileobj = io.BytesIO() with open(os.path.join(tmp_dir, "tmp.zip"), "rb") as fo: fileobj.write(fo.read()) fileobj.seek(0) return fileobj, f"{entity.name}.zip"
[docs]def gen_user(key, id): project = Project.query.first() server_name = project.server1 user = User.query.get(id) entity = Entity(user.email, user.organization.name, user.role.name) issuer = Entity(project.short_name) signing_cert_pair = CertPair(issuer, project.root_key, project.root_cert) cert_pair = make_cert(entity, signing_cert_pair) config = json.loads(template["fed_admin"]) replacement_dict = {"admin_name": entity.name, "cn": server_name, "admin_port": "8003", "docker_image": ""} if project.ha_mode: overseer_agent = {"path": "nvflare.ha.overseer_agent.HttpOverseerAgent"} overseer_agent["args"] = { "role": "admin", "overseer_end_point": f"https://{project.overseer}:8443/api/v1", "project": project.short_name, "name": entity.name, } else: overseer_agent = {"path": "nvflare.ha.dummy_overseer_agent.DummyOverseerAgent"} overseer_agent["args"] = {"sp_end_point": f"{project.server1}:8002:8003"} config["admin"].update({"overseer_agent": overseer_agent}) with tempfile.TemporaryDirectory() as tmp_dir: user_dir = os.path.join(tmp_dir, entity.name) dest_dir = os.path.join(user_dir, "startup") os.mkdir(user_dir) os.mkdir(dest_dir) utils._write(os.path.join(dest_dir, "fed_admin.json"), json.dumps(config, indent=2), "t") utils._write( os.path.join(dest_dir, "fl_admin.sh"), utils.sh_replace(template["fl_admin_sh"], replacement_dict), "t", exe=True, ) utils._write_pki(type="client", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert) signatures = utils.sign_all(dest_dir, deserialize_ca_key(project.root_key)) json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt")) # local folder creation dest_dir = os.path.join(user_dir, "local") os.mkdir(dest_dir) # workspace folder file utils._write( os.path.join(user_dir, "readme.txt"), template["readme_am"], "t", ) utils._write( os.path.join(user_dir, "system_info.ipynb"), utils.sh_replace(template["adm_notebook"], replacement_dict), "t", ) run_args = ["zip", "-rq", "-P", key, "tmp.zip", "."] subprocess.run(run_args, cwd=tmp_dir) fileobj = io.BytesIO() with open(os.path.join(tmp_dir, "tmp.zip"), "rb") as fo: fileobj.write(fo.read()) fileobj.seek(0) return fileobj, f"{entity.name}.zip"