Source code for nvflare.fuel.utils.import_utils

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


# Part of code is Adapted from from https://github.com/Project-MONAI/MONAI/blob/dev/monai/utils/module.py#L282
# which has the following license
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Part of code is Adapted from from https://github.com/Project-MONAI/MONAI/blob/dev/monai/utils/module.py#L282
"""
from importlib import import_module
from typing import Any, Tuple

from nvflare.security.logging import secure_format_exception

OPTIONAL_IMPORT_MSG_FMT = "{}"
OPS = ["==", ">=", ">", "<", "<="]


[docs]def get_module_version(this_pkg): return this_pkg.__version__.split(".")[:2]
[docs]def get_module_version_str(the_module): if the_module: module_version = ".".join(get_module_version(the_module)) else: module_version = "" return module_version
[docs]def check_version(that_pkg, version: str = "", op: str = "==") -> bool: """ compare module version with provided version """ if not version or not hasattr(that_pkg, "__version__"): return True # always valid version mod_version = tuple(int(x) for x in get_module_version(that_pkg)) required = tuple(int(x) for x in version.split(".")) result = True if op == "==": result = mod_version == required elif op == ">=": result = mod_version >= required elif op == ">": result = mod_version > required elif op == "<": result = mod_version < required elif op == "<=": result = mod_version <= required return result
[docs]class LazyImportError(ImportError): """ Could not import APIs from an optional dependency. """
[docs]def optional_import( module: str, op: str = "==", version: str = "", name: str = "", descriptor: str = OPTIONAL_IMPORT_MSG_FMT, allow_namespace_pkg: bool = False, ) -> Tuple[Any, bool]: """ Imports an optional module specified by `module` string. Any importing related exceptions will be stored, and exceptions raise lazily when attempting to use the failed-to-import module. Args: module: name of the module to be imported. op: version op it should be one of the followings: ==, <=, <, >, >= version: version string of the module if specified, it is used to the module version such that is satisfy the condition: <module>.__version__ <op> <version>. name: a non-module attribute (such as method/class) to import from the imported module. descriptor: a format string for the final error message when using a not imported module. allow_namespace_pkg: whether importing a namespace package is allowed. Defaults to False. Returns: The imported module and a boolean flag indicating whether the import is successful. Examples:: >>> torch, flag = optional_import('torch') >>> print(torch, flag) <module 'torch' from '/..../lib/python3.8/site-packages/torch/__init__.py'> True >>> torch, flag = optional_import('torch', '1.1') >>> print(torch, flag) <module 'torch' from 'python/lib/python3.6/site-packages/torch/__init__.py'> True >>> the_module, flag = optional_import('unknown_module') >>> print(flag) False >>> the_module.method # trying to access a module which is not imported OptionalImportError: import unknown_module (No module named 'unknown_module'). >>> torch, flag = optional_import('torch', '42') >>> torch.nn # trying to access a module for which there isn't a proper version imported OptionalImportError: import torch (requires version=='42'). >>> conv, flag = optional_import('torch.nn.functional', ">=", '1.0', name='conv1d') >>> print(conv) <built-in method conv1d of type object at 0x11a49eac0> >>> conv, flag = optional_import('torch.nn.functional', ">=", '42', name='conv1d') >>> conv() # trying to use a function from the not successfully imported module (due to unmatched version) OptionalImportError: from torch.nn.functional import conv1d (requires version>='42'). """ tb = None exception_str = "" if name: actual_cmd = f"from {module} import {name}" else: actual_cmd = f"import {module}" pkg = None try: if op not in OPS: raise ValueError(f"invalid op {op}, must be one of {OPS}") pkg = __import__(module) # top level module the_module = import_module(module) if not allow_namespace_pkg: is_namespace = getattr(the_module, "__file__", None) is None and hasattr(the_module, "__path__") if is_namespace: raise AssertionError if name: # user specified to load class/function/... from the module the_module = getattr(the_module, name) except Exception as import_exception: # any exceptions during import tb = import_exception.__traceback__ exception_str = secure_format_exception(import_exception) else: # found the module if check_version(pkg, version, op): return the_module, True # preparing lazy error message msg = descriptor.format(actual_cmd) if version and tb is None: # a pure version issue msg += f": requires '{module}{op}{version}'" if pkg: module_version = get_module_version_str(pkg) msg += f", current '{module}=={module_version}' " if exception_str: msg += f" ({exception_str})" class _LazyRaise: def __init__(self, attr_name, *_args, **_kwargs): self.attr_name = attr_name _default_msg = f"{msg}." if tb is None: self._exception = LazyImportError(_default_msg) else: self._exception = LazyImportError(_default_msg).with_traceback(tb) def __getattr__(self, attr_name): """ Raises: OptionalImportError: When you call this method. """ raise self._exception def __call__(self, *_args, **_kwargs): """ Raises: OptionalImportError: When you call this method. """ raise self._exception return _LazyRaise(name), False