2222 PerRow ,
2323 PerToken ,
2424)
25- from torchao .quantization .prototype . qat .api import (
25+ from torchao .quantization .qat .api import (
2626 ComposableQATQuantizer ,
2727 FakeQuantizeConfig ,
2828)
29- from torchao .quantization .prototype . qat .fake_quantizer import (
29+ from torchao .quantization .qat .fake_quantizer import (
3030 FakeQuantizer ,
3131)
32- from torchao .quantization .prototype . qat .embedding import (
32+ from torchao .quantization .qat .embedding import (
3333 FakeQuantizedEmbedding ,
3434)
35- from torchao .quantization .prototype . qat .linear import (
35+ from torchao .quantization .qat .linear import (
3636 FakeQuantizedLinear ,
3737)
38- from torchao .quantization .prototype . qat .utils import (
38+ from torchao .quantization .qat .utils import (
3939 _choose_qparams_per_token_asymmetric ,
4040 _fake_quantize_per_channel_group ,
4141 _fake_quantize_per_token ,
@@ -175,7 +175,7 @@ def _set_ptq_weight(
175175 Int8DynActInt4WeightLinear ,
176176 WeightOnlyInt4Linear ,
177177 )
178- from torchao .quantization .prototype . qat .linear import (
178+ from torchao .quantization .qat .linear import (
179179 Int8DynActInt4WeightQATLinear ,
180180 Int4WeightOnlyQATLinear ,
181181 )
@@ -207,7 +207,7 @@ def _set_ptq_weight(
207207
208208 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
209209 def test_qat_8da4w_linear (self ):
210- from torchao .quantization .prototype . qat .linear import Int8DynActInt4WeightQATLinear
210+ from torchao .quantization .qat .linear import Int8DynActInt4WeightQATLinear
211211 from torchao .quantization .GPTQ import Int8DynActInt4WeightLinear
212212
213213 group_size = 128
@@ -232,7 +232,7 @@ def test_qat_8da4w_linear(self):
232232
233233 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
234234 def test_qat_8da4w_quantizer (self ):
235- from torchao .quantization .prototype . qat import Int8DynActInt4WeightQATQuantizer
235+ from torchao .quantization .qat import Int8DynActInt4WeightQATQuantizer
236236 from torchao .quantization .GPTQ import Int8DynActInt4WeightQuantizer
237237
238238 group_size = 16
@@ -266,7 +266,7 @@ def test_qat_8da4w_quantizer(self):
266266
267267 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
268268 def test_qat_8da4w_quantizer_meta_weights (self ):
269- from torchao .quantization .prototype . qat import Int8DynActInt4WeightQATQuantizer
269+ from torchao .quantization .qat import Int8DynActInt4WeightQATQuantizer
270270
271271 with torch .device ("meta" ):
272272 m = M ()
@@ -281,7 +281,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
281281 """
282282 Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
283283 """
284- from torchao .quantization .prototype . qat import (
284+ from torchao .quantization .qat . linear import (
285285 Int8DynActInt4WeightQATQuantizer ,
286286 disable_8da4w_fake_quant ,
287287 enable_8da4w_fake_quant ,
@@ -340,7 +340,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
340340 """
341341 Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
342342 """
343- from torchao .quantization .prototype . qat import (
343+ from torchao .quantization .qat . linear import (
344344 Int8DynActInt4WeightQATQuantizer ,
345345 disable_8da4w_fake_quant ,
346346 )
@@ -422,7 +422,7 @@ def _test_qat_quantized_gradients(self, quantizer):
422422
423423 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
424424 def test_qat_8da4w_quantizer_gradients (self ):
425- from torchao .quantization .prototype . qat import Int8DynActInt4WeightQATQuantizer
425+ from torchao .quantization .qat import Int8DynActInt4WeightQATQuantizer
426426 quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = 16 )
427427 self ._test_qat_quantized_gradients (quantizer )
428428
@@ -512,7 +512,7 @@ def test_qat_4w_primitives(self):
512512 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
513513 @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
514514 def test_qat_4w_linear (self ):
515- from torchao .quantization .prototype . qat .linear import Int4WeightOnlyQATLinear
515+ from torchao .quantization .qat .linear import Int4WeightOnlyQATLinear
516516 from torchao .quantization .GPTQ import WeightOnlyInt4Linear
517517
518518 group_size = 128
@@ -539,14 +539,14 @@ def test_qat_4w_linear(self):
539539
540540 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
541541 def test_qat_4w_quantizer_gradients (self ):
542- from torchao .quantization .prototype . qat import Int4WeightOnlyQATQuantizer
542+ from torchao .quantization .qat import Int4WeightOnlyQATQuantizer
543543 quantizer = Int4WeightOnlyQATQuantizer (groupsize = 32 , inner_k_tiles = 8 )
544544 self ._test_qat_quantized_gradients (quantizer )
545545
546546 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
547547 @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
548548 def test_qat_4w_quantizer (self ):
549- from torchao .quantization .prototype . qat import Int4WeightOnlyQATQuantizer
549+ from torchao .quantization .qat import Int4WeightOnlyQATQuantizer
550550 from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
551551
552552 group_size = 32
@@ -624,7 +624,7 @@ def test_composable_qat_quantizer(self):
624624
625625 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
626626 def test_qat_4w_embedding (self ):
627- from torchao .quantization .prototype . qat import Int4WeightOnlyEmbeddingQATQuantizer
627+ from torchao .quantization .qat import Int4WeightOnlyEmbeddingQATQuantizer
628628 model = M2 ()
629629 x = model .example_inputs ()
630630 out = model (* x )
0 commit comments