33import gguf
44import argparse
55import concurrent .futures
6+ from concurrent .futures import ThreadPoolExecutor , ProcessPoolExecutor
67import copy
78import enum
89import faulthandler
1718import signal
1819import struct
1920import sys
21+ import time
2022import zipfile
2123import numpy as np
2224
2325from abc import ABCMeta , abstractmethod
2426from dataclasses import dataclass
2527from pathlib import Path
26- from typing import (IO , TYPE_CHECKING , Any , Callable , Dict , Iterable , List , Literal , Optional , Sequence , Tuple , TypeVar , Union )
28+ from typing import (IO , TYPE_CHECKING , Any , Callable , Dict , Generator , Iterable , List , Literal , Optional , Sequence , Set , Tuple , TypeVar , Union )
2729from sentencepiece import SentencePieceProcessor # type: ignore
2830
2931if TYPE_CHECKING :
3739ARCH = gguf .MODEL_ARCH .LLAMA
3840NAMES = gguf .MODEL_TENSOR_NAMES [ARCH ]
3941
42+ DEFAULT_CONCURRENCY = 8
4043#
4144# data types
4245#
4346
4447@dataclass (frozen = True )
45- class UnquantizedDataType :
48+ class DataType :
4649 name : str
50+ dtype : 'np.dtype[Any]'
51+ valid_conversions : List [str ]
4752
48- DT_F16 = UnquantizedDataType ('F16' )
49- DT_F32 = UnquantizedDataType ('F32' )
50- DT_I32 = UnquantizedDataType ('I32' )
51- DT_BF16 = UnquantizedDataType ('BF16' )
53+ def elements_to_bytes (self , n_elements : int ) -> int :
54+ return n_elements * self .dtype .itemsize
5255
53- DataType = Union [UnquantizedDataType ]
56+ @dataclass (frozen = True )
57+ class UnquantizedDataType (DataType ):
58+ pass
5459
55- DATA_TYPE_TO_NUMPY : Dict [DataType , 'np.dtype[Any]' ] = {
56- DT_BF16 : np .dtype (np .uint16 ),
57- DT_F16 : np .dtype (np .float16 ),
58- DT_F32 : np .dtype (np .float32 ),
59- DT_I32 : np .dtype (np .int32 ),
60- }
60+ DT_F16 = UnquantizedDataType ('F16' , dtype = np .dtype (np .float16 ), valid_conversions = ['F32' , 'Q8_0' ])
61+ DT_F32 = UnquantizedDataType ('F32' , dtype = np .dtype (np .float32 ), valid_conversions = ['F16' , 'Q8_0' ])
62+ DT_I32 = UnquantizedDataType ('I32' , dtype = np .dtype (np .int16 ), valid_conversions = [])
63+ DT_BF16 = UnquantizedDataType ('BF16' , dtype = np .dtype (np .uint16 ), valid_conversions = ['F32' , 'F16' , 'Q8_0' ])
64+
65+ @dataclass (frozen = True )
66+ class QuantizedDataType (DataType ):
67+ block_size : int
68+ quantized_dtype : 'np.dtype[Any]'
69+ ggml_type : gguf .GGMLQuantizationType
6170
62- NUMPY_TYPE_TO_DATA_TYPE : Dict ['np.dtype[Any]' , DataType ] = \
63- {dtype : data_type for (data_type , dtype ) in DATA_TYPE_TO_NUMPY .items ()}
71+ def quantize (self , arr : NDArray ) -> NDArray :
72+ raise NotImplementedError (f'Quantization for { self .name } not implemented' )
73+
74+ def elements_to_bytes (self , n_elements : int ) -> int :
75+ assert n_elements % self .block_size == 0 , f'Invalid number of elements { n_elements } for { self .name } with block size { self .block_size } '
76+ return self .quantized_dtype .itemsize * (n_elements // self .block_size )
77+
78+ @dataclass (frozen = True )
79+ class Q8_0QuantizedDataType (QuantizedDataType ):
80+ # Mini Q8_0 quantization in Python!
81+ def quantize (self , arr : NDArray ) -> NDArray :
82+ assert arr .size % self .block_size == 0 and arr .size != 0 , f'Bad array size { arr .size } '
83+ assert arr .dtype == np .float32 , f'Bad array type { arr .dtype } '
84+ n_blocks = arr .size // self .block_size
85+ blocks = arr .reshape ((n_blocks , self .block_size ))
86+ # Much faster implementation of block quantization contributed by @Cebtenzzre
87+ def quantize_blocks_q8_0 (blocks : NDArray ) -> Iterable [Tuple [Any , Any ]]:
88+ d = abs (blocks ).max (axis = 1 ) / np .float32 (127 )
89+ with np .errstate (divide = 'ignore' ):
90+ qs = (blocks / d [:, None ]).round ()
91+ qs [d == 0 ] = 0
92+ yield from zip (d , qs )
93+ return np .fromiter (quantize_blocks_q8_0 (blocks ), count = n_blocks , dtype = self .quantized_dtype )
94+
95+ DT_Q8_0 = Q8_0QuantizedDataType ('Q8_0' ,
96+ dtype = np .dtype (np .float32 ), valid_conversions = [],
97+ ggml_type = gguf .GGMLQuantizationType .Q8_0 , block_size = 32 ,
98+ quantized_dtype = np .dtype ([('d' , '<f2' ), ('qs' , 'i1' , (32 ,))]))
99+
100+ # Quantized types skipped here because they may also map to np.float32
101+ NUMPY_TYPE_TO_DATA_TYPE : Dict ['np.dtype[Any]' , DataType ] = {}
102+ for dt in (DT_BF16 , DT_F16 , DT_F32 , DT_I32 ):
103+ if dt .dtype in NUMPY_TYPE_TO_DATA_TYPE :
104+ raise ValueError (f'Invalid duplicate data type { dt } ' )
105+ NUMPY_TYPE_TO_DATA_TYPE [dt .dtype ] = dt
64106
65107SAFETENSORS_DATA_TYPES : Dict [str , DataType ] = {
66108 'BF16' : DT_BF16 ,
@@ -73,20 +115,22 @@ class UnquantizedDataType:
73115# TODO: rename to LLAMAFileType
74116# TODO: move to `gguf.py`
75117class GGMLFileType (enum .IntEnum ):
76- AllF32 = 0
77- MostlyF16 = 1 # except 1d tensors
118+ AllF32 = 0
119+ MostlyF16 = 1 # except 1d tensors
120+ MostlyQ8_0 = 7 # except 1d tensors
78121
79122 def type_for_tensor (self , name : str , tensor : 'LazyTensor' ) -> DataType :
80- if len (tensor .shape ) == 1 :
81- # 1D tensors are always F32.
82- return DT_F32
83- elif self == GGMLFileType .AllF32 :
84- return DT_F32
85- elif self == GGMLFileType .MostlyF16 :
86- return DT_F16
87- else :
123+ dt = GGML_FILE_TYPE_TO_DATA_TYPE .get (self )
124+ if dt is None :
88125 raise ValueError (self )
126+ # 1D tensors are always F32.
127+ return dt if len (tensor .shape ) > 1 else DT_F32
89128
129+ GGML_FILE_TYPE_TO_DATA_TYPE : Dict [GGMLFileType , DataType ] = {
130+ GGMLFileType .AllF32 : DT_F32 ,
131+ GGMLFileType .MostlyF16 : DT_F16 ,
132+ GGMLFileType .MostlyQ8_0 : DT_Q8_0 ,
133+ }
90134
91135#
92136# hparams loading
@@ -415,7 +459,7 @@ def __init__(self, ndarray: NDArray) -> None:
415459 self .data_type = NUMPY_TYPE_TO_DATA_TYPE [ndarray .dtype ]
416460
417461 def astype (self , data_type : DataType ) -> Tensor :
418- dtype = DATA_TYPE_TO_NUMPY [ data_type ]
462+ dtype = data_type . dtype
419463 if self .data_type == DT_BF16 :
420464 self .ndarray = bf16_to_fp32 (self .ndarray )
421465 return UnquantizedTensor (self .ndarray .astype (dtype ))
@@ -454,22 +498,6 @@ def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, conv
454498GGMLCompatibleTensor = Union [UnquantizedTensor ]
455499
456500
457- class DeferredPermutedTensor (Tensor ):
458- def __init__ (self , base : Tensor , n_head : int , n_head_kv : int ) -> None :
459- self .base = base
460- self .n_head = n_head
461- self .data_type = self .base .data_type
462-
463- def astype (self , data_type : DataType ) -> Tensor :
464- return self .base .astype (data_type ).permute (self .n_head , self .n_head_kv )
465-
466- def to_ggml (self ) -> GGMLCompatibleTensor :
467- return self .base .to_ggml ().permute (self .n_head , self .n_head_kv )
468-
469- def permute (self , n_head : int , n_head_kv : int ) -> Tensor :
470- raise Exception ("shouldn't permute twice" )
471-
472-
473501@dataclass
474502class LazyTensor :
475503 _load : Callable [[], Tensor ]
@@ -479,7 +507,9 @@ class LazyTensor:
479507
480508 def load (self ) -> Tensor :
481509 ret = self ._load ()
482- assert ret .data_type == self .data_type , (self .data_type , ret .data_type , self .description )
510+ # Should be okay if it maps to the same numpy type?
511+ assert ret .data_type == self .data_type or (self .data_type .dtype == ret .data_type .dtype ), \
512+ (self .data_type , ret .data_type , self .description )
483513 return ret
484514
485515 def astype (self , data_type : DataType ) -> 'LazyTensor' :
@@ -490,8 +520,8 @@ def load() -> Tensor:
490520 return LazyTensor (load , self .shape , data_type , f'convert({ data_type } ) { self .description } ' )
491521
492522 def validate_conversion_to (self , data_type : DataType ) -> None :
493- if data_type == self .data_type :
494- return
523+ if data_type != self .data_type and data_type . name not in self . data_type . valid_conversions :
524+ raise ValueError ( f'Cannot validate conversion from { self . data_type } to { data_type } .' )
495525
496526
497527LazyModel = Dict [str , LazyTensor ]
@@ -617,9 +647,7 @@ def persistent_load(self, pid: Any) -> Any:
617647 info = self .zip_file .getinfo (filename )
618648
619649 def load (offset : int , elm_count : int ) -> NDArray :
620- dtype = DATA_TYPE_TO_NUMPY .get (data_type )
621- if dtype is None :
622- raise Exception ("tensor stored in unsupported format" )
650+ dtype = data_type .dtype
623651 fp = self .zip_file .open (info )
624652 fp .seek (offset * dtype .itemsize )
625653 size = elm_count * dtype .itemsize
@@ -683,7 +711,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
683711
684712 def convert (info : Dict [str , Any ]) -> LazyTensor :
685713 data_type = SAFETENSORS_DATA_TYPES [info ['dtype' ]]
686- numpy_dtype = DATA_TYPE_TO_NUMPY [ data_type ]
714+ numpy_dtype = data_type . dtype
687715 shape : List [int ] = info ['shape' ]
688716 begin , end = info ['data_offsets' ]
689717 assert 0 <= begin <= end <= len (byte_buf )
@@ -723,23 +751,35 @@ def lazy_load_file(path: Path) -> ModelPlus:
723751In = TypeVar ('In' )
724752Out = TypeVar ('Out' )
725753
726- def bounded_parallel_map (func : Callable [[In ], Out ], iterable : Iterable [In ], concurrency : int ) -> Iterable [Out ]:
754+ def bounded_parallel_map (func : Callable [[In ], Out ], iterable : Iterable [In ], concurrency : int , max_workers : Optional [ int ] = None , factory : Callable = ThreadPoolExecutor ) -> Iterable [Out ]:
727755 '''Parallel map, but with backpressure. If the caller doesn't call `next`
728756 fast enough, this will stop calling `func` at some point rather than
729757 letting results pile up in memory. Specifically, there is a max of one
730758 output value buffered per thread.'''
731- with concurrent .futures .ThreadPoolExecutor () as executor :
759+ if concurrency < 2 :
760+ yield from map (func , iterable )
761+ # Not reached.
762+ iterable = iter (iterable )
763+ with factory (max_workers = max_workers ) as executor :
732764 futures : List [concurrent .futures .Future [Out ]] = []
733- items_rev = list (iterable )[::- 1 ]
734- for i in range (min (concurrency , len (items_rev ))):
735- futures .append (executor .submit (func , items_rev .pop ()))
765+ done = False
766+ for _ in range (concurrency ):
767+ try :
768+ futures .append (executor .submit (func , next (iterable )))
769+ except StopIteration :
770+ done = True
771+ break
772+
736773 while futures :
737774 result = futures .pop (0 ).result ()
738- if items_rev :
739- futures .append (executor .submit (func , items_rev .pop ()))
775+ while not done and len (futures ) < concurrency :
776+ try :
777+ futures .append (executor .submit (func , next (iterable )))
778+ except StopIteration :
779+ done = True
780+ break
740781 yield result
741782
742-
743783def check_vocab_size (params : Params , vocab : Vocab ) -> None :
744784 if params .n_vocab != vocab .vocab_size :
745785 assert isinstance (vocab , BpeVocab ) or isinstance (vocab , SentencePieceVocab )
@@ -804,12 +844,11 @@ def add_meta_vocab(self, vocab: Vocab) -> None:
804844 self .gguf .add_token_types (toktypes )
805845
806846 def add_tensor_info (self , name : str , tensor : LazyTensor ) -> None :
807- n_elements = 1
808- for dim in tensor .shape :
809- n_elements *= dim
810- data_type = DATA_TYPE_TO_NUMPY [tensor .data_type ]
811- data_nbytes = n_elements * data_type .itemsize
812- self .gguf .add_tensor_info (name , tensor .shape , data_type , data_nbytes )
847+ n_elements = int (np .prod (tensor .shape ))
848+ raw_dtype = getattr (tensor .data_type , 'ggml_type' , None )
849+ data_type = getattr (tensor .data_type , 'quantized_type' , None ) or tensor .data_type .dtype
850+ data_nbytes = tensor .data_type .elements_to_bytes (n_elements )
851+ self .gguf .add_tensor_info (name , tensor .shape , data_type , data_nbytes , raw_dtype = raw_dtype )
813852
814853 def write_meta (self ) -> None :
815854 self .gguf .write_header_to_file ()
@@ -835,7 +874,20 @@ def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None:
835874 of .close ()
836875
837876 @staticmethod
838- def write_all (fname_out : Path , params : Params , model : LazyModel , vocab : Vocab ) -> None :
877+ def do_item (item : Tuple [str , LazyTensor ]) -> Tuple [DataType , NDArray ]:
878+ name , lazy_tensor = item
879+ tensor = lazy_tensor .load ().to_ggml ()
880+ return (lazy_tensor .data_type , tensor .ndarray )
881+
882+ @staticmethod
883+ def maybe_do_quantize (item : Tuple [DataType , NDArray ]) -> NDArray :
884+ dt , arr = item
885+ if not isinstance (dt , QuantizedDataType ):
886+ return arr
887+ return dt .quantize (arr )
888+
889+ @staticmethod
890+ def write_all (fname_out : Path , ftype : GGMLFileType , params : Params , model : LazyModel , vocab : Vocab , concurrency : int = DEFAULT_CONCURRENCY ) -> None :
839891 check_vocab_size (params , vocab )
840892
841893 of = OutputFile (fname_out )
@@ -851,16 +903,19 @@ def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -
851903 of .write_meta ()
852904 of .write_tensor_info ()
853905
854- def do_item (item : Tuple [str , LazyTensor ]) -> NDArray :
855- name , lazy_tensor = item
856- return lazy_tensor .load ().to_ggml ().ndarray
857-
858906 # tensor data
859- ndarrays = bounded_parallel_map (do_item , model .items (), concurrency = 8 )
907+ ndarrays_inner = bounded_parallel_map (OutputFile .do_item , model .items (), concurrency = concurrency )
908+ if ftype == GGMLFileType .MostlyQ8_0 :
909+ ndarrays = bounded_parallel_map (OutputFile .maybe_do_quantize , ndarrays_inner , concurrency = concurrency , max_workers = concurrency , factory = ProcessPoolExecutor )
910+ else :
911+ ndarrays = map (OutputFile .maybe_do_quantize , ndarrays_inner )
912+
913+ start = time .time ()
860914 for i , ((name , lazy_tensor ), ndarray ) in enumerate (zip (model .items (), ndarrays )):
915+ elapsed = time .time () - start
861916 size = ' x ' .join (f"{ dim :6d} " for dim in lazy_tensor .shape )
862917 padi = len (str (len (model )))
863- print (f"[{ i + 1 :{padi }d} /{ len (model )} ] Writing tensor { name :38s} | size { size :16} | type { lazy_tensor .data_type } " )
918+ print (f"[{ i + 1 :{padi }d} /{ len (model )} ] Writing tensor { name :38s} | size { size :16} | type { lazy_tensor .data_type . name :4 } | T+ { int ( elapsed ):4 } " )
864919 of .gguf .write_tensor_data (ndarray )
865920
866921 of .close ()
@@ -872,6 +927,8 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi
872927 return GGMLFileType .AllF32
873928 if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16 , DT_BF16 )):
874929 return GGMLFileType .MostlyF16
930+ if output_type_str == "q8_0" :
931+ return GGMLFileType .MostlyQ8_0
875932
876933 name_to_type = {name : lazy_tensor .data_type for (name , lazy_tensor ) in model .items ()}
877934
@@ -918,7 +975,7 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
918975 print (f"skipping tensor { name_new } " )
919976 continue
920977 else :
921- print (f"{ name :48s} -> { name_new :40s} | { lazy_tensor .data_type } | { lazy_tensor .shape } " )
978+ print (f"{ name :48s} -> { name_new :40s} | { lazy_tensor .data_type . name :6s } | { lazy_tensor .shape } " )
922979 out [name_new ] = lazy_tensor
923980
924981 return out
@@ -1023,6 +1080,7 @@ def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
10231080 namestr = {
10241081 GGMLFileType .AllF32 : "f32" ,
10251082 GGMLFileType .MostlyF16 : "f16" ,
1083+ GGMLFileType .MostlyQ8_0 :"q8_0" ,
10261084 }[file_type ]
10271085 ret = model_paths [0 ].parent / f"ggml-model-{ namestr } .gguf"
10281086 if ret in model_paths :
@@ -1046,12 +1104,13 @@ def main(args_in: Optional[List[str]] = None) -> None:
10461104 parser .add_argument ("--dump" , action = "store_true" , help = "don't convert, just show what's in the model" )
10471105 parser .add_argument ("--dump-single" , action = "store_true" , help = "don't convert, just show what's in a single model file" )
10481106 parser .add_argument ("--vocab-only" , action = "store_true" , help = "extract only the vocab" )
1049- parser .add_argument ("--outtype" , choices = ["f32" , "f16" ], help = "output format (default: based on input)" )
1107+ parser .add_argument ("--outtype" , choices = ["f32" , "f16" , "q8_0" ], help = "output format - note: q8_0 may be very slow (default: f16 or f32 based on input)" )
10501108 parser .add_argument ("--vocab-dir" , type = Path , help = "directory containing tokenizer.model, if separate from model file" )
10511109 parser .add_argument ("--outfile" , type = Path , help = "path to write to; default: based on input" )
10521110 parser .add_argument ("model" , type = Path , help = "directory containing model file, or model file itself (*.pth, *.pt, *.bin)" )
10531111 parser .add_argument ("--vocabtype" , choices = ["spm" , "bpe" ], help = "vocab format (default: spm)" , default = "spm" )
10541112 parser .add_argument ("--ctx" , type = int , help = "model training context (default: based on input)" )
1113+ parser .add_argument ("--concurrency" , type = int , help = f"concurrency used for conversion (default: { DEFAULT_CONCURRENCY } )" , default = DEFAULT_CONCURRENCY )
10551114 args = parser .parse_args (args_in )
10561115
10571116 if args .dump_single :
@@ -1073,6 +1132,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
10731132 params .ftype = {
10741133 "f32" : GGMLFileType .AllF32 ,
10751134 "f16" : GGMLFileType .MostlyF16 ,
1135+ "q8_0" : GGMLFileType .MostlyQ8_0 ,
10761136 }[args .outtype ]
10771137
10781138 print (f"params = { params } " )
@@ -1104,7 +1164,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
11041164 params .ftype = ftype
11051165 print (f"Writing { outfile } , format { ftype } " )
11061166
1107- OutputFile .write_all (outfile , params , model , vocab )
1167+ OutputFile .write_all (outfile , ftype , params , model , vocab , concurrency = args . concurrency )
11081168 print (f"Wrote { outfile } " )
11091169
11101170
0 commit comments