Skip to content

Commit 5b8c308

Browse files
authored
Move test file to right location (#1503)
1 parent de3f812 commit 5b8c308

File tree

1 file changed

+49
-56
lines changed

1 file changed

+49
-56
lines changed

test_gptq_mt.py renamed to test/quantization/test_gptq_mt.py

Lines changed: 49 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
1+
from pathlib import Path
12

2-
import unittest
3+
import pytest
34
import torch
4-
import os
5-
from pathlib import Path
6-
from torchao._models.llama.tokenizer import get_tokenizer
7-
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
8-
from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer, MultiTensor
9-
import sys
10-
from safetensors.torch import load_file # Import safetensors loader
115
import torch.nn.functional as F
126

7+
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
8+
from torchao._models.llama.tokenizer import get_tokenizer
9+
from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer, MultiTensor
1310
from torchao.quantization.utils import _lm_eval_available
11+
1412
if _lm_eval_available:
15-
13+
hqq_core = pytest.importorskip("hqq.core", reason="requires hqq")
1614
import lm_eval
15+
1716
try: # lm_eval version 0.4
1817
from lm_eval.evaluator import evaluate
1918
from lm_eval.models.huggingface import HFLM as eval_wrapper
@@ -49,8 +48,7 @@ def __init__(
4948
self.calibration_seq_length = calibration_seq_length
5049

5150
self.input_prep_func = (
52-
input_prep_func if input_prep_func is not None
53-
else lambda x: (x,)
51+
input_prep_func if input_prep_func is not None else lambda x: (x,)
5452
)
5553

5654
self.pad_calibration_inputs = pad_calibration_inputs
@@ -164,13 +162,9 @@ class TransformerEvalWrapper(InputRecorder):
164162
"""
165163
A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library.
166164
"""
165+
167166
def __init__(
168-
self,
169-
model,
170-
tokenizer,
171-
max_seq_length,
172-
input_prep_func=None,
173-
device="cuda"
167+
self, model, tokenizer, max_seq_length, input_prep_func=None, device="cuda"
174168
):
175169
super().__init__(tokenizer, None)
176170
self._model = model
@@ -181,41 +175,38 @@ def __init__(
181175
# need to take inps and convert to corrent input
182176
# for model
183177
self.input_prep_func = (
184-
input_prep_func if input_prep_func is not None
185-
else lambda x: (x,)
178+
input_prep_func if input_prep_func is not None else lambda x: (x,)
186179
)
187180

188181
def _model_call(self, inps):
189182
# print("Entering _model_call")
190183
# print(f"Input shape: {inps.shape}")
191-
184+
192185
input = self.input_prep_func(inps)
193186
# print(f"Processed input shapes: {[x.shape for x in input]}")
194-
187+
195188
input = [x.to(self._device) for x in input]
196189
# print(f"Inputs moved to device: {self._device}")
197-
190+
198191
max_seq_length = min(max(inps.size()), self.max_length)
199192
# print(f"Max sequence length: {max_seq_length}")
200-
193+
201194
# print("Setting up caches")
202195
with torch.device(self._device):
203196
# print(f"Device: {self._device}")
204197
# print(f"Batch size: {self.batch_size}")
205198
# print(f"Max sequence length: {max_seq_length}")
206199
self._model.setup_caches(self.batch_size, max_seq_length)
207200
# print("Caches set up")
208-
201+
209202
# print("Running model")
210203
# torch.save(input, "input.pt")
211204
logits = self._model(*input)
212205
# print(f"Model run complete. Logits shape: {logits.shape}")
213206
return logits
214-
215-
216207

217208
def _model_generate(self, context, max_length, eos_token_id):
218-
raise Exception('unimplemented')
209+
raise Exception("unimplemented")
219210

220211
def run_eval(self, tasks, limit):
221212
logger.info(f"Starting evaluation on tasks: {tasks}")
@@ -238,26 +229,21 @@ def run_eval(self, tasks, limit):
238229

239230
logger.info("Starting evaluation")
240231
start_time = time.time()
241-
232+
242233
try:
243234
with torch.no_grad():
244-
result = evaluate(
245-
self,
246-
task_dict,
247-
limit=limit,
248-
verbosity= "DEBUG"
249-
)
235+
result = evaluate(self, task_dict, limit=limit, verbosity="DEBUG")
250236
except Exception as e:
251237
logger.error(f"Evaluation failed: {e}")
252238
raise
253-
239+
254240
end_time = time.time()
255241
logger.info(f"Evaluation completed in {end_time - start_time:.2f} seconds")
256242

257243
logger.info("Evaluation results:")
258244
for task, res in result["results"].items():
259245
print(f"{task}: {res}")
260-
246+
261247
return result
262248

263249

@@ -289,34 +275,41 @@ def run_eval(self, tasks, limit):
289275
input_prep_func = prepare_inputs_for_model
290276
pad_calibration_inputs = False
291277
print("Recording inputs")
292-
inputs = InputRecorder(
278+
inputs = (
279+
InputRecorder(
293280
tokenizer,
294281
calibration_seq_length,
295282
input_prep_func,
296283
pad_calibration_inputs,
297284
model.config.vocab_size,
298285
device="cpu",
299-
).record_inputs(
286+
)
287+
.record_inputs(
300288
calibration_tasks,
301289
calibration_limit,
302-
).get_inputs()
290+
)
291+
.get_inputs()
292+
)
303293
print("Inputs recorded")
304294
quantizer = Int4WeightOnlyGPTQQuantizer(
305-
blocksize,
306-
percdamp,
307-
groupsize,
308-
)
309-
295+
blocksize,
296+
percdamp,
297+
groupsize,
298+
)
299+
310300
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
311-
multi = [MultiTensor([ inp for inp, _ in inputs]),MultiTensor([ inds for _, inds in inputs])]
301+
multi = [
302+
MultiTensor([inp for inp, _ in inputs]),
303+
MultiTensor([inds for _, inds in inputs]),
304+
]
312305
print("Quantizing model")
313306
model = quantizer.quantize(model, multi).cuda()
314307
print("Model quantized")
315308
print("Saving model and fixing state dict")
316-
regular_state_dict = model.state_dict()#defaultdict(torch.tensor)
309+
regular_state_dict = model.state_dict() # defaultdict(torch.tensor)
317310
for key, value in model.state_dict().items():
318311
if isinstance(value, MultiTensor):
319-
regular_state_dict[key] = value.values[0]
312+
regular_state_dict[key] = value.values[0]
320313
else:
321314
regular_state_dict[key] = value
322315

@@ -326,16 +319,16 @@ def run_eval(self, tasks, limit):
326319
del regular_state_dict[k]
327320

328321
model.load_state_dict(regular_state_dict, assign=True)
329-
torch.save(model.state_dict(), 'model.pth')
322+
torch.save(model.state_dict(), "model.pth")
330323
print("Running evaluation")
331324
result = TransformerEvalWrapper(
332-
model.to(device), # quantized model needs to run on cuda
333-
tokenizer,
334-
model.config.block_size,
335-
prepare_inputs_for_model,
336-
).run_eval(
337-
["wikitext"],
338-
None,
339-
)
325+
model.to(device), # quantized model needs to run on cuda
326+
tokenizer,
327+
model.config.block_size,
328+
prepare_inputs_for_model,
329+
).run_eval(
330+
["wikitext"],
331+
None,
332+
)
340333

341334
# wikitext: {'word_perplexity,none': 12.523175352665858, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.6042723245990418, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.681919059499152, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

0 commit comments

Comments
 (0)