Source code for nvflare.fuel.utils.fobs.datum

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
from enum import Enum
from typing import Any, Callable, Dict, Union

TEN_MEGA = 10 * 1024 * 1024
MIN_THRESHOLD = 1024


[docs] class DatumType(Enum): TEXT = 1 # for text string BLOB = 2 # for binary bytes FILE = 3 # for file name
[docs] class Datum: """Datum is a class that holds information for externalized data""" def __init__(self, datum_type: DatumType, value: Any, dot=0): """Constructor of Datum object Args: datum_type: type of the datum. value: value of the datum dot: the Object Type of the datum """ self.datum_id = str(uuid.uuid4()) self.datum_type = datum_type self.dot = dot self.value = value self.restore_func = None # func to restore original object. self.restore_func_data = None # arg to the restore func
[docs] def set_restore_func(self, func, func_data): """Set the restore function and func data. Restore func is set during the serialization process. If set, the func will be called after the serialization to restore the serialized object back to its original state. Args: func: the restore function func_data: arg passed to the restore func when called Returns: None """ if not callable(func): raise ValueError(f"func must be callable but got {type(func)}") self.restore_func = func self.restore_func_data = func_data
[docs] @staticmethod def blob_datum(blob: Union[bytes, bytearray, memoryview], dot=0): """Factory method to create a BLOB datum""" return Datum(DatumType.BLOB, blob, dot)
[docs] @staticmethod def text_datum(text: str, dot=0): """Factory method to create a TEXT datum""" return Datum(DatumType.TEXT, text, dot)
[docs] @staticmethod def file_datum(path: str, dot=0): """Factory method to crate a file datum""" return Datum(DatumType.FILE, path, dot)
[docs] class DatumRef: """A reference to externalized datum. If unwrap is true, the reference will be removed and replaced with the content of the datum""" def __init__(self, datum_id: str, unwrap=False): self.datum_id = datum_id self.unwrap = unwrap
[docs] class DatumManager: def __init__(self, threshold=None, fobs_ctx: dict = None): if not threshold: threshold = TEN_MEGA if not isinstance(threshold, int): raise TypeError(f"threshold must be int but got {type(threshold)}") if threshold < MIN_THRESHOLD: raise ValueError(f"threshold must be at least {MIN_THRESHOLD} but got {threshold}") if not fobs_ctx: fobs_ctx = {} self.threshold = threshold self.datums: Dict[str, Datum] = {} self.fobs_ctx = fobs_ctx # some decomposers (e.g. Shareable, Learnable, etc.) make a shallow copy of the original object before # serialization. After serialization, only the values in the copy are restored. We need to keep a ref # from the copy to the original object so that values in the original are also restored. self.obj_copies = {} # copy id => original object # Post CBs are called after the serialize process is done # Post CBs could be used, for example, to prepare files to be downloaded by the message receiver self.post_cbs = [] self.error = None # save error text
[docs] def add_datum(self, d: Datum): self.datums[d.datum_id] = d
[docs] def get_fobs_context(self): """Get the FOBS Context associated with the manager. The context is available during the whole process of serialization/deserialization of a single message. Since Decomposers are singleton objects that could be used by multiple decomposition processes concurrently, processing state data must not be stored in the decomposer! Instead, such data should be stored in the FOBS context. Returns: """ return self.fobs_ctx
[docs] def register_post_cb(self, cb: Callable[["DatumManager"], None], **cb_kwargs): """Register a callback that will be called after the decomposition is done during serialization process. The callback is typically registered during decomposition by decomposers. Note that the callback itself could also call this method to register additional callbacks. These callbacks will be appended to the callback list. The manager's post CB processing continues until all registered callbacks are invoked. Args: cb: the callback to be registered **cb_kwargs: kwargs to be passed to the callback when invoked Returns: """ if not callable(cb): raise ValueError("cb is not callable") self.post_cbs.append((cb, cb_kwargs))
[docs] def set_error(self, error: str): """Set an error with the manager. The manager will eventually raise RuntimeError at the end of serialization if any error is set. Args: error: the error to be set Returns: None """ if error and not self.error: self.error = error
[docs] def get_error(self): """Get the error set with the manager Returns: the error set with the manager """ return self.error
[docs] def post_process(self): """Invoke all post serialization callbacks. Called during serialization after all objects are decomposed. Returns: None """ # must guarantee that all post_cbs are called! i = 0 while True: # we cannot use a simple for-loop here since a cb could register additional CBs during processing! if i >= len(self.post_cbs): return cb, cb_kwargs = self.post_cbs[i] i += 1 try: cb(self, **cb_kwargs) except Exception as ex: self.set_error(f"exception from post_cb {cb.__name__}: {type(ex)}")
[docs] def register_copy(self, obj_copy, original_obj): """Register the object_copy => original object Args: obj_copy: a copy of the original object original_obj: the original object Returns: None """ self.obj_copies[id(obj_copy)] = original_obj
[docs] def get_original(self, obj_copy) -> Any: """Get the registered original object from the object copy. Args: obj_copy: a copy of the original object Returns: the original object if found; None otherwise. """ return self.obj_copies.get(id(obj_copy))
[docs] def get_datums(self): return self.datums
[docs] def get_datum(self, datum_id: str): return self.datums.get(datum_id)
[docs] def externalize(self, data: Any): if not isinstance(data, (bytes, bytearray, memoryview, Datum, str)): return data if isinstance(data, Datum): # this is an app-defined datum. we need to keep it as is when deserialized. # hence unwrap is set to False in the DatumRef. self.add_datum(data) return DatumRef(data.datum_id, False) if len(data) >= self.threshold: # turn it to Datum if isinstance(data, str): d = Datum.text_datum(data) else: d = Datum.blob_datum(data) self.add_datum(d) return DatumRef(d.datum_id, True) else: return data
[docs] def internalize(self, data: Any) -> Any: if not isinstance(data, DatumRef): return data d = self.get_datum(data.datum_id) if not d: raise RuntimeError(f"can't find datum for {data.datum_id}") if d.datum_type == DatumType.FILE: return d elif data.unwrap: return d.value else: return d