Source code for nvflare.fuel.utils.tree_utils

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List

from nvflare.fuel.common.fqn import FQN


def _build_path(obj: Any, path: list, get_name_f, get_parent_f, **kwargs):
    name = get_name_f(obj, **kwargs)
    if name in path:
        return f"circular parent ref {name}"

    path.insert(0, name)
    parent = get_parent_f(obj, **kwargs)
    if not parent:
        return ""

    return _build_path(parent, path, get_name_f, get_parent_f, **kwargs)


[docs] def build_path(obj: Any, get_name_f, get_parent_f, **kwargs): """Build tree path for the specified object, which can be of any type. It is assumed that the object is already in an object tree. The tree path of an object is the list of objects that are traversed from the root to the object. Objects must have unique names. Args: obj: the object that the path will be built for get_name_f: the function that returns object name get_parent_f: the function that returns the object's parent object **kwargs: kwargs to be passed to the get_name_f and get_parent_f Returns: a tuple of (error, path). """ if not callable(get_name_f): raise ValueError("get_name_f is not callable") if not callable(get_parent_f): raise ValueError("get_parent_f is not callable") path = [] err = _build_path(obj, path, get_name_f, get_parent_f, **kwargs) return err, path
[docs] class Node: """Definition of Nodes that are elements of trees. The "self.obj" is the object associated with the node; The "self.parent" is the parent Node; The "self.children" is a list of child Nodes. """ def __init__(self, obj: Any, parent): """Constructor of Node Args: obj: the object to be associated with the node parent: parent node of the node """ self.obj = obj self.parent = parent # a Node self.children = [] # child nodes
[docs] class Forest: """A forest contains one or more trees. Each entry in "self.roots" is the unique name of the root object. The "self.nodes" a dict of: name => Node. You can get the Node from an object name. """ def __init__(self): """Constructor of Forest""" self.roots = [] # one or more names of the root nodes self.nodes = {} # name => Node
[docs] def build_forest(objs: List[Any], get_name_f, get_fqn_f, **kwargs) -> Forest: """Builds a Forest from a list of objects. Each object must have a unique "name" and unique "fqn". This function builds trees based on the FQNs of the objects. Args: objs: objects for which to build the forest get_name_f: the function to get object name get_fqn_f: the function to get object fqn **kwargs: kwargs to be passed to get_name_f and get_fqn_f Returns: a Forest object """ forest = Forest() f2n = {} # fqn => name n2f = {} # name => fqn f2o = {} # fqn => Obj for obj in objs: fqn = get_fqn_f(obj, **kwargs) n = get_name_f(obj, **kwargs) n2 = f2n.get(fqn) if n2: raise ValueError(f"two names ({n} and {n2}) have the same FQN {fqn}") fqn2 = n2f.get(n) if fqn2: raise ValueError(f"two FQNs ({fqn} and {fqn2}) have the same name {n}") f2n[fqn] = n f2o[fqn] = obj n2f[n] = fqn fqn_to_nodes = {} for fqn, obj in f2o.items(): parent_fqn = FQN.get_parent(fqn) node = Node(obj, parent_fqn) fqn_to_nodes[fqn] = node name = f2n[fqn] forest.nodes[name] = node # resolve fqns to nodes for name, node in forest.nodes.items(): assert isinstance(node, Node) if node.parent: # node.parent is the fqn of the parent parent_node = fqn_to_nodes.get(node.parent) if not parent_node: raise ValueError(f"missing node definition for FQN {node.parent}, which is the parent of {name}") node.parent = parent_node parent_node.children.append(node) else: # this node has no parent - it's a root forest.roots.append(name) return forest
def _dump_one(node: Node, get_name_f, **kwargs): name = get_name_f(node.obj, **kwargs) if not node.children: return name children = [] for n in node.children: children.append(_dump_one(n, get_name_f, **kwargs)) return {name: children}
[docs] def forest_to_dict(forest: Forest, get_name_f, **kwargs) -> dict: """Output the forest into a dict that shows the structure of the forest. Args: forest: the forest to output get_name_f: function to get object name **kwargs: kwargs passed to get_name_f Returns: a dict that shows the structure of the forest """ trees = [] for r in forest.roots: node = forest.nodes[r] trees.append(_dump_one(node, get_name_f, **kwargs)) return {"roots": forest.roots, "trees": trees}