Skip to content

Commit 85edafe

Browse files
committed
convert : parse safetensors directly
1 parent 8993982 commit 85edafe

File tree

2 files changed

+93
-9
lines changed

2 files changed

+93
-9
lines changed

convert_hf_to_gguf.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,7 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
187187
logger.info(f"gguf: indexing model part '{part_name}'")
188188
ctx: ContextManager[Any]
189189
if is_safetensors:
190-
from safetensors import safe_open
191-
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
190+
ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name))
192191
else:
193192
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
194193

@@ -197,18 +196,18 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
197196

198197
for name in model_part.keys():
199198
if is_safetensors:
199+
data: gguf.utility.LocalTensor = model_part[name]
200200
if self.lazy:
201-
data = model_part.get_slice(name)
202-
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
201+
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_meta(data) # noqa: E731
203202
else:
204-
data = model_part.get_tensor(name)
205-
data_gen = lambda data=data: data # noqa: E731
203+
dtype = LazyTorchTensor._dtype_str_map[data.dtype]
204+
data_gen = lambda data=data: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731
206205
else:
207-
data = model_part[name]
206+
data_torch: Tensor = model_part[name]
208207
if self.lazy:
209-
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
208+
data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data) # noqa: E731
210209
else:
211-
data_gen = lambda data=data: data # noqa: E731
210+
data_gen = lambda data=data_torch: data # noqa: E731
212211
tensors[name] = data_gen
213212

214213
# verify tensor name presence and identify potentially missing files
@@ -8614,6 +8613,16 @@ def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
86148613
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
86158614
return cast(torch.Tensor, lazy)
86168615

8616+
@classmethod
8617+
def from_safetensors_meta(cls, t: gguf.utility.LocalTensor) -> Tensor:
8618+
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
8619+
dtype = cls._dtype_str_map[tensor.dtype]
8620+
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
8621+
dtype = cls._dtype_str_map[t.dtype]
8622+
shape = t.shape
8623+
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
8624+
return cast(torch.Tensor, lazy)
8625+
86178626
@classmethod
86188627
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
86198628
dtype = cls._dtype_str_map[remote_tensor.dtype]

gguf-py/gguf/utility.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4+
from pathlib import Path
45
from typing import Literal
56

67
import os
78
import json
9+
import numpy as np
810

911

1012
def fill_templated_filename(filename: str, output_type: str | None) -> str:
@@ -266,3 +268,76 @@ def _get_request_headers(cls) -> dict[str, str]:
266268
if os.environ.get("HF_TOKEN"):
267269
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
268270
return headers
271+
272+
273+
@dataclass
274+
class LocalTensorRange:
275+
filename: Path
276+
offset: int
277+
size: int
278+
279+
280+
@dataclass
281+
class LocalTensor:
282+
dtype: str
283+
shape: tuple[int, ...]
284+
data_range: LocalTensorRange
285+
286+
def mmap_bytes(self) -> np.ndarray:
287+
return np.memmap(self.data_range.filename, offset=self.data_range.offset, shape=self.data_range.size)
288+
289+
290+
class SafetensorsLocal:
291+
"""
292+
Read a safetensors file from the local filesystem.
293+
294+
Custom parsing gives a bit more control over the memory usage.
295+
The official safetensors library doesn't expose file ranges.
296+
"""
297+
ALIGNMENT = 8 # bytes
298+
299+
tensors: dict[str, LocalTensor]
300+
301+
def __init__(self, filename: Path):
302+
with open(filename, "rb") as f:
303+
metadata_length = int.from_bytes(f.read(8), byteorder='little')
304+
file_size = os.stat(filename).st_size
305+
if file_size < 8 + metadata_length:
306+
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
307+
308+
metadata_str = f.read(metadata_length).decode('utf-8')
309+
try:
310+
metadata = json.loads(metadata_str)
311+
except json.JSONDecodeError as e:
312+
raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}")
313+
314+
data_start_offset = f.tell()
315+
alignment = self.ALIGNMENT
316+
if data_start_offset % alignment != 0:
317+
data_start_offset += alignment - (data_start_offset % alignment)
318+
319+
tensors: dict[str, LocalTensor] = {}
320+
for name, meta in metadata.items():
321+
if name == "__metadata__":
322+
# ignore metadata, it's not a tensor
323+
continue
324+
325+
tensors[name] = LocalTensor(
326+
dtype=meta["dtype"],
327+
shape=tuple(meta["shape"]),
328+
data_range=LocalTensorRange(
329+
filename,
330+
data_start_offset + meta["data_offsets"][0],
331+
meta["data_offsets"][1] - meta["data_offsets"][0],
332+
),
333+
)
334+
335+
# order by offset
336+
self.tensors = dict(sorted(tensors.items(), key=lambda t: t[1].data_range.offset))
337+
338+
def __enter__(self, *args, **kwargs):
339+
del args, kwargs # unused
340+
return self.tensors
341+
342+
def __exit__(self, *args, **kwargs):
343+
del args, kwargs # unused

0 commit comments

Comments
 (0)