Source code for nvflare.edge.device.config

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, List


[docs] class ConfigError(Exception): pass
[docs] class ConfigKey: NAME = "name" TYPE = "type" ARGS = "args" TRAINER = "trainer" EXECUTORS = "executors" COMPONENTS = "components" IN_FILTERS = "in_filters" OUT_FILTERS = "out_filters" HANDLERS = "handlers"
[docs] class TrainConfig: def __init__(self, objects: dict, in_filters, out_filters, event_handlers, executors: dict): self.objects = objects self.in_filters = in_filters self.out_filters = out_filters self.event_handlers = event_handlers self.executors = executors
[docs] def find_executor(self, task_name: str): if not self.executors: return self.objects.get(ConfigKey.TRAINER) e = self.executors.get(task_name) if e: return e else: return self.executors.get("*")
[docs] class ComponentResolver: """A ComponentResolver resolves component spec into a device-native object.""" def __init__(self, comp_type, name, args, obj_class=None): self.comp_type = comp_type self.comp_name = name if not args: args = {} self.comp_args = args self.obj_class = obj_class
[docs] def resolve(self) -> Any: """Resolve the component spec and create device-native object. Returns: a device-native object or None if failed. """ return self.obj_class(**self.comp_args)
def _determine_value(item: Any, resolvers: dict) -> Any: """Determine value of the specified item: recursively replace component refs with the ComponentResolver objects of the referenced components. Args: item: the item whose value is to be determined resolvers: table of resolvers Returns: """ if isinstance(item, list): for i, v in enumerate(item): item[i] = _determine_value(v, resolvers) return item elif isinstance(item, dict): for k, v in item.items(): item[k] = _determine_value(v, resolvers) return item elif not isinstance(item, str): return item if not item.startswith("@"): return item else: referenced_name = item[1:] c = resolvers.get(referenced_name) if not c: raise ConfigError(f"referenced component '{referenced_name}' does not exist") return c def _find_obj(item, obj_table): """Try to find the native object(s) for the item: recursively process all ComponentResolvers and replace them with their native objects following the structure of the item (list or dict). Args: item: the item to be processed obj_table: the object table that contains objects already created Returns: the item itself (with referenced components replaced with objects); or in case that the item is a ComponentResolver, the native object created by it """ if isinstance(item, ComponentResolver): # has this component been resolved? obj = obj_table.get(item.comp_name) if obj is None: # not resolved yet return item else: # already resolved return obj elif isinstance(item, list): for i, v in enumerate(item): item[i] = _find_obj(v, obj_table) return item elif isinstance(item, dict): for k, v in item.items(): item[k] = _find_obj(v, obj_table) return item else: return item def _try_to_resolve(resolver: ComponentResolver, obj_table: dict) -> Any: """Try to create device-native object. If created, place the obj in the obj_table. Args: resolver: the ComponentResolver that will try to resolve its component obj_table: object table that keeps objects of resolved components Returns: the resolved object, or None if the resolver is not ready For the resolver to be ready, all of its args must be resolved already, meaning that if an arg references another component, the referenced component must be resolved. """ if isinstance(resolver.comp_args, dict): for k, v in resolver.comp_args.items(): v = _find_obj(v, obj_table) resolver.comp_args[k] = v if isinstance(v, ComponentResolver): # not ready to resolve this component since this referenced component has not been resolved return None obj = resolver.resolve() if obj is None: raise ConfigError(f"failed to resolve component {resolver.comp_name}") obj_table[resolver.comp_name] = obj return obj def _process_components(component_config: dict, resolver_registry: dict): # Step 1: create a ComponentResolver for each component spec in the config resolvers = {} # name => ComponentResolver for c in component_config: name = c.get(ConfigKey.NAME) comp_type = c.get(ConfigKey.TYPE) clazz = resolver_registry.get(comp_type) if not clazz: raise ConfigError(f"no ComponentResolver registered for component type {comp_type}") comp_args = c.get(ConfigKey.ARGS) if issubclass(clazz, ComponentResolver): resolver = clazz(comp_type, name, comp_args) else: if not inspect.isclass(clazz): raise ConfigError(f"resolver for component {comp_type} is not a valid class") # the clazz is the native object's class resolver = ComponentResolver(comp_type, name, comp_args, clazz) if not resolver: raise ConfigError(f"cannot make resolver for component {name} of type {comp_type}") if name in resolvers: raise ConfigError(f"duplicate component definition for '{name}'") resolvers[name] = resolver # find ComponentResolver for referenced components for name, resolver in resolvers.items(): assert isinstance(resolver, ComponentResolver) if not isinstance(resolver.comp_args, dict): # the args could be None continue for k, v in resolver.comp_args.items(): resolver.comp_args[k] = _determine_value(v, resolvers) # repeatedly trying to resolve components until all are done. # the "resolve" method of ComponentResolver creates device objects based on the args. obj_table = {} while resolvers: resolved = [] for name, resolver in resolvers.items(): obj = _try_to_resolve(resolver, obj_table) if obj is not None: resolved.append(name) if not resolved: # nothing is resolved - there are cyclic refs raise ConfigError(f"cannot resolve components {resolvers.keys()}") for n in resolved: resolvers.pop(n) return obj_table def _resolve_ref(ref, obj_table: dict): if not ref.startswith("@"): raise ConfigError(f"invalid ref {ref}") referenced_name = ref[1:] return obj_table.get(referenced_name) def _process_refs(refs: List[str], obj_table: dict): for i, r in enumerate(refs): obj = _resolve_ref(r, obj_table) if not obj: raise ConfigError(f"cannot find object for reference {r}") refs[i] = obj
[docs] def process_train_config(config: dict, resolver_registry: dict) -> TrainConfig: components = config.get(ConfigKey.COMPONENTS) if not components: raise ConfigError(f"missing {ConfigKey.COMPONENTS} in config") obj_table = _process_components(components, resolver_registry) in_filters = config.get(ConfigKey.IN_FILTERS) if in_filters: if not isinstance(in_filters, list): raise ConfigError(f"{ConfigKey.IN_FILTERS} should be list of str but got {type(in_filters)}") _process_refs(in_filters, obj_table) out_filters = config.get(ConfigKey.OUT_FILTERS) if out_filters: if not isinstance(out_filters, list): raise ConfigError(f"{ConfigKey.OUT_FILTERS} should be list of str but got {type(out_filters)}") _process_refs(out_filters, obj_table) handlers = config.get(ConfigKey.HANDLERS) if handlers: if not isinstance(handlers, list): raise ConfigError(f"{ConfigKey.HANDLERS} should be list of str but got {type(handlers)}") _process_refs(handlers, obj_table) # process executors executor_config = config.get(ConfigKey.EXECUTORS) executors = {} if executor_config: if not isinstance(executor_config, dict): raise ConfigError(f"{ConfigKey.EXECUTORS} should be dict but got {type(executor_config)}") for k, v in executor_config.items(): executors[k] = _resolve_ref(v, obj_table) return TrainConfig(obj_table, in_filters, out_filters, handlers, executors)