Skip to content

Commit 134a82f

Browse files
committed
fixing formatting issue
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent cae1cce commit 134a82f

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

torchao/float8/config.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ def short_str(self):
9292

9393
def __post_init__(self):
9494
if self.scaling_granularity is ScalingGranularity.AXISWISE:
95-
assert (
96-
self.scaling_type is ScalingType.DYNAMIC
97-
), "only dynamic scaling type is supported for axiswise scaling granularity"
95+
assert self.scaling_type is ScalingType.DYNAMIC, (
96+
"only dynamic scaling type is supported for axiswise scaling granularity"
97+
)
9898
assert self.target_dtype is None or (
9999
self.target_dtype.is_floating_point and self.target_dtype.itemsize == 1
100100
), "must specify a 8-bit floating-point dtype"
@@ -240,7 +240,9 @@ def __post_init__(self):
240240

241241
# float8 all-gather only supports tensorwise, in the future may support blockwise
242242
if self.cast_config_weight.scaling_granularity != ScalingGranularity.TENSORWISE:
243-
assert not self.enable_fsdp_float8_all_gather, f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}"
243+
assert not self.enable_fsdp_float8_all_gather, (
244+
f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}"
245+
)
244246

245247
# save some characters in the compatibility checks below
246248
cc_i = self.cast_config_input
@@ -259,9 +261,9 @@ def __post_init__(self):
259261
):
260262
is_disabled_1 = cc1.scaling_type is ScalingType.DISABLED
261263
is_disabled_2 = cc1.scaling_type is ScalingType.DISABLED
262-
assert (
263-
is_disabled_1 == is_disabled_2
264-
), f"incompatible operand precision for {gemm_name}"
264+
assert is_disabled_1 == is_disabled_2, (
265+
f"incompatible operand precision for {gemm_name}"
266+
)
265267

266268
for cc1, cc2, operand_name, default_dtype in [
267269
(cc_i, cc_i_gw, "input", e4m3_dtype),
@@ -273,9 +275,9 @@ def __post_init__(self):
273275
object.__setattr__(cc1, "target_dtype", default_dtype)
274276
if cc2.target_dtype is None:
275277
object.__setattr__(cc2, "target_dtype", default_dtype)
276-
assert (
277-
cc1.target_dtype == cc2.target_dtype
278-
), f"{operand_name} must be cast to the same dtype in both matmuls it's used in"
278+
assert cc1.target_dtype == cc2.target_dtype, (
279+
f"{operand_name} must be cast to the same dtype in both matmuls it's used in"
280+
)
279281

280282
# See the comments around `force_recompute_fp8_weight_in_bwd` for more details of this warning.
281283
if (
@@ -296,9 +298,9 @@ def from_recipe_name(
296298
"""
297299
if type(recipe_name) == str:
298300
valid_names = [n.value for n in Float8LinearRecipeName]
299-
assert (
300-
recipe_name in valid_names
301-
), f"recipe_name {recipe_name} not in valid names {valid_names}"
301+
assert recipe_name in valid_names, (
302+
f"recipe_name {recipe_name} not in valid names {valid_names}"
303+
)
302304
recipe_name = Float8LinearRecipeName(recipe_name)
303305

304306
if recipe_name is Float8LinearRecipeName.TENSORWISE:

0 commit comments

Comments
 (0)