1212 run_tests ,
1313)
1414
15- from torchao .float8 .config import e4m3_dtype
1615from torchao .quantization import (
1716 FbgemmConfig ,
1817 quantize_ ,
2928@unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
3029class TestFbgemmFp8Tensor (TestCase ):
3130 def setUp (self ):
31+ self .e4m3_dtype = torch .float8_e4m3fn
3232 self .config = FbgemmConfig (
33- input_dtype = e4m3_dtype ,
34- weight_dtype = e4m3_dtype ,
33+ input_dtype = self . e4m3_dtype ,
34+ weight_dtype = self . e4m3_dtype ,
3535 output_dtype = torch .bfloat16 ,
3636 )
37- self .bmm_config = FbgemmConfig (
38- input_dtype = e4m3_dtype ,
39- weight_dtype = e4m3_dtype ,
40- output_dtype = torch .bfloat16 ,
41- transpose_input = True ,
42- )
4337 self .GPU_DEVICES = ["cuda" ] if torch .cuda .is_available () else []
4438
4539 def test_linear (self ):
@@ -128,7 +122,9 @@ def forward(self, x):
128122 weight = torch .randn (10 , 128 , 256 , dtype = dtype , device = device )
129123 m = M (weight ).eval ()
130124 original = m (input )
131- quantize_ (m , self .bmm_config , filter_fn = lambda x , fqn : True )
125+ # we need to transpose the weight first for bmm
126+ m .weight = torch .nn .Parameter (m .weight .transpose (1 , 2 ).contiguous ())
127+ quantize_ (m , self .config , filter_fn = lambda x , fqn : True )
132128 quantized = m (input )
133129 self .assertTrue (compute_error (original , quantized ) > 20 )
134130
@@ -146,6 +142,54 @@ def test_to_device(self):
146142 quantize_ (linear , self .config )
147143 linear .to (device )
148144
145+ def test_cat (self ):
146+ dtype = torch .bfloat16
147+ device = "cuda"
148+ # weight: (256, 128)
149+ linear1 = torch .nn .Linear (128 , 256 , dtype = dtype )
150+ # weight: (256, 128)
151+ linear2 = torch .nn .Linear (128 , 256 , dtype = dtype )
152+
153+ cat_weight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
154+ dummy1 = torch .nn .Linear (128 , 512 , bias = False , dtype = dtype , device = device )
155+
156+ dummy1 .weight = torch .nn .Parameter (cat_weight1 )
157+ quantize_ (dummy1 , self .config )
158+
159+ quantize_ (linear1 , self .config )
160+ quantize_ (linear2 , self .config )
161+
162+ cat_qweight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
163+ self .assertTrue (cat_qweight1 .shape , (512 , 128 ))
164+ self .assertEqual (dummy1 .weight .float8_data , cat_qweight1 .float8_data )
165+ self .assertEqual (dummy1 .weight .scale , cat_qweight1 .scale )
166+
167+ # concat with dim == 1 is not really correct and will be fixed later
168+ # when we support distributed checkpointing
169+ cat_qweight2 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 1 )
170+ self .assertTrue (cat_qweight2 .shape , (256 , 256 ))
171+ ref_float8_data = torch .cat (
172+ [linear1 .weight .float8_data , linear2 .weight .float8_data ], dim = 1
173+ )
174+ ref_scale = linear1 .weight .scale
175+ self .assertEqual (cat_qweight2 .float8_data , ref_float8_data )
176+ self .assertEqual (cat_qweight2 .scale , ref_scale )
177+
178+ def test_transpose (self ):
179+ dtype = torch .bfloat16
180+ device = "cuda"
181+ # weight: (256, 128)
182+ linear1 = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
183+ quantize_ (linear1 , self .config )
184+ linear1 .weight = torch .nn .Parameter (linear1 .weight .transpose (0 , 1 ).contiguous ())
185+ linear1 .bias = torch .nn .Parameter (torch .randn (128 , dtype = dtype , device = device ))
186+ self .assertTrue (linear1 .weight .shape , (128 , 256 ))
187+
188+ input = torch .randn (32 , 256 , dtype = dtype , device = device )
189+ # make sure it runs
190+ res = linear1 (input )
191+ self .assertTrue (res .shape , (32 , 128 ))
192+
149193
150194if __name__ == "__main__" :
151195 run_tests ()
0 commit comments