@@ -92,9 +92,9 @@ def short_str(self):
9292
9393 def __post_init__ (self ):
9494 if self .scaling_granularity is ScalingGranularity .AXISWISE :
95- assert (
96- self . scaling_type is ScalingType . DYNAMIC
97- ), "only dynamic scaling type is supported for axiswise scaling granularity"
95+ assert self . scaling_type is ScalingType . DYNAMIC , (
96+ "only dynamic scaling type is supported for axiswise scaling granularity"
97+ )
9898 assert self .target_dtype is None or (
9999 self .target_dtype .is_floating_point and self .target_dtype .itemsize == 1
100100 ), "must specify a 8-bit floating-point dtype"
@@ -240,7 +240,9 @@ def __post_init__(self):
240240
241241 # float8 all-gather only supports tensorwise, in the future may support blockwise
242242 if self .cast_config_weight .scaling_granularity != ScalingGranularity .TENSORWISE :
243- assert not self .enable_fsdp_float8_all_gather , f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got { self .cast_config_weight .scaling_granularity } "
243+ assert not self .enable_fsdp_float8_all_gather , (
244+ f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got { self .cast_config_weight .scaling_granularity } "
245+ )
244246
245247 # save some characters in the compatibility checks below
246248 cc_i = self .cast_config_input
@@ -259,9 +261,9 @@ def __post_init__(self):
259261 ):
260262 is_disabled_1 = cc1 .scaling_type is ScalingType .DISABLED
261263 is_disabled_2 = cc1 .scaling_type is ScalingType .DISABLED
262- assert (
263- is_disabled_1 == is_disabled_2
264- ), f"incompatible operand precision for { gemm_name } "
264+ assert is_disabled_1 == is_disabled_2 , (
265+ f"incompatible operand precision for { gemm_name } "
266+ )
265267
266268 for cc1 , cc2 , operand_name , default_dtype in [
267269 (cc_i , cc_i_gw , "input" , e4m3_dtype ),
@@ -273,9 +275,9 @@ def __post_init__(self):
273275 object .__setattr__ (cc1 , "target_dtype" , default_dtype )
274276 if cc2 .target_dtype is None :
275277 object .__setattr__ (cc2 , "target_dtype" , default_dtype )
276- assert (
277- cc1 . target_dtype == cc2 . target_dtype
278- ), f" { operand_name } must be cast to the same dtype in both matmuls it's used in"
278+ assert cc1 . target_dtype == cc2 . target_dtype , (
279+ f" { operand_name } must be cast to the same dtype in both matmuls it's used in"
280+ )
279281
280282 # See the comments around `force_recompute_fp8_weight_in_bwd` for more details of this warning.
281283 if (
@@ -296,9 +298,9 @@ def from_recipe_name(
296298 """
297299 if type (recipe_name ) == str :
298300 valid_names = [n .value for n in Float8LinearRecipeName ]
299- assert (
300- recipe_name in valid_names
301- ), f"recipe_name { recipe_name } not in valid names { valid_names } "
301+ assert recipe_name in valid_names , (
302+ f" recipe_name { recipe_name } not in valid names { valid_names } "
303+ )
302304 recipe_name = Float8LinearRecipeName (recipe_name )
303305
304306 if recipe_name is Float8LinearRecipeName .TENSORWISE :
0 commit comments