# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import importlib
import json
import os
import re
import sys
from typing import Any, Type
from nvflare.edge.simulation.device_task_processor import DeviceTaskProcessor
from nvflare.fuel.utils.validation_utils import check_positive_int, check_positive_number, check_str
VAR_PATTERN = re.compile(r"\{(.*?)}")
[docs]
def load_class(class_path) -> Type:
try:
if "." in class_path:
module_name, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
else:
return getattr(builtins, class_path)
except Exception as ex:
raise TypeError(f"Can't load class {class_path}: {ex}")
[docs]
class ConfigParser:
def __init__(self, config_file: str):
self.job_name = None
self.get_job_timeout = None
self.processor = None
self.endpoint = None
self.num_devices = 100
self.num_workers = 10
self.processor_class = None
self.processor_args = None
self.parse(config_file)
[docs]
def get_processor(self, variables: dict = None) -> DeviceTaskProcessor:
if self.processor_args:
args = self._variable_substitution(self.processor_args, variables)
else:
args = {}
return self.processor_class(**args)
[docs]
def get_endpoint(self):
return self.endpoint
[docs]
def get_num_devices(self):
return self.num_devices
[docs]
def get_num_workers(self):
return self.num_workers
[docs]
def get_job_name(self):
return self.job_name
[docs]
def parse(self, config_file: str):
with open(config_file, "r") as f:
config = json.load(f)
# Load processor
processor_config = config.get("processor", None)
if processor_config is None:
raise ValueError("processor is not defined in config file")
path = processor_config.get("python_path", None)
if not path:
# If no python_path defined, use the folder where the config file is
path = os.path.abspath(os.path.dirname(config_file))
sys.path.append(path)
path = processor_config.get("path")
if path is None:
raise ValueError("path for processor is not defined in config file")
self.processor_args = processor_config.get("args", {})
self.processor_class = load_class(path)
if not issubclass(self.processor_class, DeviceTaskProcessor):
raise TypeError(f"Processor {path} is not a subclass of DeviceTaskProcessor")
self.endpoint = config.get("endpoint", None)
if self.endpoint is not None:
check_str("endpoint", self.endpoint)
self.job_name = config.get("job_name", None)
check_str("job_name", self.job_name)
n = config.get("num_devices", None)
if n:
check_positive_int("num_devices", n)
self.num_devices = n
n = config.get("num_workers", None)
if n:
check_positive_int("num_workers", n)
self.num_workers = n
n = config.get("get_job_timeout", 60.0)
check_positive_number("get_job_timeout", n)
self.get_job_timeout = n
def _variable_substitution(self, args: Any, variables: dict) -> Any:
if isinstance(args, dict):
return {k: self._variable_substitution(v, variables) for k, v in args.items()}
elif isinstance(args, list):
return [self._variable_substitution(v, variables) for v in args]
elif isinstance(args, str):
result = args
offset = 0
for i, match in enumerate(VAR_PATTERN.finditer(result)):
start, end = match.span()
start += offset
end += offset
var = match.group(1)
if var in variables:
var_value = variables.get(var)
result = result[:start] + var_value + result[end:]
offset += len(var_value) - (end - start)
return result
else:
return args