55# LICENSE file in the root directory of this source tree.
66import copy
77import random
8- from typing import List , Tuple
98import sys
109import unittest
1110from io import StringIO
1918
2019import torch
2120import torch .nn as nn
21+ from torch ._dynamo .test_case import TestCase as DynamoTestCase
22+ from torch ._dynamo .testing import CompileCounterWithBackend
23+
2224from torchao .float8 .config import (
2325 CastConfig ,
2426 Float8LinearConfig ,
25- ScalingType ,
2627 Float8LinearRecipeName ,
28+ ScalingType ,
2729 recipe_name_to_linear_config ,
2830)
2931from torchao .float8 .float8_linear import Float8Linear
3739 hp_tensor_to_float8_dynamic ,
3840)
3941from torchao .float8 .float8_tensor import (
40- LinearMMConfig ,
4142 GemmInputRole ,
43+ LinearMMConfig ,
4244 ScaledMMConfig ,
4345)
4446from torchao .float8 .float8_utils import e4m3_dtype
4547from torchao .testing .float8 .test_utils import get_test_float8_linear_config
4648
47- from torch ._dynamo .test_case import TestCase as DynamoTestCase
48- from torch ._dynamo .testing import CompileCounterWithBackend
49-
5049# TODO(future PR): standardize IS_H100 with the rest of the codebase
5150is_H100 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 )
5251is_cuda_8_9 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (8 , 9 )
5352
53+
5454def _test_compile_base (
5555 backend : str ,
5656 fullgraph : bool ,
@@ -92,10 +92,12 @@ def _test_compile_base(
9292 "scaling_type_input" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
9393)
9494@pytest .mark .parametrize (
95- "scaling_type_weight" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
95+ "scaling_type_weight" ,
96+ [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ],
9697)
9798@pytest .mark .parametrize (
98- "scaling_type_grad_output" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
99+ "scaling_type_grad_output" ,
100+ [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ],
99101)
100102@pytest .mark .parametrize ("emulate" , [False , True ] if is_cuda_8_9 else [True ])
101103@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
@@ -129,10 +131,12 @@ def test_eager_only(
129131 "scaling_type_input" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
130132)
131133@pytest .mark .parametrize (
132- "scaling_type_weight" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
134+ "scaling_type_weight" ,
135+ [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ],
133136)
134137@pytest .mark .parametrize (
135- "scaling_type_grad_output" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
138+ "scaling_type_grad_output" ,
139+ [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ],
136140)
137141@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
138142@unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
@@ -165,12 +169,17 @@ def test_aot_eager(
165169 "scaling_type_input" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
166170)
167171@pytest .mark .parametrize (
168- "scaling_type_weight" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
172+ "scaling_type_weight" ,
173+ [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ],
169174)
170175@pytest .mark .parametrize (
171- "scaling_type_grad_output" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
176+ "scaling_type_grad_output" ,
177+ [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ],
178+ )
179+ @unittest .skipIf (
180+ not torch .cuda .is_available () or not is_cuda_8_9 ,
181+ "CUDA with float8 support not available" ,
172182)
173- @unittest .skipIf (not torch .cuda .is_available () or not is_cuda_8_9 , "CUDA with float8 support not available" )
174183@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
175184def test_inductor_from_config_params (
176185 fullgraph ,
@@ -194,13 +203,17 @@ def test_inductor_from_config_params(
194203 dtype ,
195204 )
196205
206+
197207# Note: there are now too many config combinations to test all of
198208# them, so this function factors out some of the recipes which are annoying
199209# to combine with the main testing function.
200210# TODO(future PR): make this cleaner.
201211@pytest .mark .parametrize (
202212 "recipe_name" ,
203- [Float8LinearRecipeName .ALL_AXISWISE , Float8LinearRecipeName .LW_AXISWISE_WITH_GW_HP ],
213+ [
214+ Float8LinearRecipeName .ALL_AXISWISE ,
215+ Float8LinearRecipeName .LW_AXISWISE_WITH_GW_HP ,
216+ ],
204217)
205218@unittest .skipIf (not is_H100 , "CUDA with capability 9.0 or greater not available" )
206219def test_inductor_from_recipe (recipe_name ):
@@ -239,7 +252,10 @@ def forward(self, x):
239252 return x_fp8
240253
241254 # TODO(future): figure out why the test below fails on CUDA capability 8.9
242- @unittest .skipIf (not torch .cuda .is_available () or not is_H100 , "CUDA with capability 9.0 or greater not available" )
255+ @unittest .skipIf (
256+ not torch .cuda .is_available () or not is_H100 ,
257+ "CUDA with capability 9.0 or greater not available" ,
258+ )
243259 def test_float8_with_graph_break_in_the_middle (self ):
244260 """Test that having Float8Tensor object at the boundary of a subgraph"""
245261 cnts = CompileCounterWithBackend ("inductor" )
@@ -252,7 +268,10 @@ def test_float8_with_graph_break_in_the_middle(self):
252268 self .assertEqual (cnts .frame_count , 2 , "Compiled graph should have 2 frames!" )
253269 torch .testing .assert_close (y_eager , y_compiled )
254270
255- @unittest .skipIf (not torch .cuda .is_available () or not is_cuda_8_9 , "CUDA with float8 support not available" )
271+ @unittest .skipIf (
272+ not torch .cuda .is_available () or not is_cuda_8_9 ,
273+ "CUDA with float8 support not available" ,
274+ )
256275 def test_float8_graph_input (self ):
257276 """Test that having Float8Tensor object as a graph input"""
258277
@@ -273,7 +292,10 @@ def to_float(x):
273292 )
274293 torch .testing .assert_close (y2_eager , y2_compiled )
275294
276- @unittest .skipIf (not torch .cuda .is_available () or not is_cuda_8_9 , "CUDA with float8 support not available" )
295+ @unittest .skipIf (
296+ not torch .cuda .is_available () or not is_cuda_8_9 ,
297+ "CUDA with float8 support not available" ,
298+ )
277299 def test_float8_graph_output (self ):
278300 """Test that having Float8Tensor object as a graph output works"""
279301 cnts = CompileCounterWithBackend ("inductor" )
@@ -300,7 +322,10 @@ def test_float8_graph_output(self):
300322 )
301323
302324
303- @unittest .skipIf (not torch .cuda .is_available () or not is_cuda_8_9 , "CUDA with float8 support not available" )
325+ @unittest .skipIf (
326+ not torch .cuda .is_available () or not is_cuda_8_9 ,
327+ "CUDA with float8 support not available" ,
328+ )
304329def test_sync_amax_func ():
305330 torch ._dynamo .reset ()
306331 cnts = CompileCounterWithBackend ("inductor" )
@@ -338,7 +363,10 @@ def __exit__(self, *args):
338363 sys .stderr = self .sys_stderr
339364
340365
341- @unittest .skipIf (not torch .cuda .is_available () or not is_cuda_8_9 , "CUDA with float8 support not available" )
366+ @unittest .skipIf (
367+ not torch .cuda .is_available () or not is_cuda_8_9 ,
368+ "CUDA with float8 support not available" ,
369+ )
342370def test_sync_amax_func_cuda_graph_success ():
343371 torch ._dynamo .reset ()
344372 with capture_stderr () as stderr :
@@ -368,9 +396,9 @@ def test_sync_amax_func_cuda_graph_success():
368396
369397
370398@unittest .skipIf (
371- not is_cuda_8_9 ,
372- "CUDA not available" ,
373- )
399+ not is_cuda_8_9 ,
400+ "CUDA not available" ,
401+ )
374402@pytest .mark .parametrize (
375403 "dtype" ,
376404 [
0 commit comments