1818from torchao .quantization .prototype .qat .api import (
1919 ComposableQATQuantizer ,
2020)
21- from torchao .quantization .prototype .qat .affine_fake_quantized_tensor import (
22- AffineFakeQuantizedTensor ,
23- )
2421from torchao .quantization .prototype .qat .utils import (
2522 _choose_qparams_per_token_asymmetric ,
2623 _fake_quantize_per_channel_group ,
2724 _fake_quantize_per_token ,
2825 _GenericFakeQuantize ,
29- _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK ,
3026)
3127from torchao .quantization .quant_api import (
3228 int4_weight_only ,
@@ -164,7 +160,7 @@ def _set_ptq_weight(
164160 Int8DynActInt4WeightLinear ,
165161 WeightOnlyInt4Linear ,
166162 )
167- from torchao .quantization .prototype .qat ._module_swap_api import (
163+ from torchao .quantization .prototype .qat .linear import (
168164 Int8DynActInt4WeightQATLinear ,
169165 Int4WeightOnlyQATLinear ,
170166 )
@@ -196,7 +192,7 @@ def _set_ptq_weight(
196192
197193 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
198194 def test_qat_8da4w_linear (self ):
199- from torchao .quantization .prototype .qat ._module_swap_api import Int8DynActInt4WeightQATLinear
195+ from torchao .quantization .prototype .qat .linear import Int8DynActInt4WeightQATLinear
200196 from torchao .quantization .GPTQ import Int8DynActInt4WeightLinear
201197
202198 group_size = 128
@@ -219,45 +215,17 @@ def test_qat_8da4w_linear(self):
219215 ptq_out = ptq_linear (x2 )
220216 torch .testing .assert_close (ptq_out , qat_out , atol = 0 , rtol = 0 )
221217
222- # TODO: compare against quantize_ API instead
223218 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
224219 def test_qat_8da4w_quantizer (self ):
225220 from torchao .quantization .prototype .qat import Int8DynActInt4WeightQATQuantizer
226- from torchao .quantization .GPTQ import Int8DynActInt4WeightQuantizer
227-
228- group_size = 16
229- torch .manual_seed (self .SEED )
230- m = M ()
231- m2 = copy .deepcopy (m )
232- qat_quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
233- ptq_quantizer = Int8DynActInt4WeightQuantizer (groupsize = group_size )
234- qat_model = qat_quantizer .prepare (m )
235- ptq_model = ptq_quantizer .quantize (m2 )
236-
237- # Compare model values
238- torch .manual_seed (self .SEED )
239- x = m .example_inputs ()
240- x2 = copy .deepcopy (x )
241- qat_out = qat_model (* x )
242- ptq_out = ptq_model (* x2 )
243- torch .testing .assert_close (ptq_out , qat_out , atol = 0 , rtol = 0 )
244-
245- # Convert QAT model and compare model values
246- converted_model = qat_quantizer .convert (qat_model )
247- converted_out = converted_model (* x )
248- torch .testing .assert_close (ptq_out , converted_out , atol = 0 , rtol = 0 )
249-
250- @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
251- def test_qat_8da4w_quantizer_module_swap (self ):
252- from torchao .quantization .prototype .qat import Int8DynActInt4WeightQATQuantizer
253- from torchao .quantization .prototype .qat ._module_swap_api import Int8DynActInt4WeightQATQuantizerModuleSwap
221+ from torchao .quantization .prototype .qat .linear import Int8DynActInt4WeightQATQuantizer
254222
255223 group_size = 16
256224 torch .manual_seed (self .SEED )
257225 m = M ()
258226 m2 = copy .deepcopy (m )
259227 subclass_quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
260- module_swap_quantizer = Int8DynActInt4WeightQATQuantizerModuleSwap (groupsize = group_size )
228+ module_swap_quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
261229 subclass_model = subclass_quantizer .prepare (m )
262230 module_swap_model = module_swap_quantizer .prepare (m2 )
263231
@@ -288,20 +256,6 @@ def test_qat_8da4w_quantizer_meta_weights(self):
288256 qat_model = qat_quantizer .prepare (m )
289257 self .assertTrue (all (v .is_meta for v in qat_model .state_dict ().values ()))
290258
291- def _copy_subclass_weights (
292- self ,
293- nn_linear : torch .nn .Linear ,
294- subclass_linear : AffineFakeQuantizedTensor ,
295- ):
296- nn_linear .weight = torch .nn .Parameter (subclass_linear .weight .original_tensor )
297-
298- def _assert_matches_subclass_weights (
299- self ,
300- nn_linear : torch .nn .Linear ,
301- subclass_linear : AffineFakeQuantizedTensor ,
302- ):
303- torch .testing .assert_close (nn_linear .weight , subclass_linear .weight .original_tensor , atol = 0 , rtol = 0 )
304-
305259 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
306260 def test_qat_8da4w_quantizer_disable_fake_quant (self ):
307261 """
@@ -313,16 +267,6 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
313267 enable_8da4w_fake_quant ,
314268 )
315269
316- def assert_fake_quant_enabled (m : torch .nn .Linear , enabled : bool ):
317- self .assertTrue (isinstance (m .weight , AffineFakeQuantizedTensor ))
318- self .assertEqual (m .weight .fake_quant_enabled , enabled )
319- self .assertTrue (hasattr (m , _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK ))
320- (_ , handle ) = getattr (m , _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK )
321- if enabled :
322- self .assertIsNotNone (handle )
323- else :
324- self .assertIsNone (handle )
325-
326270 group_size = 16
327271 torch .manual_seed (self .SEED )
328272 m = M ()
@@ -331,14 +275,14 @@ def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool):
331275 quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
332276 qat_model = quantizer .prepare (m )
333277 qat_model .apply (disable_8da4w_fake_quant )
334- assert_fake_quant_enabled (qat_model .linear1 , enabled = False )
335- assert_fake_quant_enabled (qat_model .linear2 , enabled = False )
336- assert_fake_quant_enabled (qat_model .sub .linear , enabled = False )
278+ self . assertFalse (qat_model .linear1 . _fake_quant_enabled )
279+ self . assertFalse (qat_model .linear2 . _fake_quant_enabled )
280+ self . assertFalse (qat_model .sub .linear . _fake_quant_enabled )
337281
338282 # Disabled fake quant is just a normal linear
339- self . _copy_subclass_weights ( m2 .linear1 , qat_model .linear1 )
340- self . _copy_subclass_weights ( m2 .linear2 , qat_model .linear2 )
341- self . _copy_subclass_weights ( m2 .sub .linear , qat_model .sub .linear )
283+ m2 .linear1 . weight = torch . nn . Parameter ( qat_model .linear1 . weight )
284+ m2 .linear2 . weight = torch . nn . Parameter ( qat_model .linear2 . weight )
285+ m2 .sub .linear . weight = torch . nn . Parameter ( qat_model .sub .linear . weight )
342286 torch .manual_seed (self .SEED )
343287 x = m .example_inputs ()
344288 x2 = copy .deepcopy (x )
@@ -348,16 +292,16 @@ def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool):
348292
349293 # Renable fake quant
350294 qat_model .apply (enable_8da4w_fake_quant )
351- assert_fake_quant_enabled (qat_model .linear1 , enabled = True )
352- assert_fake_quant_enabled (qat_model .linear2 , enabled = True )
353- assert_fake_quant_enabled (qat_model .sub .linear , enabled = True )
295+ self . assertTrue (qat_model .linear1 . _fake_quant_enabled )
296+ self . assertTrue (qat_model .linear2 . _fake_quant_enabled )
297+ self . assertTrue (qat_model .sub .linear . _fake_quant_enabled )
354298
355299 # Fake quant should be applied as normal
356300 quantizer2 = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
357301 qat_model2 = quantizer2 .prepare (m3 )
358- qat_model2 .linear1 .weight . original_tensor = qat_model .linear1 .weight . original_tensor
359- qat_model2 .linear2 .weight . original_tensor = qat_model .linear2 .weight . original_tensor
360- qat_model2 .sub .linear .weight . original_tensor = qat_model .sub .linear .weight . original_tensor
302+ qat_model2 .linear1 .weight = qat_model .linear1 .weight
303+ qat_model2 .linear2 .weight = qat_model .linear2 .weight
304+ qat_model2 .sub .linear .weight = qat_model .sub .linear .weight
361305 torch .manual_seed (self .SEED )
362306 x = m .example_inputs ()
363307 x2 = copy .deepcopy (x )
@@ -382,9 +326,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
382326 quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
383327 qat_model = quantizer .prepare (m )
384328 qat_model .apply (disable_8da4w_fake_quant )
385- self . _copy_subclass_weights ( nn_model .linear1 , qat_model .linear1 )
386- self . _copy_subclass_weights ( nn_model .linear2 , qat_model .linear2 )
387- self . _copy_subclass_weights ( nn_model .sub .linear , qat_model .sub .linear )
329+ nn_model .linear1 . weight = torch . nn . Parameter ( qat_model .linear1 . weight )
330+ nn_model .linear2 . weight = torch . nn . Parameter ( qat_model .linear2 . weight )
331+ nn_model .sub .linear . weight = torch . nn . Parameter ( qat_model .sub .linear . weight )
388332
389333 # Simulate training for both models
390334 optimizer1 = torch .optim .SGD (nn_model .parameters (), lr = 0.001 , momentum = 0.9 , weight_decay = 1e-5 )
@@ -406,9 +350,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
406350 optimizer2 .step ()
407351
408352 # After 1 training step, weights should match exactly
409- self . _assert_matches_subclass_weights (nn_model .linear1 , qat_model .linear1 )
410- self . _assert_matches_subclass_weights (nn_model .linear2 , qat_model .linear2 )
411- self . _assert_matches_subclass_weights (nn_model .sub .linear , qat_model .sub .linear )
353+ torch . testing . assert_close (nn_model .linear1 . weight , qat_model .linear1 . weight , atol = 0 , rtol = 0 )
354+ torch . testing . assert_close (nn_model .linear2 . weight , qat_model .linear2 . weight , atol = 0 , rtol = 0 )
355+ torch . testing . assert_close (nn_model .sub .linear . weight , qat_model .sub .linear . weight , atol = 0 , rtol = 0 )
412356
413357 def _test_qat_quantized_gradients (self , quantizer ):
414358 """
@@ -542,7 +486,7 @@ def test_qat_4w_primitives(self):
542486 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
543487 @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
544488 def test_qat_4w_linear (self ):
545- from torchao .quantization .prototype .qat ._module_swap_api import Int4WeightOnlyQATLinear
489+ from torchao .quantization .prototype .qat .linear import Int4WeightOnlyQATLinear
546490 from torchao .quantization .GPTQ import WeightOnlyInt4Linear
547491
548492 group_size = 128
@@ -567,39 +511,6 @@ def test_qat_4w_linear(self):
567511 ptq_out = ptq_linear (x2 )
568512 self ._assert_close_4w (qat_out , ptq_out )
569513
570- @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
571- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
572- def test_qat_4w_quantizer (self ):
573- from torchao .quantization .prototype .qat import Int4WeightOnlyQATQuantizer
574- from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
575-
576- group_size = 32
577- inner_k_tiles = 8
578- device = torch .device ("cuda" )
579- dtype = torch .bfloat16
580- torch .manual_seed (self .SEED )
581- m = M ().to (device ).to (dtype )
582- m2 = copy .deepcopy (m )
583- qat_quantizer = Int4WeightOnlyQATQuantizer (
584- groupsize = group_size , inner_k_tiles = inner_k_tiles ,
585- )
586- qat_model = qat_quantizer .prepare (m )
587- ptq_model = m2
588- quantize_ (ptq_model , int4_weight_only (group_size , TensorCoreTiledLayoutType (inner_k_tiles )))
589-
590- # Compare model values
591- torch .manual_seed (self .SEED )
592- x = [i .to (device ).to (dtype ) for i in m .example_inputs ()]
593- x2 = copy .deepcopy (x )
594- qat_out = qat_model (* x )
595- ptq_out = ptq_model (* x2 )
596- self ._assert_close_4w (qat_out , ptq_out )
597-
598- # Convert QAT model and compare model values
599- converted_model = qat_quantizer .convert (qat_model )
600- converted_out = converted_model (* x )
601- torch .testing .assert_close (converted_out , ptq_out , atol = 0 , rtol = 0 )
602-
603514 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
604515 def test_qat_4w_quantizer_gradients (self ):
605516 from torchao .quantization .prototype .qat import Int4WeightOnlyQATQuantizer
@@ -608,9 +519,9 @@ def test_qat_4w_quantizer_gradients(self):
608519
609520 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
610521 @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
611- def test_qat_4w_quantizer_module_swap (self ):
522+ def test_qat_4w_quantizer (self ):
612523 from torchao .quantization .prototype .qat import Int4WeightOnlyQATQuantizer
613- from torchao .quantization .prototype .qat ._module_swap_api import Int4WeightOnlyQATQuantizerModuleSwap
524+ from torchao .quantization .prototype .qat .linear import Int4WeightOnlyQATQuantizer
614525
615526 group_size = 32
616527 inner_k_tiles = 8
@@ -622,7 +533,7 @@ def test_qat_4w_quantizer_module_swap(self):
622533 subclass_quantizer = Int4WeightOnlyQATQuantizer (
623534 groupsize = group_size , inner_k_tiles = inner_k_tiles ,
624535 )
625- module_swap_quantizer = Int4WeightOnlyQATQuantizerModuleSwap (
536+ module_swap_quantizer = Int4WeightOnlyQATQuantizer (
626537 groupsize = group_size , inner_k_tiles = inner_k_tiles ,
627538 )
628539 subclass_model = subclass_quantizer .prepare (m )
0 commit comments