Source code for nvflare.dashboard.application.blob

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import json
import os
import re
import subprocess
import tempfile

from nvflare.dashboard.config import PropertyManager
from nvflare.lighter.constants import PropKey
from nvflare.lighter.entity import Project as ProvProject
from nvflare.lighter.impl.aws import AWSBuilder
from nvflare.lighter.impl.azure import AzureBuilder
from nvflare.lighter.impl.cert import CertBuilder
from nvflare.lighter.impl.signature import SignatureBuilder
from nvflare.lighter.impl.static_file import StaticFileBuilder
from nvflare.lighter.impl.workspace import WorkspaceBuilder
from nvflare.lighter.provisioner import Provisioner

from .cert import deserialize_ca_key
from .models import Client, Project, User
from .store import Store, inc_dl


[docs] class DummyLogger: """This dummy logger is used to suppress all log messages generated by the Provisioner, except for errors. We print error messages to stdout. """
[docs] def info(self, msg: str): pass
[docs] def error(self, msg: str): print(f"ERROR: {msg}")
[docs] def debug(self, msg: str): pass
[docs] def warning(self, msg: str): pass
def _get_provisioner(prop_mgr: PropertyManager, root_dir: str, scheme, docker_image=None): overseer_agent = { "path": "nvflare.ha.dummy_overseer_agent.DummyOverseerAgent", "overseer_exists": False, "args": {"sp_end_point": "server:8002:8003"}, } scheme = prop_mgr.get_project_prop("scheme", scheme) builders = [ WorkspaceBuilder(), StaticFileBuilder( config_folder="config", scheme=scheme, docker_image=docker_image, overseer_agent=overseer_agent, ), AWSBuilder(), AzureBuilder(), CertBuilder(), SignatureBuilder(), ] # TBD: need to add Packager object to the provisioner! # Should we create the Provisioner based on a project.yml file? packager = None return Provisioner(root_dir, builders, packager)
[docs] def gen_server_blob(key): return _gen_kit(key)
def _gen_kit(download_key, prepare_target_cb=None, **cb_kwargs): # validate download_key allowed_pattern = r"^[A-Za-z0-9]+$" if not re.match(allowed_pattern, download_key): raise RuntimeError(f"ERROR: detected unsafe download key: {download_key}") prop_mgr = PropertyManager() u = Store.get_user(1) super_user = u.get("user") fl_port = prop_mgr.get_server_prop(PropKey.FED_LEARN_PORT, 8002) admin_port = prop_mgr.get_server_prop(PropKey.ADMIN_PORT, fl_port) with tempfile.TemporaryDirectory() as tmp_dir: project = Project.query.first() scheme = project.scheme if hasattr(project, "scheme") else "grpc" docker_image = project.app_location.split(" ")[-1] if project.app_location else "nvflare/nvflare" provisioner = _get_provisioner(prop_mgr, tmp_dir, scheme, docker_image) # the root key is protected by password root_pri_key = deserialize_ca_key(project.root_key) proj_props = {PropKey.API_VERSION: 3} if project.project_props: proj_props.update(json.loads(project.project_props)) prov_project = ProvProject( project.short_name, project.description, props=proj_props, root_private_key=root_pri_key, serialized_root_cert=project.root_cert, ) # use org of superuser org = super_user.get("organization", "nvflare") server_name = project.server1 server_props = { PropKey.FED_LEARN_PORT: fl_port, PropKey.ADMIN_PORT: admin_port, PropKey.DEFAULT_HOST: server_name, } if project.server_props: server_props.update(json.loads(project.server_props)) extra = prop_mgr.get_server_props() if extra: server_props.update(extra) server = prov_project.set_server( name=server_name, org=org, props=server_props, ) target = server if prepare_target_cb is not None: target = prepare_target_cb(prop_mgr, prov_project, **cb_kwargs) ctx = provisioner.provision(prov_project, logger=DummyLogger()) result_dir = ctx.get_result_location() ent_dir = os.path.join(result_dir, target.name) subprocess.run(["zip", "-rq", "-P", download_key, "tmp.zip", "."], cwd=ent_dir) fileobj = io.BytesIO() with open(os.path.join(ent_dir, "tmp.zip"), "rb") as fo: fileobj.write(fo.read()) fileobj.seek(0) return fileobj, f"{target.name}.zip"
[docs] def gen_client_blob(key, id): return _gen_kit(key, _prepare_client, client_id=id)
def _prepare_client(prop_mgr: PropertyManager, prov_project: ProvProject, client_id): client = Client.query.get(client_id) inc_dl(Client, client_id) if client.props: props = json.loads(client.props) else: props = {} if client.capacity: props[PropKey.CAPACITY] = json.loads(client.capacity) extra = prop_mgr.get_client_props() if extra: props.update(extra) return prov_project.add_client(name=client.name, org=client.organization.name, props=props)
[docs] def gen_user_blob(key, id): return _gen_kit(key, _prepare_user, user_id=id)
def _prepare_user(prop_mgr: PropertyManager, prov_project: ProvProject, user_id): user = User.query.get(user_id) inc_dl(User, user_id) props = {PropKey.ROLE: user.role.name} if user.props: props.update(json.loads(user.props)) extra = prop_mgr.get_admin_props() if extra: props.update(extra) admin = prov_project.add_admin(name=user.email, org=user.organization.name, props=props) return admin