|
""" |
|
Pure python version of Safetensors safe_open |
|
From https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282 |
|
""" |
|
|
|
import json |
|
import mmap |
|
import os |
|
|
|
import torch |
|
|
|
|
|
class SafetensorsWrapper: |
|
def __init__(self, metadata, tensors): |
|
self._metadata = metadata |
|
self._tensors = tensors |
|
|
|
def metadata(self): |
|
return self._metadata |
|
|
|
def keys(self): |
|
return self._tensors.keys() |
|
|
|
def get_tensor(self, k): |
|
return self._tensors[k] |
|
|
|
|
|
DTYPES = { |
|
"F32": torch.float32, |
|
"F16": torch.float16, |
|
"BF16": torch.bfloat16, |
|
} |
|
|
|
|
|
def create_tensor(storage, info, offset): |
|
dtype = DTYPES[info["dtype"]] |
|
shape = info["shape"] |
|
start, stop = info["data_offsets"] |
|
return ( |
|
torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8) |
|
.view(dtype=dtype) |
|
.reshape(shape) |
|
) |
|
|
|
|
|
def safe_open(filename, framework="pt", device="cpu"): |
|
if framework != "pt": |
|
raise ValueError("`framework` must be 'pt'") |
|
|
|
with open(filename, mode="r", encoding="utf8") as file_obj: |
|
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m: |
|
header = m.read(8) |
|
n = int.from_bytes(header, "little") |
|
metadata_bytes = m.read(n) |
|
metadata = json.loads(metadata_bytes) |
|
|
|
size = os.stat(filename).st_size |
|
storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped() |
|
offset = n + 8 |
|
|
|
return SafetensorsWrapper( |
|
metadata=metadata.get("__metadata__", {}), |
|
tensors={ |
|
name: create_tensor(storage, info, offset).to(device) |
|
for name, info in metadata.items() |
|
if name != "__metadata__" |
|
}, |
|
) |
|
|