Skip to content

Commit 6234116

Browse files
authored
Lint fixes for test/float8 (#1303)
1 parent bce2abb commit 6234116

File tree

9 files changed

+123
-99
lines changed

9 files changed

+123
-99
lines changed

ruff.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ include = [
88
"torchao/dtypes/**/*.py",
99
"torchao/sparsity/**/*.py",
1010
"torchao/prototype/low_bit_optim/**.py",
11+
"test/float8/**/*.py",
1112
"test/quantization/test_observer.py",
1213
"test/dtypes/test_affine_quantized_float.py",
1314
"test/dtypes/test_nf4.py",

test/float8/test_base.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,13 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
import copy
7-
import io
87
import itertools
98
import random
109
import re
1110
import unittest
1211
import warnings
13-
from typing import List, Tuple
1412

1513
import pytest
16-
1714
import torch
1815
import torch.nn as nn
1916

@@ -27,9 +24,9 @@
2724
CastConfig,
2825
Float8LinearConfig,
2926
Float8LinearRecipeName,
30-
recipe_name_to_linear_config,
3127
ScalingGranularity,
3228
ScalingType,
29+
recipe_name_to_linear_config,
3330
)
3431
from torchao.float8.float8_linear import Float8Linear
3532
from torchao.float8.float8_linear_utils import (
@@ -45,16 +42,16 @@
4542
from torchao.float8.float8_tensor import (
4643
Float8Tensor,
4744
GemmInputRole,
48-
hp_tensor_and_scale_to_float8,
4945
LinearMMConfig,
5046
ScaledMMConfig,
47+
hp_tensor_and_scale_to_float8,
5148
)
5249
from torchao.float8.float8_utils import (
50+
FP8_TYPES,
5351
compute_error,
5452
e4m3_dtype,
5553
e5m2_dtype,
5654
fp8_tensor_statistics,
57-
FP8_TYPES,
5855
tensor_to_scale,
5956
)
6057
from torchao.testing.float8.test_utils import get_test_float8_linear_config
@@ -186,7 +183,7 @@ def test_axiswise_reshape(self):
186183
rtol=0,
187184
)
188185
with pytest.raises(RuntimeError):
189-
a_fp8_d0_r2 = a_fp8_d0.reshape(-1, 7)
186+
a_fp8_d0.reshape(-1, 7)
190187

191188
# if we scale across dim2, we can only reshape to [-1, 7]
192189
a_fp8_d2 = hp_tensor_to_float8_dynamic(
@@ -210,7 +207,7 @@ def test_axiswise_reshape(self):
210207
rtol=0,
211208
)
212209
with pytest.raises(RuntimeError):
213-
a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1)
210+
a_fp8_d2.reshape(3, -1)
214211

215212
@pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)])
216213
@pytest.mark.parametrize(
@@ -528,7 +525,7 @@ def test_inference_mode(self):
528525
m = nn.Sequential(nn.Linear(32, 32)).cuda()
529526
m = convert_to_float8_training(m)
530527
with torch.inference_mode(mode=True):
531-
y = m(x)
528+
m(x)
532529

533530

534531
class TestScaledMM:

test/float8/test_compile.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66
import copy
77
import random
8-
from typing import List, Tuple
98
import sys
109
import unittest
1110
from io import StringIO
@@ -19,11 +18,14 @@
1918

2019
import torch
2120
import torch.nn as nn
21+
from torch._dynamo.test_case import TestCase as DynamoTestCase
22+
from torch._dynamo.testing import CompileCounterWithBackend
23+
2224
from torchao.float8.config import (
2325
CastConfig,
2426
Float8LinearConfig,
25-
ScalingType,
2627
Float8LinearRecipeName,
28+
ScalingType,
2729
recipe_name_to_linear_config,
2830
)
2931
from torchao.float8.float8_linear import Float8Linear
@@ -37,20 +39,18 @@
3739
hp_tensor_to_float8_dynamic,
3840
)
3941
from torchao.float8.float8_tensor import (
40-
LinearMMConfig,
4142
GemmInputRole,
43+
LinearMMConfig,
4244
ScaledMMConfig,
4345
)
4446
from torchao.float8.float8_utils import e4m3_dtype
4547
from torchao.testing.float8.test_utils import get_test_float8_linear_config
4648

47-
from torch._dynamo.test_case import TestCase as DynamoTestCase
48-
from torch._dynamo.testing import CompileCounterWithBackend
49-
5049
# TODO(future PR): standardize IS_H100 with the rest of the codebase
5150
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
5251
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
5352

53+
5454
def _test_compile_base(
5555
backend: str,
5656
fullgraph: bool,
@@ -92,10 +92,12 @@ def _test_compile_base(
9292
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
9393
)
9494
@pytest.mark.parametrize(
95-
"scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
95+
"scaling_type_weight",
96+
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
9697
)
9798
@pytest.mark.parametrize(
98-
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
99+
"scaling_type_grad_output",
100+
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
99101
)
100102
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
101103
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@@ -129,10 +131,12 @@ def test_eager_only(
129131
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
130132
)
131133
@pytest.mark.parametrize(
132-
"scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
134+
"scaling_type_weight",
135+
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
133136
)
134137
@pytest.mark.parametrize(
135-
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
138+
"scaling_type_grad_output",
139+
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
136140
)
137141
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
138142
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@@ -165,12 +169,17 @@ def test_aot_eager(
165169
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
166170
)
167171
@pytest.mark.parametrize(
168-
"scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
172+
"scaling_type_weight",
173+
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
169174
)
170175
@pytest.mark.parametrize(
171-
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
176+
"scaling_type_grad_output",
177+
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
178+
)
179+
@unittest.skipIf(
180+
not torch.cuda.is_available() or not is_cuda_8_9,
181+
"CUDA with float8 support not available",
172182
)
173-
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
174183
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
175184
def test_inductor_from_config_params(
176185
fullgraph,
@@ -194,13 +203,17 @@ def test_inductor_from_config_params(
194203
dtype,
195204
)
196205

206+
197207
# Note: there are now too many config combinations to test all of
198208
# them, so this function factors out some of the recipes which are annoying
199209
# to combine with the main testing function.
200210
# TODO(future PR): make this cleaner.
201211
@pytest.mark.parametrize(
202212
"recipe_name",
203-
[Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP],
213+
[
214+
Float8LinearRecipeName.ALL_AXISWISE,
215+
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
216+
],
204217
)
205218
@unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available")
206219
def test_inductor_from_recipe(recipe_name):
@@ -239,7 +252,10 @@ def forward(self, x):
239252
return x_fp8
240253

241254
# TODO(future): figure out why the test below fails on CUDA capability 8.9
242-
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with capability 9.0 or greater not available")
255+
@unittest.skipIf(
256+
not torch.cuda.is_available() or not is_H100,
257+
"CUDA with capability 9.0 or greater not available",
258+
)
243259
def test_float8_with_graph_break_in_the_middle(self):
244260
"""Test that having Float8Tensor object at the boundary of a subgraph"""
245261
cnts = CompileCounterWithBackend("inductor")
@@ -252,7 +268,10 @@ def test_float8_with_graph_break_in_the_middle(self):
252268
self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!")
253269
torch.testing.assert_close(y_eager, y_compiled)
254270

255-
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
271+
@unittest.skipIf(
272+
not torch.cuda.is_available() or not is_cuda_8_9,
273+
"CUDA with float8 support not available",
274+
)
256275
def test_float8_graph_input(self):
257276
"""Test that having Float8Tensor object as a graph input"""
258277

@@ -273,7 +292,10 @@ def to_float(x):
273292
)
274293
torch.testing.assert_close(y2_eager, y2_compiled)
275294

276-
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
295+
@unittest.skipIf(
296+
not torch.cuda.is_available() or not is_cuda_8_9,
297+
"CUDA with float8 support not available",
298+
)
277299
def test_float8_graph_output(self):
278300
"""Test that having Float8Tensor object as a graph output works"""
279301
cnts = CompileCounterWithBackend("inductor")
@@ -300,7 +322,10 @@ def test_float8_graph_output(self):
300322
)
301323

302324

303-
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
325+
@unittest.skipIf(
326+
not torch.cuda.is_available() or not is_cuda_8_9,
327+
"CUDA with float8 support not available",
328+
)
304329
def test_sync_amax_func():
305330
torch._dynamo.reset()
306331
cnts = CompileCounterWithBackend("inductor")
@@ -338,7 +363,10 @@ def __exit__(self, *args):
338363
sys.stderr = self.sys_stderr
339364

340365

341-
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
366+
@unittest.skipIf(
367+
not torch.cuda.is_available() or not is_cuda_8_9,
368+
"CUDA with float8 support not available",
369+
)
342370
def test_sync_amax_func_cuda_graph_success():
343371
torch._dynamo.reset()
344372
with capture_stderr() as stderr:
@@ -368,9 +396,9 @@ def test_sync_amax_func_cuda_graph_success():
368396

369397

370398
@unittest.skipIf(
371-
not is_cuda_8_9,
372-
"CUDA not available",
373-
)
399+
not is_cuda_8_9,
400+
"CUDA not available",
401+
)
374402
@pytest.mark.parametrize(
375403
"dtype",
376404
[

test/float8/test_dtensor.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,43 +13,42 @@
1313
import copy
1414
import os
1515

16+
import pytest
1617
import torch
1718
import torch.nn as nn
1819
import torch.nn.functional as F
1920

20-
import pytest
21-
2221
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2322

2423
if not TORCH_VERSION_AT_LEAST_2_5:
2524
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2625

27-
from torchao.float8 import Float8LinearConfig
28-
from torchao.float8.float8_linear_utils import convert_to_float8_training
26+
from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor
27+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
28+
from torch.distributed.tensor.parallel import parallelize_module
29+
from torch.testing._internal.distributed._tensor.common_dtensor import (
30+
ModelArgs,
31+
Transformer,
32+
)
33+
from tqdm import tqdm
2934

35+
from torchao.float8 import Float8LinearConfig
3036
from torchao.float8.config import CastConfig, ScalingType
37+
from torchao.float8.float8_linear_utils import convert_to_float8_training
3138
from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic
3239
from torchao.float8.float8_tensor import (
3340
Float8Tensor,
3441
GemmInputRole,
35-
hp_tensor_and_scale_to_float8,
3642
LinearMMConfig,
43+
hp_tensor_and_scale_to_float8,
3744
)
3845
from torchao.float8.float8_tensor_parallel import (
3946
Float8ColwiseParallel,
4047
Float8RowwiseParallel,
4148
PrepareFloat8ModuleInput,
4249
)
4350
from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
44-
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
45-
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
46-
from torch.distributed.tensor.parallel import parallelize_module
4751
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
48-
from torch.testing._internal.distributed._tensor.common_dtensor import (
49-
ModelArgs,
50-
Transformer,
51-
)
52-
from tqdm import tqdm
5352

5453

5554
def setup_distributed():
@@ -325,19 +324,15 @@ def _test_distribute_fsdp_tensor_subclass(tp_mesh: DeviceMesh):
325324
)
326325
assert (
327326
isinstance(colwise_param, DTensor)
328-
and isinstance(
329-
colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor
330-
)
327+
and isinstance(colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor)
331328
), f"expect DTensor(local_tensor={WeightWithDynamicFloat8CastTensor}) but got {colwise_param}"
332329
# test Float8RowwiseParallel
333330
rowwise_param = distribute_tensor(
334331
model.layers[0].attention.wo.weight, tp_mesh, [Shard(1)]
335332
)
336333
assert (
337334
isinstance(rowwise_param, DTensor)
338-
and isinstance(
339-
rowwise_param._local_tensor, WeightWithDynamicFloat8CastTensor
340-
)
335+
and isinstance(rowwise_param._local_tensor, WeightWithDynamicFloat8CastTensor)
341336
), f"expect DTensor(local_tensor={WeightWithDynamicFloat8CastTensor}) but got {colwise_param}"
342337

343338

test/float8/test_fsdp.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313

1414
import copy
1515
import os
16-
import pytest
1716
import warnings
1817

1918
import fire
19+
import pytest
2020

2121
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2222

@@ -27,18 +27,21 @@
2727
import torch.distributed as dist
2828
import torch.multiprocessing as mp
2929
import torch.nn as nn
30+
from torch.distributed.fsdp import (
31+
FullStateDictConfig,
32+
StateDictType,
33+
)
34+
from torch.distributed.fsdp import (
35+
FullyShardedDataParallel as FSDP,
36+
)
37+
3038
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
3139
from torchao.float8.float8_linear_utils import (
3240
convert_to_float8_training,
3341
linear_requires_sync,
3442
sync_float8_amax_and_scale_history,
3543
)
3644
from torchao.float8.float8_utils import compute_error
37-
from torch.distributed.fsdp import (
38-
FullStateDictConfig,
39-
FullyShardedDataParallel as FSDP,
40-
StateDictType,
41-
)
4245

4346
torch.manual_seed(0)
4447

0 commit comments

Comments
 (0)