Source code for nvflare.lighter.provision

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

import argparse
import os
import pathlib
import shutil
import sys
import webbrowser

from nvflare.fuel.utils.class_utils import instantiate_class
from nvflare.lighter.spec import Participant, Project, Provisioner
from nvflare.lighter.utils import load_yaml


[docs]def main(): parser = argparse.ArgumentParser() parser.add_argument("-p", "--project_file", type=str, default="project.yml", help="file to describe FL project") parser.add_argument("-w", "--workspace", type=str, default="workspace", help="directory used by provision") parser.add_argument("-c", "--custom_folder", type=str, default=".", help="additional folder to load python codes") parser.add_argument( "-u", "--ui_tool", action="store_true", help="Run provisioning UI tool to generate project.yml file", ) args = parser.parse_args() file_path = pathlib.Path(__file__).parent.absolute() current_path = os.getcwd() custom_folder_path = os.path.join(current_path, args.custom_folder) sys.path.append(custom_folder_path) print("Path list (sys.path) for python codes loading: {}".format(sys.path)) # main project file project_file = args.project_file current_project_yml = os.path.join(current_path, "project.yml") if len(sys.argv) == 1 and not os.path.exists(current_project_yml): answer = input( f"No project.yml found in current folder. Is it OK to generate one at {current_project_yml} for you? (y/N) " ) if answer.strip().upper() == "Y": shutil.copyfile(os.path.join(file_path, "project.yml"), current_project_yml) print(f"{current_project_yml} was created. Please edit it to fit your FL configuration.") exit(0) if args.ui_tool: ui_helper_path = os.path.join(file_path, "provision_helper.html") ui_helper_url = f"file://{ui_helper_path}" webbrowser.open_new_tab(ui_helper_url) print( "\n******\n" "Now launching provisioning UI tool.\n" "After generating project.yml in the browser and saving it to your local folder,\n" "please re-run provision with -p option, pointing to the generated project.yml, to generate all packages.\n******\n" ) exit(0) workspace = args.workspace workspace_full_path = os.path.join(current_path, workspace) project_full_path = os.path.join(current_path, project_file) print(f"Project yaml file: {project_full_path}.") project_dict = load_yaml(project_full_path) api_version = project_dict.get("api_version") if api_version not in [3]: raise ValueError(f"API version expected 3 but found {api_version}") project_name = project_dict.get("name") project_description = project_dict.get("description", "") participants = list() for p in project_dict.get("participants"): participants.append(Participant(**p)) project = Project(name=project_name, description=project_description, participants=participants) n_servers = len(project.get_participants_by_type("server", first_only=False)) if n_servers > 2: print( f"Configuration error: Expect 2 or 1 server to be provisioned. {project_full_path} contains {n_servers} servers." ) return builders = list() for b in project_dict.get("builders"): path = b.get("path") args = b.get("args") builders.append(instantiate_class(path, args)) provisioner = Provisioner(workspace_full_path, builders) provisioner.provision(project)
if __name__ == "__main__": main()