# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, List, Optional, Tuple
from pyhocon import ConfigFactory, ConfigTree
from nvflare.fuel.utils.config import ConfigFormat
from nvflare.lighter.tool_consts import NVFLARE_PREFIX
from nvflare.tool.job.config.config_indexer import KeyIndex, build_reverse_order_index
from nvflare.tool.job.job_client_const import (
APP_CONFIG_DIR,
APP_CONFIG_FILE_BASE_NAMES,
APP_CONFIG_KEY,
APP_SCRIPT_KEY,
CONFIG_FED_CLIENT_CONF,
DEFAULT_APP_NAME,
JOB_META_BASE_NAME,
META_APP_NAME,
)
[docs]def merge_configs_from_cli(cmd_args, app_names: List[str]) -> Tuple[Dict[str, Dict[str, tuple]], bool]:
app_indices: Dict[str, Dict[str, Tuple]] = build_config_file_indices(cmd_args.job_folder, app_names)
app_cli_config_dict: Dict[str, Dict[str, Dict[str, str]]] = get_cli_config(cmd_args, app_names)
config_modified = False
if app_cli_config_dict:
config_modified = True
return merge_configs(app_indices, app_cli_config_dict), config_modified
else:
return app_indices, config_modified
[docs]def filter_indices(app_indices_configs: Dict[str, Dict[str, Tuple]]) -> Dict[str, Dict[str, Dict[str, KeyIndex]]]:
app_results = {}
for app_name in app_indices_configs:
indices_configs = app_indices_configs.get(app_name)
result = {}
for file, (config, excluded_key_list, key_indices) in indices_configs.items():
result[file] = filter_config_name_and_values(excluded_key_list, key_indices)
app_results[app_name] = result
return app_results
[docs]def filter_config_name_and_values(
excluded_key_list: List[str], key_indices: Dict[str, List[KeyIndex]]
) -> Dict[str, KeyIndex]:
temp_results = {}
for key, key_index_list in key_indices.items():
for key_index in key_index_list:
if key not in excluded_key_list and key_index.value not in excluded_key_list:
# duplicated key will be over-written by last one
temp_results[key] = key_index
return temp_results
def _cast_type(key_index, cli_value):
"""Casts cli_value to correct type.
Since build_reverse_order_index is using pyhocon, we need to do the same here.
"""
if key_index.value is None:
return cli_value
new_value = ConfigFactory.parse_string(f"{key_index.key}={cli_value}")[key_index.key]
return new_value
[docs]def split_array_key(key: str) -> Tuple:
if "[" not in key and "]" not in key:
return None, None, key
# Split key using '[' as delimiter
parent, rest_of_key = key.split("[", 1)
# Check if there is a ']' in the remaining part of the key
if "]" not in rest_of_key:
raise ValueError(f"invalid key '{key}'")
# Split the remaining part using ']' as delimiter
index_str, key = rest_of_key.split("]", 1)
# Convert index string to integer
try:
index = int(index_str)
except ValueError:
raise ValueError(f"invalid index '{index_str}' in key '{key}'")
# Remove leading '.' from the key, if any
key = key.lstrip(".")
return parent, index, key
[docs]def convert_to_number(value: str):
if not value:
return value
try:
if value.isdigit():
return int(value)
elif value.replace(".", "").isdigit():
return float(value)
else:
return value
except Exception as ex:
return value
[docs]def get_last_token(input_string):
if not input_string:
return input_string
tokens = input_string.split(".")
if len(tokens) > 1:
last_token = tokens[-1]
return last_token
else:
return input_string
[docs]def handle_key_in_path_notation_or_new_key(file: str, key: str, cli_value: str, config: ConfigTree, key_indices: Dict):
key_value = None
parent, index, key = split_array_key(key)
if parent is not None and index is not None:
# we have key expressed in the form of array such as component[index]
parent_config_list = config.get(parent)
if not isinstance(parent_config_list, list):
raise ValueError(f"invalid key '{key}' for file {file}")
index_config = parent_config_list[index]
if cli_value:
index_config.put(key, cli_value)
key_value = index_config.get(key)
else:
# if the value is None, we need to drop the key
index_config.pop(key)
else:
# we have key has no array component.
if cli_value:
config.put(key, cli_value)
key_value = config.get(key)
else:
config.pop(key)
last_token = get_last_token(key)
if key_value:
# now update the key
key_index = KeyIndex(key, key_value)
key_indices[last_token] = [key_index]
else:
# now drop the key
key_indices.pop(last_token)
[docs]def merge_configs(
app_indices_configs: Dict[str, Dict[str, tuple]], app_cli_file_configs: Dict[str, Dict[str, Dict]]
) -> Dict[str, Dict[str, tuple]]:
"""Merges configurations from indices_configs and cli_file_configs.
Args:
app_indices_configs (Dict[str, Dict[str, tuple]]): A dictionary containing indices and configurations.
app_cli_file_configs (Dict[str, Dict[str, Dict]]): A dictionary containing CLI configurations.
Returns:
Dict[str, Dict[str, Tuple]]: A dictionary of {app_name: merged configurations}.
Each of the merged configurations can be expressed in a Tuple: config, excluded_key_List, key_indices
"""
app_merged = {}
for app_name in app_indices_configs:
indices_configs = app_indices_configs[app_name]
cli_file_configs = app_cli_file_configs.get(app_name, None)
if cli_file_configs:
merged = {}
for file, (config, excluded_key_list, key_indices) in indices_configs.items():
if len(key_indices) > 0:
cli_configs = cli_file_configs.get(file, None)
if cli_configs:
for key, cli_value in cli_configs.items():
cli_value = convert_to_number(cli_value)
if key not in key_indices:
# not every client has app_config, app_script
if key not in [APP_SCRIPT_KEY, APP_CONFIG_KEY]:
if key.startswith(".") or key.endswith("."):
raise ValueError(f"invalid key {key} for file {file}")
handle_key_in_path_notation_or_new_key(file, key, cli_value, config, key_indices)
else:
if cli_value:
indices = key_indices.get(key)
for key_index in indices:
new_value = _cast_type(key_index, cli_value)
key_index.value = new_value
parent_key = key_index.parent_key
if parent_key and isinstance(parent_key.value, ConfigTree):
parent_key.value.put(key_index.key, new_value)
else:
key_indices.pop(key)
merged[file] = (config, excluded_key_list, key_indices)
app_merged[app_name] = merged
elif app_name == META_APP_NAME:
new_indices_configs = {}
for k, v in indices_configs.items():
new_indices_configs[k] = v
app_merged[app_name] = new_indices_configs
else:
app_merged[app_name] = indices_configs
return app_merged
[docs]def get_root_index(key_index: KeyIndex) -> Optional[KeyIndex]:
if key_index is None or key_index.parent_key is None:
return key_index
if key_index.parent_key is not None:
if key_index.parent_key.parent_key is None or key_index.parent_key.parent_key.key == "":
return key_index.parent_key
else:
return get_root_index(key_index.parent_key)
return None
[docs]def get_cli_config(cmd_args: Any, app_names: List[str]) -> Dict[str, Dict[str, Dict[str, str]]]:
"""Extracts configurations from command-line arguments and return them in a dictionary.
Args:
app_names: application names
cmd_args: Command-line arguments containing configuration data.
Returns:
A dictionary containing the configurations extracted from the command-line arguments.
"""
app_cli_config_dict = {}
if cmd_args.config_file:
cli_configs = cmd_args.config_file
app_cli_config_dict = _parse_cli_config(cmd_args.job_folder, cli_configs, app_names)
# replace "script"
if "script" in cmd_args and cmd_args.script:
script = os.path.basename(cmd_args.script)
if app_cli_config_dict:
key = CONFIG_FED_CLIENT_CONF
for _, cli_config_dict in app_cli_config_dict.items():
if key in cli_config_dict:
cli_config_dict[key].update({APP_SCRIPT_KEY: script})
else:
cli_config_dict[key] = {APP_SCRIPT_KEY: script}
else:
app_cli_config_dict = {DEFAULT_APP_NAME: {CONFIG_FED_CLIENT_CONF: {APP_SCRIPT_KEY: script}}}
return app_cli_config_dict
def _is_meta_file(filename: str) -> bool:
for postfix in ConfigFormat.extensions():
if filename == f"{JOB_META_BASE_NAME}{postfix}":
return True
return False
def _parse_cli_config(
job_folder: str, cli_configs: List[str], app_names: List[str]
) -> Dict[str, Dict[str, Dict[str, str]]]:
"""Extracts configurations from command-line arguments and return them in a dictionary.
Args:
job_folder: job_folder directory
app_names: application names
cli_configs: Array of CLI config option in the format of
-f filename.conf key1=v1 key2=v2
where <app_name>/config is omitted, default to app/config
or
-f <app_name>/filename.conf key1=v1 key2=v2
where config is omitted, default to <app_name>/config/filename.conf
or
-f <app_name>/config/filename.conf key1=v1 key2=v2
or
-f <app_name>/custom/filename.conf key1=v1 key2=v2
if filename.conf is meta.conf, the <app_name> = __meta_app__
separated by space
Returns:
A dictionary containing the configurations extracted from the command-line arguments.
"""
app_cli_config_dict = {}
if cli_configs:
for arr in cli_configs:
app_name = get_app_name_from_path(arr[0])
config_file = get_config_file_path(app_name, arr[0], job_folder)
config_data = arr[1:]
config_dict = {}
app_name = DEFAULT_APP_NAME if not app_name else app_name
if app_name not in app_names and app_name != DEFAULT_APP_NAME and app_name != META_APP_NAME:
raise ValueError(
f"Please specify one of the app names {app_names}. For example '<app_name>/xxx.conf k1=v1 k2=v2...'"
)
for conf in config_data:
if conf.endswith("-"):
conf_key = conf[:-1]
conf_value = None
else:
index = conf.find("=")
if index == -1:
raise ValueError("Invalid config data, expecting key, value pair in the format key=value")
conf_key = conf[0:index]
conf_value = conf[index + 1 :]
if conf_key.endswith("-"):
conf_key = conf_key[:-1]
conf_value = None
config_dict[conf_key] = conf_value
if app_name not in app_cli_config_dict:
app_cli_config_dict[app_name] = {}
app_cli_config_dict[app_name][config_file] = config_dict
return app_cli_config_dict
[docs]def get_config_file_path(app_name, input_file_path, job_folder):
basename = os.path.basename(input_file_path)
if basename.startswith(f"{JOB_META_BASE_NAME}."):
config_file = os.path.abspath(os.path.join(job_folder, basename))
else:
# The input_file_path could be in one of the following format
# <config_file_name> --> missing "app/config"
# <app_name>/<config_file_name> -- missing "config" directory
# <app_name>/config/<config_file_name> -- including "config" directory
# <app_name>/custom/<config_file_name> -- including "config" directory
# We need to handle all cases
if input_file_path.strip().startswith("/"):
raise ValueError(f"invalid config_file, {input_file_path}")
dirname = os.path.dirname(input_file_path)
if dirname == "":
# no dirname
config_file = os.path.abspath(os.path.join(job_folder, DEFAULT_APP_NAME, "config", basename))
else:
index = dirname.find("/")
if index == -1:
# no directory name, only app name
config_file = os.path.abspath(os.path.join(job_folder, app_name, "config", basename))
else:
# full path
config_file = os.path.abspath(os.path.join(job_folder, input_file_path))
return config_file
[docs]def build_config_file_indices(job_folder: str, app_names: List[str]) -> Dict[str, Dict[str, Tuple]]:
config_included = APP_CONFIG_FILE_BASE_NAMES
meta_base = JOB_META_BASE_NAME
config_extensions = ConfigFormat.extensions()
app_config_file_index = {}
app_config_files = {}
for ext in config_extensions:
meta_file = os.path.join(job_folder, f"{meta_base}{ext}")
if os.path.isfile(meta_file):
app_config_files[META_APP_NAME] = [meta_file]
break
for app_name in app_names:
app_dir = os.path.join(job_folder, app_name)
for ext in config_extensions:
for base in config_included:
file = os.path.abspath(os.path.join(app_dir, APP_CONFIG_DIR, f"{base}{ext}"))
if os.path.isfile(file):
config_files = app_config_files.get(app_name, [])
config_files.append(file)
app_config_files[app_name] = config_files
for app_name in app_names:
custom_dir = os.path.join(job_folder, app_name, "custom")
for root, dirs, files in os.walk(custom_dir):
for f in files:
for ext in config_extensions:
if f.endswith(ext) and not f.startswith(NVFLARE_PREFIX):
file = os.path.join(root, f)
config_files = app_config_files.get(app_name, [])
config_files.append(file)
app_config_files[app_name] = config_files
for app_name, config_files in app_config_files.items():
for f in config_files:
real_path, config, excluded_key_list, key_indices = build_reverse_order_index(str(f))
config_file_index = app_config_file_index.get(app_name, {})
config_file_index[real_path] = (config, excluded_key_list, key_indices)
app_config_file_index[app_name] = config_file_index
return app_config_file_index
[docs]def get_app_name_from_path(path: str):
# path is in the format of the following:
# path xxx.conf
# path app1/xxx.conf
# path app1/config/xxx.conf
# path app1/custom/xxx.conf
if _is_meta_file(os.path.basename(path)):
return META_APP_NAME
if os.path.isabs(path):
raise ValueError(f"Expecting <config file> or <app_name>/xxx/<config file>, but '{path}' is given.")
segs = path.split(os.path.sep)
if len(segs) == 1:
return DEFAULT_APP_NAME
else:
return segs[0]