Source code for nvflare.app_common.utils.math_utils

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
from typing import Callable, Optional, Tuple

operator_mapping = {
    ">=": operator.ge,
    "<=": operator.le,
    ">": operator.gt,
    "<": operator.lt,
    "=": operator.eq,
}


[docs] def parse_compare_criteria(compare_expr: Optional[str] = None) -> Tuple[str, float, Callable]: """Parses the compare expression into individual component. The compare expression is in the format of string literal : "<key> <op> <value>" such as .. code-block:: "accuracy >= 0.5" "loss > 2.4" Args: compare_expr: string literal in the format of "<key> <op> <value>" Returns: A Tuple of (key, value, operator) """ tokens = compare_expr.split(" ") if len(tokens) != 3: raise ValueError( f"Invalid early_stop_condition, expecting form of '<key> <op> <value>' but got '{compare_expr}'" ) key = tokens[0] op = tokens[1] target = tokens[2] op_fn = operator_mapping.get(op, None) if op_fn is None: raise ValueError("Invalid operator symbol: expecting one of <=, =, >=, <, > ") if not target: raise ValueError("Invalid empty or None target value") try: target_value = float(target) except Exception as e: raise ValueError(f"expect a number, but get '{target}' in '{compare_expr}'") return key, target_value, op_fn