Source code for nvflare.tool.package_checker.server_package_checker

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import shutil
import sys

from .check_rule import CheckAddressBinding, CheckOverseerRunning, CheckWriting
from .package_checker import PackageChecker
from .utils import NVFlareConfig, NVFlareRole

SERVER_SCRIPT = "nvflare.private.fed.app.server.server_train"


def _get_server_fed_config(package_path: str):
    startup = os.path.join(package_path, "startup")
    fed_config_file = os.path.join(startup, NVFlareConfig.SERVER)
    with open(fed_config_file, "r") as f:
        fed_config = json.load(f)
    return fed_config


def _get_snapshot_storage_root(package_path: str) -> str:
    fed_config = _get_server_fed_config(package_path)
    snapshot_storage_root = ""
    if (
        fed_config.get("snapshot_persistor", {}).get("path")
        == "nvflare.app_common.state_persistors.storage_state_persistor.StorageStatePersistor"
    ):
        storage = fed_config["snapshot_persistor"].get("args", {}).get("storage")
        if storage["path"] == "nvflare.app_common.storages.filesystem_storage.FilesystemStorage":
            snapshot_storage_root = storage["args"]["root_dir"]
    return snapshot_storage_root


def _get_job_storage_root(package_path: str) -> str:
    fed_config = _get_server_fed_config(package_path)
    job_storage_root = ""
    for c in fed_config.get("components", []):
        if c.get("path") == "nvflare.apis.impl.job_def_manager.SimpleJobDefManager":
            job_storage_root = c["args"]["uri_root"]
    return job_storage_root


def _get_grpc_host_and_port(package_path: str) -> (str, int):
    fed_config = _get_server_fed_config(package_path)
    server_conf = fed_config["servers"][0]
    grpc_service_config = server_conf["service"]
    grpc_target_address = grpc_service_config["target"]
    _, port = grpc_target_address.split(":")
    return "localhost", int(port)


def _get_admin_host_and_port(package_path: str) -> (str, int):
    fed_config = _get_server_fed_config(package_path)
    server_conf = fed_config["servers"][0]
    return "localhost", int(server_conf["admin_port"])


[docs]class ServerPackageChecker(PackageChecker): def __init__(self): super().__init__() self.snapshot_storage_root = None self.job_storage_root = None
[docs] def init_rules(self, package_path): self.dry_run_timeout = 3 self.rules = [ CheckOverseerRunning(name="Check overseer running", role=NVFlareRole.SERVER), CheckAddressBinding(name="Check grpc port binding", get_host_and_port_from_package=_get_grpc_host_and_port), CheckAddressBinding( name="Check admin port binding", get_host_and_port_from_package=_get_admin_host_and_port ), CheckWriting(name="Check snapshot storage writable", get_filename_from_package=_get_snapshot_storage_root), CheckWriting(name="Check job storage writable", get_filename_from_package=_get_job_storage_root), ]
[docs] def should_be_checked(self) -> bool: startup = os.path.join(self.package_path, "startup") if os.path.exists(os.path.join(startup, NVFlareConfig.SERVER)): return True return False
[docs] def get_dry_run_command(self) -> str: command = ( f"{sys.executable} -m {SERVER_SCRIPT}" f" -m {self.package_path} -s {NVFlareConfig.SERVER}" " --set secure_train=true config_folder=config" ) self.snapshot_storage_root = _get_snapshot_storage_root(self.package_path) self.job_storage_root = _get_job_storage_root(self.package_path) return command
[docs] def stop_dry_run(self, force=True): super().stop_dry_run(force=force) if os.path.exists(self.snapshot_storage_root): shutil.rmtree(self.snapshot_storage_root) if os.path.exists(self.job_storage_root): shutil.rmtree(self.job_storage_root)