Source code for nvflare.fuel.utils.config_service

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from typing import Dict, List, Optional, Union

from nvflare.fuel.utils.config import Config, ConfigFormat
from nvflare.fuel.utils.config_factory import ConfigFactory

ENV_VAR_PREFIX = "NVFLARE_"


[docs]def find_file_in_dir(file_basename, path) -> Union[None, str]: """ Find a file from a directory and return the full path of the file, if found Args: file_basename: base name of the file to be found path: the directory from where the file is to be found Returns: the full path of the file, if found; None if not found """ for root, dirs, files in os.walk(path): if file_basename in files: return os.path.join(root, file_basename) return None
[docs]def search_file(file_basename: str, dirs: List[str]) -> Union[None, str]: """ Find a file by searching a list of dirs and return the one in the last dir. Args: file_basename: base name of the file to be found dirs: list of directories to search Returns: the full path of the file, if found; None if not found """ if isinstance(dirs, str): dirs = [dirs] for d in dirs: f = find_file_in_dir(file_basename, d) if f: return f return None
[docs]class ConfigService: """ The ConfigService provides a global configuration service that can be used by any component at any layer. The ConfigService manages config information and makes it available to any component, in two ways: 1. Config info is preloaded into predefined sections. Callers can get the config data by a section name. 2. Manages config path (a list of directories) and loads file from the path. Only JSON file loading is supported. """ logger = logging.getLogger(__name__) _sections = {} _config_path = [] _cmd_args = None _var_dict = None _var_values = {}
[docs] @classmethod def initialize(cls, section_files: Dict[str, str], config_path: List[str], parsed_args=None, var_dict=None): """ Initialize the ConfigService. Configuration is divided into sections, and each section must have a JSON config file. Only specify the base name of the config file. Config path is provided to locate config files. Files are searched in the order of provided config_dirs. If multiple directories contain the same file name, then the first one is used. Args: section_files: dict: section name => config file config_path: list of config directories parsed_args: command args for starting the program var_dict: dict for additional vars Returns: """ if not isinstance(section_files, dict): raise TypeError(f"section_files must be dict but got {type(section_files)}") if not isinstance(config_path, list): raise TypeError(f"config_dirs must be list but got {type(config_path)}") if not config_path: raise ValueError("config_dirs is empty") if var_dict and not isinstance(var_dict, dict): raise ValueError(f"var_dict must dict but got {type(var_dict)}") for d in config_path: if not isinstance(d, str): raise ValueError(f"config_dirs must contain str but got {type(d)}") if not os.path.exists(d): raise ValueError(f"'directory {d}' does not exist") if not os.path.isdir(d): raise ValueError(f"'{d}' is not a valid directory") cls._config_path = config_path for section, file_basename in section_files.items(): cls._sections[section] = cls.load_config_dict(file_basename, cls._config_path) cls._var_dict = var_dict if parsed_args: if not isinstance(parsed_args, argparse.Namespace): raise ValueError(f"parsed_args must be argparse.Namespace but got {type(parsed_args)}") cls._cmd_args = dict(parsed_args.__dict__)
[docs] @classmethod def get_section(cls, name: str): return cls._sections.get(name)
[docs] @classmethod def add_section(cls, section_name: str, data: dict, overwrite_existing: bool = True): """ Add a section to the config data. Args: section_name: name of the section to be added data: data of the section overwrite_existing: if section already exists, whether to overwrite Returns: """ if not isinstance(section_name, str): raise TypeError(f"section name must be str but got {type(section_name)}") if not isinstance(data, dict): raise TypeError(f"config data must be dict but got {type(data)}") if overwrite_existing or section_name not in cls._sections: cls._sections[section_name] = data
[docs] @classmethod def load_configuration(cls, file_basename: str) -> Optional[Config]: return ConfigFactory.load_config(file_basename, cls._config_path)
[docs] @classmethod def load_config_dict( cls, file_basename: str, search_dirs: Optional[List] = None, raise_exception: bool = True ) -> Optional[Dict]: """ Load a specified config file ( ignore extension) Args: raise_exception: if True raise exception when error occurs file_basename: base name of the config file to be loaded. for example: file_basename = config_fed_server.json what the function does is to search for config file that matches config_fed_server.[json|json.default|conf|conf.default|yml|yml.default] in given search directories: cls._config_path if json or json.default is not found; then switch to Pyhoncon [.conf] or corresponding default file; if still not found; then we switch to YAML files. We use OmegaConf to load YAML search_dirs: which directories to search. Returns: Dictionary from the configuration if not found, exception will be raised. """ conf = ConfigFactory.load_config(file_basename, search_dirs) if conf: return conf.to_dict() else: if raise_exception: raise FileNotFoundError(cls.config_not_found_msg(file_basename, search_dirs)) return None
[docs] @classmethod def config_not_found_msg(cls, file_basename, search_dirs): basename = os.path.splitext(file_basename)[0] conf_exts = "|".join(ConfigFormat.config_ext_formats().keys()) msg = f"cannot find file '{basename}[{conf_exts}]'" msg = f"{msg} from search paths: '{search_dirs}'" if search_dirs else msg return msg
[docs] @classmethod def find_file(cls, file_basename: str) -> Union[None, str]: """ Find specified file from the config path. Caller is responsible for loading/processing the file. This is useful for non-JSON files. Args: file_basename: base name of the file to be found Returns: full name of the file if found; None if not. """ if not isinstance(file_basename, str): raise TypeError(f"file_basename must be str but got {type(file_basename)}") return search_file(file_basename, cls._config_path)
@staticmethod def _get_var_from_os_env(name: str): if not name.startswith(ENV_VAR_PREFIX): env_var_name = ENV_VAR_PREFIX + name else: env_var_name = name env_var_name = env_var_name.upper() if env_var_name in os.environ: return os.environ.get(env_var_name) else: return None @classmethod def _get_var_from_config_sources(cls, name: str, conf): if conf is None: return None # conf could be: # a single config source (a section name or a dict) # a list of config sources if not isinstance(conf, list): conf = [conf] # check each conf source until the var is found for src in conf: if isinstance(src, str): # this is a section name src = cls.get_section(src) if isinstance(src, dict): v = src.get(name) if v is not None: return v # No source has this var return None @classmethod def _get_var_from_source(cls, name: str, conf): if not isinstance(name, str): raise ValueError(f"var name must be str but got {type(name)}") # see whether command args have it if cls._cmd_args and name in cls._cmd_args: return cls._cmd_args.get(name), "cmd_args" if cls._var_dict and name in cls._var_dict: return cls._var_dict.get(name), "var_dict" value = cls._get_var_from_config_sources(name, conf) if value is not None: return value, "config" # finally check os env return cls._get_var_from_os_env(name), "env" @classmethod def _get_var(cls, name: str, conf): value, src = cls._get_var_from_source(name, conf) # print(f"#### VAR from {src}: {name}={value}") return value @classmethod def _int_var(cls, name: str, conf=None, default=None): v = cls._get_var(name, conf) if v is None: return default try: return int(v) except Exception as e: raise ValueError(f"var {name}'s value '{v}' cannot be converted to int: {e}") @classmethod def _any_var(cls, func, name, conf, default): if name in cls._var_values: return cls._var_values.get(name) v = func(name, conf, default) if v is not None: cls._var_values[name] = v return v
[docs] @classmethod def get_int_var(cls, name: str, conf=None, default=None): return cls._any_var(cls._int_var, name, conf, default)
@classmethod def _float_var(cls, name: str, conf=None, default=None): v = cls._get_var(name, conf) if v is None: return default try: return float(v) except: raise ValueError(f"var {name}'s value '{v}' cannot be converted to float")
[docs] @classmethod def get_float_var(cls, name: str, conf=None, default=None): return cls._any_var(cls._float_var, name, conf, default)
@classmethod def _bool_var(cls, name: str, conf=None, default=None): v = cls._get_var(name, conf) if v is None: return default if isinstance(v, bool): return v if isinstance(v, int): return v != 0 if isinstance(v, str): v = v.lower() return v in ["true", "t", "yes", "y", "1"] raise ValueError(f"var {name}'s value '{v}' cannot be converted to bool")
[docs] @classmethod def get_bool_var(cls, name: str, conf=None, default=None): return cls._any_var(cls._bool_var, name, conf, default)
@classmethod def _str_var(cls, name: str, conf=None, default=None): v = cls._get_var(name, conf) if v is None: return default try: return str(v) except: raise ValueError(f"var {name}'s value '{v}' cannot be converted to str")
[docs] @classmethod def get_str_var(cls, name: str, conf=None, default=None): return cls._any_var(cls._str_var, name, conf, default)
[docs] @classmethod def get_var_values(cls): return cls._var_values