# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
import torch
from safetensors.torch import _remove_duplicate_names, load, load_file, save, save_file
import nvflare.fuel.utils.fobs.dots as dots
from nvflare.fuel.utils.fobs.datum import DatumManager
from nvflare.fuel.utils.fobs.decomposers.via_file import ViaFileDecomposer
MIN_SIZE_FOR_FILE = 0 # The default value
[docs]
class SerializationModule(torch.nn.Module):
def __init__(self, tensor):
super().__init__()
self.register_buffer("saved_tensor", tensor)
def _safe_save(state_dict, filename: str) -> Optional[dict]:
"""Save model weights with the safetensors format.
The model weights may contain tensors with shared memory. In this case, save_file won't work.
We first try to find and remove such tensors, and then save the remaining tensors with save_file.
We then return the information about the removed tensors as a dict.
The key of the dict is the name of the tensor kept in the weights.
The value is a list of tensor names that are to be substituted by the kept tensor.
For example, the state_dict contains multiple tensors:
{
"t1": t1,
"t2": t2,
"t3": t3,
"t4": t4
}
Suppose tensors t1, t2 and t3 are shared, the state_dict after removing shared tensors will look like this:
{
"t1": t1,
"t4": t4
}
And the removed tensors dict looks like this:
{
"t1": ["t2", "t3"]
}
Args:
state_dict: the model weights to be saved
filename: name of the file
Returns: a dict that contains removed tensor info
"""
to_removes = _remove_duplicate_names(state_dict)
for kept_name, to_remove_group in to_removes.items():
for to_remove in to_remove_group:
del state_dict[to_remove]
state_dict = {k: v.contiguous() for k, v in state_dict.items()}
save_file(state_dict, filename)
if to_removes:
# to_removes is dict-like but not a simple dict
return {k: v for k, v in to_removes.items()}
else:
return None
[docs]
class TensorDecomposer(ViaFileDecomposer):
def __init__(self):
ViaFileDecomposer.__init__(self, MIN_SIZE_FOR_FILE, "tensor_")
[docs]
def supported_type(self):
return torch.Tensor
[docs]
def dump_to_file(self, items: dict, path: str, fobs_ctx: dict):
try:
meta = _safe_save(items, path)
removed = len(meta) if meta else 0
self.logger.info(f"Saving {len(items)} tensors to file {path}: Number of duplicate removed: {removed}")
return path, meta
except Exception as e:
self.logger.error(f"exception saving tensors to file: {e}")
raise e
[docs]
def load_from_file(self, path: str, fobs_ctx: dict, meta: dict = None) -> Any:
items = load_file(path)
self.logger.info(f"loaded {len(items)} tensor(s) from file {path}")
if meta:
# the meta keeps names of removed tensors and the name of the tensor for them
for kept, removed_group in meta.items():
for r in removed_group:
items[r] = items[kept]
return items
[docs]
def get_bytes_dot(self) -> int:
return dots.TENSOR_BYTES
[docs]
def get_file_dot(self) -> int:
return dots.TENSOR_FILE
[docs]
def native_decompose(self, target: torch.Tensor, manager: DatumManager = None) -> bytes:
# save the tensor to bytes using safetensors
dummy = {"t": target}
return save(dummy)
[docs]
def native_recompose(self, data: bytes, manager: DatumManager = None) -> torch.Tensor:
# load safetensors generated bytes
dummy = load(data)
if not isinstance(dummy, dict):
raise ValueError(f"failed to load data: should be dict but got {type(dummy)}")
return dummy.get("t")