Source code for nvflare.fuel.utils.fobs.fobs

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import logging
import os
from enum import Enum
from os.path import dirname, join
from typing import Any, BinaryIO, Dict, Type, TypeVar, Union

import msgpack

from nvflare.fuel.utils.fobs.datum import DatumManager
from nvflare.fuel.utils.fobs.decomposer import DataClassDecomposer, Decomposer, EnumTypeDecomposer

__all__ = [
    "register",
    "register_data_classes",
    "register_enum_types",
    "auto_register_enum_types",
    "register_folder",
    "num_decomposers",
    "serialize",
    "serialize_stream",
    "deserialize",
    "deserialize_stream",
    "reset",
]

from nvflare.security.logging import secure_format_exception

FOBS_TYPE = "__fobs_type__"
FOBS_DATA = "__fobs_data__"
MAX_CONTENT_LEN = 128
MSGPACK_TYPES = (None, bool, int, float, str, bytes, bytearray, memoryview, list, dict)
T = TypeVar("T")

log = logging.getLogger(__name__)
_decomposers: Dict[str, Decomposer] = {}
_decomposers_registered = False
_enum_auto_register = True


class Packer:
    def __init__(self, manager: DatumManager):
        self.manager = manager

    def pack(self, obj: Any) -> dict:

        if type(obj) in MSGPACK_TYPES:
            return obj

        type_name = _get_type_name(obj.__class__)
        if type_name not in _decomposers:
            if _enum_auto_register and isinstance(obj, Enum):
                register_enum_types(type(obj))
            else:
                return obj

        decomposed = _decomposers[type_name].decompose(obj, self.manager)
        if self.manager:
            decomposed = self.manager.externalize(decomposed)

        return {FOBS_TYPE: type_name, FOBS_DATA: decomposed}

    def unpack(self, obj: Any) -> Any:

        if type(obj) is not dict or FOBS_TYPE not in obj:
            return obj

        type_name = obj[FOBS_TYPE]
        if type_name not in _decomposers:
            error = True
            if _enum_auto_register:
                cls = self._load_class(type_name)
                if issubclass(cls, Enum):
                    register_enum_types(cls)
                    error = False
            if error:
                raise TypeError(f"Unknown type {type_name}, caused by mismatching decomposers")

        data = obj[FOBS_DATA]
        if self.manager:
            data = self.manager.internalize(data)

        decomposer = _decomposers[type_name]
        return decomposer.recompose(data, self.manager)

    @staticmethod
    def _load_class(type_name: str):
        parts = type_name.split(".")
        if len(parts) == 1:
            parts = ["builtins", type_name]

        mod = __import__(parts[0])
        for comp in parts[1:]:
            mod = getattr(mod, comp)

        return mod


def _get_type_name(cls: Type) -> str:
    module = cls.__module__
    if module == "builtins":
        return cls.__qualname__
    return module + "." + cls.__qualname__


[docs]def register(decomposer: Union[Decomposer, Type[Decomposer]]) -> None: """Register a decomposer. It does nothing if decomposer is already registered for the type Args: decomposer: The decomposer type or instance """ global _decomposers if inspect.isclass(decomposer): instance = decomposer() else: instance = decomposer name = _get_type_name(instance.supported_type()) if name in _decomposers: return if not isinstance(instance, Decomposer): log.error(f"Class {instance.__class__} is not a decomposer") return _decomposers[name] = instance
[docs]def register_data_classes(*data_classes: Type[T]) -> None: """Register generic decomposers for data classes Args: data_classes: The classes to be registered """ for data_class in data_classes: decomposer = DataClassDecomposer(data_class) register(decomposer)
[docs]def register_enum_types(*enum_types: Type[Enum]) -> None: """Register generic decomposers for enum classes Args: enum_types: The enum classes to be registered """ for enum_type in enum_types: if not issubclass(enum_type, Enum): raise TypeError(f"Can't register class {enum_type}, which is not a subclass of Enum") decomposer = EnumTypeDecomposer(enum_type) register(decomposer)
[docs]def auto_register_enum_types(enabled=True) -> None: """Enable or disable auto registering of enum classes Args: enabled: Auto-registering of enum classes is enabled if True """ global _enum_auto_register _enum_auto_register = enabled
[docs]def register_folder(folder: str, package: str): """Scan the folder and register all decomposers found. Args: folder: The folder to scan package: The package to import the decomposers from """ for module in os.listdir(folder): if module != "__init__.py" and module[-3:] == ".py": decomposers = package + "." + module[:-3] imported = importlib.import_module(decomposers, __package__) for _, cls_obj in inspect.getmembers(imported, inspect.isclass): spec = inspect.getfullargspec(cls_obj.__init__) # classes who are abstract or take extra args in __init__ can't be auto-registered if issubclass(cls_obj, Decomposer) and not inspect.isabstract(cls_obj) and len(spec.args) == 1: register(cls_obj)
def _register_decomposers(): global _decomposers_registered if _decomposers_registered: return register_folder(join(dirname(__file__), "decomposers"), ".decomposers") _decomposers_registered = True
[docs]def num_decomposers() -> int: """Returns the number of decomposers registered. Returns: The number of decomposers """ return len(_decomposers)
[docs]def serialize(obj: Any, manager: DatumManager = None, **kwargs) -> bytes: """Serialize object into bytes. Args: obj: Object to be serialized manager: Datum manager used to externalize datum kwargs: Arguments passed to msgpack.packb Returns: Serialized data """ _register_decomposers() packer = Packer(manager) try: return msgpack.packb(obj, default=packer.pack, strict_types=True, **kwargs) except ValueError as ex: content = str(obj) if len(content) > MAX_CONTENT_LEN: content = content[:MAX_CONTENT_LEN] + " ..." raise ValueError(f"Object {type(obj)} is not serializable: {secure_format_exception(ex)}: {content}")
[docs]def serialize_stream(obj: Any, stream: BinaryIO, manager: DatumManager = None, **kwargs): """Serialize object and write the data to a stream. Args: obj: Object to be serialized stream: Stream to write the result to manager: Datum manager to externalize datum kwargs: Arguments passed to msgpack.packb """ data = serialize(obj, manager, **kwargs) stream.write(data)
[docs]def deserialize(data: bytes, manager: DatumManager = None, **kwargs) -> Any: """Deserialize bytes into an object. Args: data: Serialized data manager: Datum manager to internalize datum kwargs: Arguments passed to msgpack.unpackb Returns: Deserialized object """ _register_decomposers() packer = Packer(manager) return msgpack.unpackb(data, strict_map_key=False, object_hook=packer.unpack, **kwargs)
[docs]def deserialize_stream(stream: BinaryIO, manager: DatumManager = None, **kwargs) -> Any: """Deserialize bytes from stream into an object. Args: stream: Stream to write serialized data to manager: Datum manager to internalize datum kwargs: Arguments passed to msgpack.unpackb Returns: Deserialized object """ data = stream.read() return deserialize(data, manager, **kwargs)
[docs]def reset(): """Reset FOBS to initial state. Used for unit test""" # global _decomposers, _decomposers_registered # _decomposers.clear() # _decomposers_registered = False pass