1+ from pathlib import Path
12
2- import unittest
3+ import pytest
34import torch
4- import os
5- from pathlib import Path
6- from torchao ._models .llama .tokenizer import get_tokenizer
7- from torchao ._models .llama .model import Transformer , prepare_inputs_for_model
8- from torchao .quantization .GPTQ_MT import Int4WeightOnlyGPTQQuantizer , MultiTensor
9- import sys
10- from safetensors .torch import load_file # Import safetensors loader
115import torch .nn .functional as F
126
7+ from torchao ._models .llama .model import Transformer , prepare_inputs_for_model
8+ from torchao ._models .llama .tokenizer import get_tokenizer
9+ from torchao .quantization .GPTQ_MT import Int4WeightOnlyGPTQQuantizer , MultiTensor
1310from torchao .quantization .utils import _lm_eval_available
11+
1412if _lm_eval_available :
15-
13+ hqq_core = pytest . importorskip ( "hqq.core" , reason = "requires hqq" )
1614 import lm_eval
15+
1716 try : # lm_eval version 0.4
1817 from lm_eval .evaluator import evaluate
1918 from lm_eval .models .huggingface import HFLM as eval_wrapper
@@ -49,8 +48,7 @@ def __init__(
4948 self .calibration_seq_length = calibration_seq_length
5049
5150 self .input_prep_func = (
52- input_prep_func if input_prep_func is not None
53- else lambda x : (x ,)
51+ input_prep_func if input_prep_func is not None else lambda x : (x ,)
5452 )
5553
5654 self .pad_calibration_inputs = pad_calibration_inputs
@@ -164,13 +162,9 @@ class TransformerEvalWrapper(InputRecorder):
164162 """
165163 A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library.
166164 """
165+
167166 def __init__ (
168- self ,
169- model ,
170- tokenizer ,
171- max_seq_length ,
172- input_prep_func = None ,
173- device = "cuda"
167+ self , model , tokenizer , max_seq_length , input_prep_func = None , device = "cuda"
174168 ):
175169 super ().__init__ (tokenizer , None )
176170 self ._model = model
@@ -181,41 +175,38 @@ def __init__(
181175 # need to take inps and convert to corrent input
182176 # for model
183177 self .input_prep_func = (
184- input_prep_func if input_prep_func is not None
185- else lambda x : (x ,)
178+ input_prep_func if input_prep_func is not None else lambda x : (x ,)
186179 )
187180
188181 def _model_call (self , inps ):
189182 # print("Entering _model_call")
190183 # print(f"Input shape: {inps.shape}")
191-
184+
192185 input = self .input_prep_func (inps )
193186 # print(f"Processed input shapes: {[x.shape for x in input]}")
194-
187+
195188 input = [x .to (self ._device ) for x in input ]
196189 # print(f"Inputs moved to device: {self._device}")
197-
190+
198191 max_seq_length = min (max (inps .size ()), self .max_length )
199192 # print(f"Max sequence length: {max_seq_length}")
200-
193+
201194 # print("Setting up caches")
202195 with torch .device (self ._device ):
203196 # print(f"Device: {self._device}")
204197 # print(f"Batch size: {self.batch_size}")
205198 # print(f"Max sequence length: {max_seq_length}")
206199 self ._model .setup_caches (self .batch_size , max_seq_length )
207200 # print("Caches set up")
208-
201+
209202 # print("Running model")
210203 # torch.save(input, "input.pt")
211204 logits = self ._model (* input )
212205 # print(f"Model run complete. Logits shape: {logits.shape}")
213206 return logits
214-
215-
216207
217208 def _model_generate (self , context , max_length , eos_token_id ):
218- raise Exception (' unimplemented' )
209+ raise Exception (" unimplemented" )
219210
220211 def run_eval (self , tasks , limit ):
221212 logger .info (f"Starting evaluation on tasks: { tasks } " )
@@ -238,26 +229,21 @@ def run_eval(self, tasks, limit):
238229
239230 logger .info ("Starting evaluation" )
240231 start_time = time .time ()
241-
232+
242233 try :
243234 with torch .no_grad ():
244- result = evaluate (
245- self ,
246- task_dict ,
247- limit = limit ,
248- verbosity = "DEBUG"
249- )
235+ result = evaluate (self , task_dict , limit = limit , verbosity = "DEBUG" )
250236 except Exception as e :
251237 logger .error (f"Evaluation failed: { e } " )
252238 raise
253-
239+
254240 end_time = time .time ()
255241 logger .info (f"Evaluation completed in { end_time - start_time :.2f} seconds" )
256242
257243 logger .info ("Evaluation results:" )
258244 for task , res in result ["results" ].items ():
259245 print (f"{ task } : { res } " )
260-
246+
261247 return result
262248
263249
@@ -289,34 +275,41 @@ def run_eval(self, tasks, limit):
289275input_prep_func = prepare_inputs_for_model
290276pad_calibration_inputs = False
291277print ("Recording inputs" )
292- inputs = InputRecorder (
278+ inputs = (
279+ InputRecorder (
293280 tokenizer ,
294281 calibration_seq_length ,
295282 input_prep_func ,
296283 pad_calibration_inputs ,
297284 model .config .vocab_size ,
298285 device = "cpu" ,
299- ).record_inputs (
286+ )
287+ .record_inputs (
300288 calibration_tasks ,
301289 calibration_limit ,
302- ).get_inputs ()
290+ )
291+ .get_inputs ()
292+ )
303293print ("Inputs recorded" )
304294quantizer = Int4WeightOnlyGPTQQuantizer (
305- blocksize ,
306- percdamp ,
307- groupsize ,
308- )
309-
295+ blocksize ,
296+ percdamp ,
297+ groupsize ,
298+ )
299+
310300model .setup_caches (max_batch_size = 1 , max_seq_length = calibration_seq_length )
311- multi = [MultiTensor ([ inp for inp , _ in inputs ]),MultiTensor ([ inds for _ , inds in inputs ])]
301+ multi = [
302+ MultiTensor ([inp for inp , _ in inputs ]),
303+ MultiTensor ([inds for _ , inds in inputs ]),
304+ ]
312305print ("Quantizing model" )
313306model = quantizer .quantize (model , multi ).cuda ()
314307print ("Model quantized" )
315308print ("Saving model and fixing state dict" )
316- regular_state_dict = model .state_dict ()# defaultdict(torch.tensor)
309+ regular_state_dict = model .state_dict () # defaultdict(torch.tensor)
317310for key , value in model .state_dict ().items ():
318311 if isinstance (value , MultiTensor ):
319- regular_state_dict [key ] = value .values [0 ]
312+ regular_state_dict [key ] = value .values [0 ]
320313 else :
321314 regular_state_dict [key ] = value
322315
@@ -326,16 +319,16 @@ def run_eval(self, tasks, limit):
326319 del regular_state_dict [k ]
327320
328321model .load_state_dict (regular_state_dict , assign = True )
329- torch .save (model .state_dict (), ' model.pth' )
322+ torch .save (model .state_dict (), " model.pth" )
330323print ("Running evaluation" )
331324result = TransformerEvalWrapper (
332- model .to (device ), # quantized model needs to run on cuda
333- tokenizer ,
334- model .config .block_size ,
335- prepare_inputs_for_model ,
336- ).run_eval (
337- ["wikitext" ],
338- None ,
339- )
325+ model .to (device ), # quantized model needs to run on cuda
326+ tokenizer ,
327+ model .config .block_size ,
328+ prepare_inputs_for_model ,
329+ ).run_eval (
330+ ["wikitext" ],
331+ None ,
332+ )
340333
341334# wikitext: {'word_perplexity,none': 12.523175352665858, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.6042723245990418, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.681919059499152, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}
0 commit comments