# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PT lazy tensor references used by tensor disk offload.
When `enable_tensor_disk_offload=True`, incoming streamed tensor payloads are written
to temporary safetensors files instead of being fully deserialized into memory.
`LazyTensorDict` maps item IDs to on-disk files, and `_LazyRef` defers loading until
`materialize()` is called by aggregation code.
This keeps peak memory lower for large models while still allowing deterministic
explicit cleanup via `cleanup()`, with GC as a fallback through `_TempDirRef`.
"""
import logging
import shutil
from safetensors import safe_open
logger = logging.getLogger(__name__)
def _cleanup_temp_dir(path: str) -> None:
try:
shutil.rmtree(path)
except FileNotFoundError:
return
except Exception as e:
logger.warning("failed to cleanup tensor offload temp dir '%s': %s", path, e)
class _TempDirRef:
"""Reference-counted sentinel for a temp directory.
Shared between LazyTensorDict and all _LazyRef instances created from it.
The directory is deleted only when ALL holders are garbage collected.
"""
def __init__(self, temp_dir: str):
self.path = temp_dir
self._deleted = False
def cleanup(self):
if not self._deleted:
self._deleted = True
_cleanup_temp_dir(self.path)
def __del__(self):
self.cleanup()
class _LazyRef:
"""Lightweight placeholder for an on-disk tensor.
Carries only file_path + key (~100 bytes). The tensor is loaded from disk
only when materialize() is called, keeping memory near zero until then.
Holds a reference to _TempDirRef to prevent premature cleanup.
"""
def __init__(self, file_path: str, key: str, temp_ref: _TempDirRef):
self.file_path = file_path
self.key = key
self._temp_ref = temp_ref
def materialize(self):
"""Load tensor from safetensors file. Opens mmap, copies data out, closes mmap."""
with safe_open(self.file_path, framework="pt") as f:
return f.get_tensor(self.key)
def __repr__(self):
return f"_LazyRef({self.file_path!r}, key={self.key!r})"
[docs]
class LazyTensorDict:
"""Dict-like mapping of FOBS item_ids to on-disk safetensors files.
Each entry maps an item_id to a (file_path, key) pair. Tensors are loaded
via safetensors safe_open (mmap) on access.
"""
def __init__(self, key_to_file: dict[str, tuple[str, str]], temp_dir: str):
self._key_to_file = key_to_file
self._temp_ref = _TempDirRef(temp_dir)
def __getitem__(self, key):
file_path, st_key = self._key_to_file[key]
with safe_open(file_path, framework="pt") as f:
return f.get_tensor(st_key)
[docs]
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
[docs]
def keys(self):
return self._key_to_file.keys()
def __iter__(self):
return iter(self._key_to_file)
[docs]
def items(self):
for key in self._key_to_file:
yield key, self[key]
[docs]
def values(self):
for key in self._key_to_file:
yield self[key]
def __len__(self):
return len(self._key_to_file)
def __contains__(self, key):
return key in self._key_to_file
[docs]
def make_lazy_ref(self, key) -> "_LazyRef":
file_path, st_key = self._key_to_file[key]
return _LazyRef(file_path=file_path, key=st_key, temp_ref=self._temp_ref)
[docs]
def cleanup(self):
self._temp_ref.cleanup()