Source code for nvflare.tool.package_checker.package_checker

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import signal
from abc import ABC, abstractmethod
from collections import defaultdict
from subprocess import TimeoutExpired

from nvflare.tool.package_checker.check_rule import CHECK_PASSED, CheckResult, CheckRule
from nvflare.tool.package_checker.utils import run_command_in_subprocess, split_by_len


[docs]class PackageChecker(ABC): def __init__(self): self.report = defaultdict(list) self.check_len = len("Checks") self.problem_len = 80 self.fix_len = len("How to fix") self.dry_run_timeout = 5 self.package_path = None self.rules = []
[docs] @abstractmethod def init_rules(self, package_path: str): pass
[docs] def init(self, package_path: str): if not os.path.exists(package_path): raise RuntimeError(f"Package path: {package_path} does not exist.") self.package_path = os.path.abspath(package_path) self.init_rules(package_path)
[docs] @abstractmethod def should_be_checked(self) -> bool: """Check if this package should be checked by this checker.""" pass
[docs] @abstractmethod def get_dry_run_command(self) -> str: """Returns dry run command.""" pass
[docs] def get_dry_run_inputs(self): return None
[docs] def stop_dry_run(self, force: bool = True): # todo: add gracefully shutdown command print("killing dry run process") command = self.get_dry_run_command() cmd = f"pkill -9 -f '{command}'" process = run_command_in_subprocess(cmd) out, err = process.communicate() print(f"killed dry run process output: {out}") print(f"killed dry run process err: {err}")
[docs] def check(self) -> int: """Checks if the package is runnable on the current system. Returns: 0: if no dry-run process started. 1: if the dry-run process is started and return code is 0. 2: if the dry-run process is started and return code is not 0. """ ret_code = 0 try: all_passed = True for rule in self.rules: if isinstance(rule, CheckRule): result: CheckResult = rule(self.package_path, data=None) self.add_report(rule.name, result.problem, result.solution) if rule.required and result.problem != CHECK_PASSED: all_passed = False elif isinstance(rule, list): result = CheckResult() # ordered rules for r in rule: result = r(self.package_path, data=result.data) self.add_report(r.name, result.problem, result.solution) if r.required and result.problem != CHECK_PASSED: all_passed = False break # check dry run if all_passed: ret_code = self.check_dry_run() except Exception as e: self.add_report( "Package Error", f"Exception happens in checking: {e}, this package is not in correct format.", "Please download a new package.", ) finally: return ret_code
[docs] def check_dry_run(self) -> int: """Runs dry run command. Returns: 0: if no process started. 1: if the process is started and return code is 0. 2: if the process is started and return code is not 0. """ command = self.get_dry_run_command() dry_run_input = self.get_dry_run_inputs() process = None try: process = run_command_in_subprocess(command) if dry_run_input is not None: out, _ = process.communicate(input=dry_run_input, timeout=self.dry_run_timeout) else: out, _ = process.communicate(timeout=self.dry_run_timeout) ret_code = process.returncode if ret_code == 0: self.add_report( "Check dry run", CHECK_PASSED, "N/A", ) else: self.add_report( "Check dry run", f"Can't start successfully: {out}", "Please check the error message of dry run.", ) except TimeoutExpired: os.killpg(process.pid, signal.SIGTERM) # Assumption, preflight check is focused on the connectivity, so we assume all sub-systems should # behave as designed if configured correctly. # In such case, a dry run for any of the sub systems (overseer, server(s), clients etc.) will # run as service forever once started, unless it is asked to stop. Therefore, we will get TimeoutExpired # with above assumption, we consider the sub-system as running in good condition if it is started running # in give timeout period self.add_report( "Check dry run", CHECK_PASSED, "N/A", ) finally: if process: if process.returncode == 0: return 1 else: return 2 else: return 0
[docs] def add_report(self, check_name, problem_text: str, fix_text: str): self.report[self.package_path].append((check_name, problem_text, fix_text)) self.check_len = max(self.check_len, len(check_name)) self.fix_len = max(self.fix_len, len(fix_text))
def _print_line(self): print("|" + "-" * (self.check_len + self.problem_len + self.fix_len + 8) + "|") def _print_row(self, check, problem, fix): print( "| {check:<{width1}s} | {problems:<{width2}s} | {fix:<{width3}s} |".format( check=check, problems=problem, fix=fix, width1=self.check_len, width2=self.problem_len, width3=self.fix_len, ) )
[docs] def print_report(self): total_width = self.check_len + self.problem_len + self.fix_len + 10 for package_path, results in self.report.items(): print("Checking Package: " + package_path) print("-" * total_width) if results: self._print_row("Checks", "Problems", "How to fix") else: print("| {:{}s} |".format("Passed", total_width - 4)) for row in results: self._print_line() lines = split_by_len(row[1], max_len=self.problem_len) self._print_row(row[0], lines[0], row[2]) for line in lines[1:]: self._print_row("", line, "") print("-" * total_width) print()