|
22 | 22 | from torch.testing._internal.common_utils import TestCase |
23 | 23 |
|
24 | 24 | from torchao import quantize_ |
25 | | -from benchmarks._models.llama.model import Transformer, prepare_inputs_for_model |
26 | | -from benchmarks._models.llama.tokenizer import get_tokenizer |
| 25 | +from torchao._models.model import Transformer, prepare_inputs_for_model |
| 26 | +from torchao._models.tokenizer import get_tokenizer |
27 | 27 | from torchao.dtypes import AffineQuantizedTensor |
28 | 28 | from torchao.quantization import LinearActivationQuantizedTensor |
29 | 29 | from torchao.quantization.quant_api import ( |
@@ -278,7 +278,7 @@ def test_8da4w_quantizer(self): |
278 | 278 | # https://github.com/pytorch-labs/gpt-fast/blob/6253c6bb054e658d67566150f87329b87815ae63/scripts/convert_hf_checkpoint.py |
279 | 279 | @unittest.skip("skipping until we get checkpoints for gpt-fast") |
280 | 280 | def test_8da4w_gptq_quantizer(self): |
281 | | - from benchmarks._models._eval import InputRecorder, TransformerEvalWrapper |
| 281 | + from torchao._models._eval import InputRecorder, TransformerEvalWrapper |
282 | 282 | from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer |
283 | 283 |
|
284 | 284 | # should be similar to TorchCompileDynamicQuantizer |
@@ -348,7 +348,7 @@ def test_8da4w_gptq_quantizer(self): |
348 | 348 | not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch verion is 2.4 or lower" |
349 | 349 | ) |
350 | 350 | def test_8da4w_quantizer_eval(self): |
351 | | - from benchmarks._models._eval import TransformerEvalWrapper |
| 351 | + from torchao._models._eval import TransformerEvalWrapper |
352 | 352 | from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer |
353 | 353 |
|
354 | 354 | precision = torch.bfloat16 |
@@ -384,7 +384,7 @@ def test_8da4w_quantizer_eval(self): |
384 | 384 |
|
385 | 385 | @unittest.skip("skipping until we get checkpoints for gpt-fast") |
386 | 386 | def test_gptq_quantizer_int4_weight_only(self): |
387 | | - from benchmarks._models._eval import ( |
| 387 | + from torchao._models._eval import ( |
388 | 388 | MultiTensorInputRecorder, |
389 | 389 | TransformerEvalWrapper, |
390 | 390 | ) |
@@ -454,7 +454,7 @@ def test_gptq_quantizer_int4_weight_only(self): |
454 | 454 |
|
455 | 455 | @unittest.skip("skipping until we get checkpoints for gpt-fast") |
456 | 456 | def test_quantizer_int4_weight_only(self): |
457 | | - from benchmarks._models._eval import TransformerEvalWrapper |
| 457 | + from torchao._models._eval import TransformerEvalWrapper |
458 | 458 | from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer |
459 | 459 |
|
460 | 460 | precision = torch.bfloat16 |
@@ -492,7 +492,7 @@ def test_quantizer_int4_weight_only(self): |
492 | 492 |
|
493 | 493 | @unittest.skip("skipping until we get checkpoints for gpt-fast") |
494 | 494 | def test_eval_wrapper(self): |
495 | | - from benchmarks._models._eval import TransformerEvalWrapper |
| 495 | + from torchao._models._eval import TransformerEvalWrapper |
496 | 496 |
|
497 | 497 | precision = torch.bfloat16 |
498 | 498 | device = "cuda" |
@@ -525,7 +525,7 @@ def test_eval_wrapper(self): |
525 | 525 | # EVAL IS CURRENTLY BROKEN FOR LLAMA 3, VERY LOW ACCURACY |
526 | 526 | @unittest.skip("skipping until we get checkpoints for gpt-fast") |
527 | 527 | def test_eval_wrapper_llama3(self): |
528 | | - from benchmarks._models._eval import TransformerEvalWrapper |
| 528 | + from torchao._models._eval import TransformerEvalWrapper |
529 | 529 |
|
530 | 530 | precision = torch.bfloat16 |
531 | 531 | device = "cuda" |
|
0 commit comments