Source code for nvflare.tool.preflight_check

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

from nvflare.tool.package_checker import ClientPackageChecker, NVFlareConsolePackageChecker, ServerPackageChecker
from nvflare.tool.package_checker.package_checker import CheckStatus

_preflight_parser = None


[docs] def define_preflight_check_parser(parser): global _preflight_parser _preflight_parser = parser parser.add_argument( "-p", "--package-path", "--package_path", # backward compat dest="package_path", required=True, type=str, help="path to specific package", ) parser.add_argument("--schema", action="store_true", help="print command schema as JSON and exit")
[docs] def check_packages(args): from nvflare.tool.cli_output import output_error, output_ok, print_human from nvflare.tool.cli_schema import handle_schema_flag handle_schema_flag( _preflight_parser, "nvflare preflight-check", ["nvflare preflight-check -p /path/to/package"], getattr(args, "_argv", []), ) if getattr(args, "_raw_sub_command", None) == "preflight_check": print_human("Note: 'preflight_check' is deprecated; use 'nvflare preflight-check' instead.") package_path = args.package_path if not os.path.isdir(package_path): output_error("INVALID_ARGS", exit_code=4, detail=f"package_path {package_path} is not a valid directory") return if not os.path.isdir(os.path.join(package_path, "startup")): output_error("INVALID_ARGS", exit_code=4, detail=f"package in {package_path} is not in the correct format") return package_checkers = [ ServerPackageChecker(), ClientPackageChecker(), NVFlareConsolePackageChecker(), ] checks = [] overall_pass = True for p in package_checkers: p.init(package_path=package_path) check_status = CheckStatus.PASS if p.should_be_checked(): check_status = p.check() p.print_report() component_name = p.__class__.__name__.replace("PackageChecker", "").lower() status = "fail" if check_status in [CheckStatus.FAIL, CheckStatus.FAIL_WITH_CLEANUP] else "pass" if status == "fail": overall_pass = False check_result = {"component": component_name, "status": status} details = getattr(p, "last_error", None) if isinstance(details, str) and details: check_result["details"] = details checks.append(check_result) if check_status == CheckStatus.PASS_WITH_CLEANUP: p.stop_dry_run(force=False) elif check_status == CheckStatus.FAIL_WITH_CLEANUP: p.stop_dry_run(force=True) overall = "pass" if overall_pass else "fail" output_ok( { "package": os.path.abspath(package_path), "checks": checks, "overall": overall, }, exit_code=0 if overall_pass else 1, )
[docs] def main(): parser = argparse.ArgumentParser("nvflare preflight check") define_preflight_check_parser(parser) args = parser.parse_args() check_packages(args)
if __name__ == "__main__": main()