# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
from enum import Enum
from typing import Any, Dict, Union
TEN_MEGA = 10 * 1024 * 1024
[docs]class DatumType(Enum):
TEXT = 1 # for text string
BLOB = 2 # for binary bytes
FILE = 3 # for file name
[docs]class Datum:
"""Datum is a class that holds information for externalized data"""
def __init__(self, datum_type: DatumType, value: Any):
"""Constructor of Datum object
Args:
datum_type: type of the datum.
value: value of the datum
"""
self.datum_id = str(uuid.uuid4())
self.datum_type = datum_type
self.value = value
self.restore_func = None # func to restore original object.
self.restore_func_data = None # arg to the restore func
[docs] def set_restore_func(self, func, func_data):
"""Set the restore function and func data.
Restore func is set during the serialization process. If set, the func will be called after the serialization
to restore the serialized object back to its original state.
Args:
func: the restore function
func_data: arg passed to the restore func when called
Returns: None
"""
if not callable(func):
raise ValueError(f"func must be callable but got {type(func)}")
self.restore_func = func
self.restore_func_data = func_data
[docs] @staticmethod
def blob_datum(blob: Union[bytes, bytearray, memoryview]):
"""Factory method to create a BLOB datum"""
return Datum(DatumType.BLOB, blob)
[docs] @staticmethod
def text_datum(text: str):
"""Factory method to create a TEXT datum"""
return Datum(DatumType.TEXT, text)
[docs] @staticmethod
def file_datum(path: str):
"""Factory method to crate a file datum"""
return Datum(DatumType.FILE, path)
[docs]class DatumRef:
"""A reference to externalized datum. If unwrap is true, the reference will be removed and replaced with the
content of the datum"""
def __init__(self, datum_id: str, unwrap=False):
self.datum_id = datum_id
self.unwrap = unwrap
[docs]class DatumManager:
def __init__(self, threshold=None):
if not threshold:
threshold = TEN_MEGA
if not isinstance(threshold, int):
raise TypeError(f"threshold must be int but got {type(threshold)}")
if threshold <= 0:
raise ValueError(f"threshold must > 0 but got {threshold}")
self.threshold = threshold
self.datums: Dict[str, Datum] = {}
# some decomposers (e.g. Shareable, Learnable, etc.) make a shallow copy of the original object before
# serialization. After serialization, only the values in the copy are restored. We need to keep a ref
# from the copy to the original object so that values in the original are also restored.
self.obj_copies = {} # copy id => original object
[docs] def register_copy(self, obj_copy, original_obj):
"""Register the object_copy => original object
Args:
obj_copy: a copy of the original object
original_obj: the original object
Returns: None
"""
self.obj_copies[id(obj_copy)] = original_obj
[docs] def get_original(self, obj_copy) -> Any:
"""Get the registered original object from the object copy.
Args:
obj_copy: a copy of the original object
Returns: the original object if found; None otherwise.
"""
return self.obj_copies.get(id(obj_copy))
[docs] def get_datums(self):
return self.datums
[docs] def get_datum(self, datum_id: str):
return self.datums.get(datum_id)
[docs] def externalize(self, data: Any):
if not isinstance(data, (bytes, bytearray, memoryview, Datum, str)):
return data
if isinstance(data, Datum):
# this is an app-defined datum. we need to keep it as is when deserialized.
# hence unwrap is set to False in the DatumRef.
self.datums[data.datum_id] = data
return DatumRef(data.datum_id, False)
if len(data) >= self.threshold:
# turn it to Datum
if isinstance(data, str):
d = Datum.text_datum(data)
else:
d = Datum.blob_datum(data)
self.datums[d.datum_id] = d
return DatumRef(d.datum_id, True)
else:
return data
[docs] def internalize(self, data: Any) -> Any:
if not isinstance(data, DatumRef):
return data
d = self.get_datum(data.datum_id)
if not d:
raise RuntimeError(f"can't find datum for {data.datum_id}")
if d.datum_type == DatumType.FILE:
return d
elif data.unwrap:
return d.value
else:
return d