2929from torchao .float8 .float8_utils import compute_error
3030from torchao .quantization import (
3131 Float8DynamicActivationFloat8WeightConfig ,
32+ Float8RowwiseTensor ,
3233 float8_dynamic_activation_float8_weight ,
3334 float8_weight_only ,
3435 quantize_ ,
@@ -324,19 +325,15 @@ def test_mm_float8dq_per_row(
324325
325326 quant_weight = test_linear .weight
326327
327- self .assertTrue (hasattr (quant_weight , "original_weight_tensor" ))
328- weight_impl = quant_weight .original_weight_tensor .tensor_impl
329-
330- self .assertTrue (hasattr (weight_impl , "float8_data" ))
331- self .assertTrue (hasattr (weight_impl , "scale" ))
332- self .assertFalse (weight_impl .transposed )
328+ self .assertTrue (hasattr (quant_weight , "float8_data" ))
329+ self .assertTrue (hasattr (quant_weight , "scale" ))
333330
334331 # Verify scale shape for row-wise quantization
335332 expected_scale_shape = (out_features , 1 )
336- actual_scale_shape = weight_impl .scale .shape
333+ actual_scale_shape = quant_weight .scale .shape
337334 self .assertEqual (actual_scale_shape , expected_scale_shape )
338335
339- self .assertEqual (weight_impl .float8_data .shape , (out_features , in_features ))
336+ self .assertEqual (quant_weight .float8_data .shape , (out_features , in_features ))
340337
341338 input_tensor = torch .randn (* input_shape , device = device , dtype = dtype )
342339
@@ -357,7 +354,7 @@ def test_mm_float8dq_per_row(
357354 @common_utils .parametrize ("float8_dtype" , [torch .float8_e4m3fn , torch .float8_e5m2 ])
358355 @common_utils .parametrize ("output_dtype" , [torch .float32 , torch .bfloat16 ])
359356 @common_utils .parametrize ("block_size" , [None , (1 , 32 ), (2 , 16 ), (4 , 8 )])
360- def test_dequantize_affine_float8 (self , float8_dtype , output_dtype , block_size ):
357+ def test__dequantize_affine_float8 (self , float8_dtype , output_dtype , block_size ):
361358 """Test _dequantize_affine_float8 with various configurations"""
362359
363360 device = "cuda"
@@ -387,7 +384,7 @@ def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
387384 @unittest .skipIf (
388385 not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
389386 )
390- def test_dequantize_affine_float8_scale_broadcasting (self ):
387+ def test__dequantize_affine_float8_scale_broadcasting (self ):
391388 """Test that scale broadcasting works correctly for block-wise quantization"""
392389 device = "cuda"
393390 # Create input tensor with known block structure
@@ -419,11 +416,11 @@ def test_dequantize_affine_float8_scale_broadcasting(self):
419416 @unittest .skipIf (
420417 not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
421418 )
422- @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
423- def test_float8_tensor_slicing_basic (self , granularity ):
419+ def test_float8_tensor_slicing_basic_per_tensor (self ):
424420 """Test basic slicing operations on Float8 tensors"""
425421 device = "cuda"
426422 dtype = torch .bfloat16
423+ granularity = PerTensor ()
427424
428425 # Create and quantize a model
429426 model = torch .nn .Linear (64 , 32 , bias = False ).to (device ).to (dtype )
@@ -450,6 +447,41 @@ def test_float8_tensor_slicing_basic(self, granularity):
450447 self .assertTrue (isinstance (sliced_1 , Float8AQTTensorImpl ))
451448 self .assertTrue (isinstance (sliced_both , Float8AQTTensorImpl ))
452449
450+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
451+ @unittest .skipIf (
452+ not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
453+ )
454+ def test_float8_tensor_slicing_basic_per_row (self ):
455+ """Test basic slicing operations on Float8 tensors"""
456+ device = "cuda"
457+ dtype = torch .bfloat16
458+ granularity = PerRow ()
459+
460+ # Create and quantize a model
461+ model = torch .nn .Linear (64 , 32 , bias = False ).to (device ).to (dtype )
462+ quantize_ (
463+ model , Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
464+ )
465+
466+ weight = model .weight
467+
468+ # Test dimension 0 slicing (rows)
469+ sliced_0 = weight [10 :20 ]
470+ self .assertEqual (sliced_0 .shape , (10 , 64 ))
471+
472+ # Test dimension 1 slicing (columns)
473+ sliced_1 = weight [:, 20 :40 ]
474+ self .assertEqual (sliced_1 .shape , (32 , 20 ))
475+
476+ # Test combined slicing
477+ sliced_both = weight [5 :15 , 10 :30 ]
478+ self .assertEqual (sliced_both .shape , (10 , 20 ))
479+
480+ # Verify the sliced tensors are still Float8 tensors
481+ self .assertTrue (isinstance (sliced_0 , Float8RowwiseTensor ))
482+ self .assertTrue (isinstance (sliced_1 , Float8RowwiseTensor ))
483+ self .assertTrue (isinstance (sliced_both , Float8RowwiseTensor ))
484+
453485 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
454486 @unittest .skipIf (
455487 not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
@@ -497,27 +529,26 @@ def test_float8_tensor_slicing_per_row(self):
497529 )
498530
499531 original_weight = model .weight # Shape: (32, 64)
500- original_impl = original_weight .original_weight_tensor .tensor_impl
501- original_scale = original_impl .scale # Shape: (32, 1)
532+ original_scale = model .weight .scale # Shape: (32, 1)
502533
503534 # Test row slicing (dimension 0)
504535 sliced_rows = original_weight [10 :20 ] # Shape: (10, 64)
505- sliced_impl = sliced_rows .original_weight_tensor . tensor_impl
536+ sliced_scale = sliced_rows .scale
506537
507538 # Scale should be sliced to match the rows
508539 expected_scale_shape = (10 , 1 )
509- self .assertEqual (sliced_impl . scale .shape , expected_scale_shape )
540+ self .assertEqual (sliced_scale .shape , expected_scale_shape )
510541
511542 # Verify the scale values are correct (should be subset of original)
512- self .assertTrue (torch .equal (sliced_impl . scale , original_scale [10 :20 ]))
543+ self .assertTrue (torch .equal (sliced_scale , original_scale [10 :20 ]))
513544
514545 # Test column slicing (dimension 1) - scale should not change for per-row
515546 sliced_cols = original_weight [:, 20 :40 ] # Shape: (32, 20)
516- sliced_cols_impl = sliced_cols .original_weight_tensor . tensor_impl
547+ sliced_cols_scale = sliced_cols .scale
517548
518549 # Scale shape should remain the same since we're not changing rows
519- self .assertEqual (sliced_cols_impl . scale .shape , (32 , 1 ))
520- self .assertTrue (torch .equal (sliced_cols_impl . scale , original_scale ))
550+ self .assertEqual (sliced_cols_scale .shape , (32 , 1 ))
551+ self .assertTrue (torch .equal (sliced_cols_scale , original_scale ))
521552
522553 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
523554 @unittest .skipIf (
@@ -552,15 +583,15 @@ def test_float8_tensor_slicing_edge_cases(self):
552583 @unittest .skipIf (
553584 not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
554585 )
555- @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
556586 @unittest .skipIf (
557587 is_sm_version (8 , 9 ),
558588 "TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15" ,
559589 )
560- def test_float8_tensor_slicing_functional_correctness (self , granularity ):
590+ def test_float8_tensor_slicing_functional_correctness_per_tensor (self ):
561591 """Test that sliced tensors produce correct results in computations"""
562592 device = "cuda"
563593 dtype = torch .bfloat16
594+ granularity = PerTensor ()
564595
565596 # Create reference and quantized models with dimensions that are multiples of 16
566597 ref_model = (
@@ -630,6 +661,89 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
630661 error = compute_error (ref_output , quant_output )
631662 self .assertGreater (error , 15 , f"Quantization SQNR too low: { error } " )
632663
664+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
665+ @unittest .skipIf (
666+ not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
667+ )
668+ @unittest .skipIf (
669+ is_sm_version (8 , 9 ),
670+ "TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15" ,
671+ )
672+ def test_float8_tensor_slicing_functional_correctness_per_row (self ):
673+ """Test that sliced tensors produce correct results in computations"""
674+ device = "cuda"
675+ dtype = torch .bfloat16
676+ granularity = PerRow ()
677+
678+ # Create reference and quantized models with dimensions that are multiples of 16
679+ ref_model = (
680+ torch .nn .Linear (64 , 48 , bias = False ).to (device ).to (dtype )
681+ ) # 48 is divisible by 16
682+ quant_model = copy .deepcopy (ref_model )
683+ quantize_ (
684+ quant_model ,
685+ Float8DynamicActivationFloat8WeightConfig (granularity = granularity ),
686+ )
687+
688+ # Create input with batch size that works well with slicing
689+ input_tensor = torch .randn (8 , 64 , device = device , dtype = dtype )
690+
691+ ref_weight_slice = ref_model .weight [0 :16 , 0 :32 ]
692+ quant_weight_slice = quant_model .weight [0 :16 , 0 :32 ]
693+
694+ # Verify that the sliced weights maintain Float8 properties
695+ self .assertTrue (hasattr (quant_weight_slice , "float8_data" ))
696+ self .assertTrue (hasattr (quant_weight_slice , "scale" ))
697+ sliced_impl = quant_weight_slice
698+ self .assertTrue (isinstance (sliced_impl , Float8RowwiseTensor ))
699+
700+ # Verify sliced weight shapes
701+ self .assertEqual (sliced_impl .float8_data .shape , (16 , 32 ))
702+
703+ # Get original quantized weight implementation for scale comparison
704+ original_quant_impl = quant_model .weight
705+
706+ # Verify scale properties based on granularity
707+ if isinstance (granularity , PerTensor ):
708+ # Per-tensor: scale should be identical to original (scalar)
709+ self .assertEqual (sliced_impl .scale .numel (), 1 )
710+ self .assertTrue (torch .equal (sliced_impl .scale , original_quant_impl .scale ))
711+ else : # PerRow
712+ # Per-row: scale should be sliced to match the selected rows (0:16)
713+ expected_scale_shape = (16 , 1 )
714+ self .assertEqual (sliced_impl .scale .shape , expected_scale_shape )
715+ # Verify the scale values are the correct slice from the original
716+ self .assertTrue (
717+ torch .equal (sliced_impl .scale , original_quant_impl .scale [0 :16 ])
718+ )
719+
720+ # Verify that sliced quantized data matches the correct slice from original
721+ original_float8_data_slice = quant_model .weight .float8_data [0 :16 , 0 :32 ]
722+ self .assertTrue (
723+ torch .equal (sliced_impl .float8_data , original_float8_data_slice )
724+ )
725+
726+ # Verify that sliced weights can be converted back to float with correct values
727+ sliced_float_weight = quant_weight_slice .to (dtype )
728+ self .assertEqual (sliced_float_weight .shape , (16 , 32 ))
729+ self .assertEqual (sliced_float_weight .dtype , dtype )
730+
731+ input_slice = input_tensor [:, 0 :32 ] # (8, 32) to match sliced weight
732+
733+ # Compute with sliced weights
734+ with torch .no_grad ():
735+ ref_output = torch .nn .functional .linear (input_slice , ref_weight_slice )
736+ quant_output = torch .nn .functional .linear (input_slice , quant_weight_slice )
737+
738+ # Verify shapes
739+ expected_shape = (8 , 16 ) # batch_size x out_features_sliced
740+ self .assertEqual (ref_output .shape , expected_shape )
741+ self .assertEqual (quant_output .shape , expected_shape )
742+
743+ # Verify reasonable quantization error
744+ error = compute_error (ref_output , quant_output )
745+ self .assertGreater (error , 15 , f"Quantization SQNR too low: { error } " )
746+
633747 def test_preprocess_scale_3d_reshape (self ):
634748 """Test that preprocess_scale correctly handles 3D scale tensors"""
635749 device = "cpu" # Use CPU for basic functionality test
@@ -675,46 +789,6 @@ def test_preprocess_scale_3d_reshape(self):
675789 expected_shape = (8 , 1 ) # Flattened (2*2*2, 1)
676790 self .assertEqual (result .shape , expected_shape )
677791
678- @common_utils .parametrize ("float8_dtype" , [torch .float8_e4m3fn , torch .float8_e5m2 ])
679- @common_utils .parametrize ("hp_dtype" , [torch .float32 , torch .bfloat16 ])
680- def test_quantize_dequantize_fp8_inductor (self , float8_dtype , hp_dtype ):
681- quantize_affine_float8 = torch .ops .torchao .quantize_affine_float8
682- dequantize_affine_float8 = torch .ops .torchao .dequantize_affine_float8
683- input = torch .randn (10 , 10 )
684- with torch .no_grad ():
685- torch ._dynamo .reset ()
686- expected_scale = torch .tensor (2.0 )
687- expected_quantized = quantize_affine_float8 (
688- input ,
689- expected_scale ,
690- float8_dtype = float8_dtype ,
691- )
692- expected_dequantized = dequantize_affine_float8 (
693- expected_quantized ,
694- expected_scale ,
695- output_dtype = hp_dtype ,
696- )
697- test_q , (code_q ,) = torch ._inductor .utils .run_and_get_code (
698- torch .compile (quantize_affine_float8 ),
699- input ,
700- expected_scale ,
701- float8_dtype = float8_dtype ,
702- )
703- torch .testing .FileCheck ().check (
704- "torch.ops.torchao.quantize_affine_float8.default"
705- ).run (code_q )
706- test_dq , (code_dq ,) = torch ._inductor .utils .run_and_get_code (
707- torch .compile (dequantize_affine_float8 ),
708- test_q ,
709- expected_scale ,
710- hp_dtype ,
711- )
712- torch .testing .FileCheck ().check (
713- "torch.ops.torchao.dequantize_affine_float8.default"
714- ).run (code_dq )
715- torch .testing .assert_close (expected_quantized , test_q )
716- torch .testing .assert_close (expected_dequantized , test_dq )
717-
718792
719793common_utils .instantiate_parametrized_tests (TestAffineQuantizedFloat8Compile )
720794
0 commit comments