1414
1515import torch
1616
17- from float8_experimental .config import Float8LinearConfig , TensorScalingType
17+ from float8_experimental .config import Float8LinearConfig , ScalingType
1818
1919from float8_experimental .float8_dynamic_utils import (
2020 cast_to_float8_e4m3_dynamic ,
@@ -215,9 +215,9 @@ def __init__(self, *args, **kwargs):
215215 self .scaling_type_grad_output = config .cast_config_grad_output .scaling_type
216216 # Convenience flag to skip code related to delayed scaling
217217 self .has_any_delayed_scaling = (
218- self .scaling_type_input is TensorScalingType .DELAYED
219- or self .scaling_type_weight is TensorScalingType .DELAYED
220- or self .scaling_type_grad_output is TensorScalingType .DELAYED
218+ self .scaling_type_input is ScalingType .DELAYED
219+ or self .scaling_type_weight is ScalingType .DELAYED
220+ or self .scaling_type_grad_output is ScalingType .DELAYED
221221 )
222222
223223 self .config = config
@@ -340,7 +340,7 @@ def cast_input_to_float8(
340340 autocast_dtype = torch .get_autocast_gpu_dtype ()
341341 input = input .to (autocast_dtype )
342342
343- if self .scaling_type_input is TensorScalingType .DELAYED :
343+ if self .scaling_type_input is ScalingType .DELAYED :
344344 scale_fn_name = self .config .delayed_scaling_config .scale_fn_name
345345 _maybe_initialize_amaxes_scales_for_float8_cast (
346346 input ,
@@ -361,14 +361,14 @@ def cast_input_to_float8(
361361 gemm_input_role = GemmInputRole .INPUT ,
362362 )
363363 else :
364- assert self .scaling_type_input is TensorScalingType .DYNAMIC
364+ assert self .scaling_type_input is ScalingType .DYNAMIC
365365 input_fp8 = cast_to_float8_e4m3_dynamic (input , self .linear_mm_config )
366366 return input_fp8
367367
368368 def cast_weight_to_float8 (
369369 self , weight : torch .Tensor , is_amax_initialized : bool
370370 ) -> torch .Tensor :
371- if self .scaling_type_weight is TensorScalingType .DELAYED :
371+ if self .scaling_type_weight is ScalingType .DELAYED :
372372 if isinstance (self .weight , Float8Tensor ): # cast by FSDP
373373 weight_fp8 = self .weight
374374 else :
@@ -393,7 +393,7 @@ def cast_weight_to_float8(
393393 gemm_input_role = GemmInputRole .WEIGHT ,
394394 )
395395 else :
396- assert self .scaling_type_weight is TensorScalingType .DYNAMIC
396+ assert self .scaling_type_weight is ScalingType .DYNAMIC
397397 if isinstance (self .weight , Float8Tensor ): # cast by FSDP
398398 weight_fp8 = self .weight
399399 else :
@@ -405,7 +405,7 @@ def cast_weight_to_float8(
405405 return weight_fp8
406406
407407 def cast_output_to_float8_in_bw (self , output : torch .Tensor ) -> torch .Tensor :
408- if self .scaling_type_grad_output is TensorScalingType .DELAYED :
408+ if self .scaling_type_grad_output is ScalingType .DELAYED :
409409 scale_fn_name = self .config .delayed_scaling_config .scale_fn_name
410410 output = NoopFwToFloat8E5M2Bw .apply (
411411 output ,
@@ -417,7 +417,7 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
417417 self .linear_mm_config ,
418418 )
419419 else :
420- assert self .scaling_type_grad_output is TensorScalingType .DYNAMIC
420+ assert self .scaling_type_grad_output is ScalingType .DYNAMIC
421421 output = cast_to_float8_e5m2_dynamic_bw (output , self .linear_mm_config )
422422 return output
423423
@@ -504,17 +504,15 @@ def from_float(
504504 # 2. buffers need to be already created for the delayed scaling version
505505 # of the weight wrapper to be initialized
506506 if config .enable_fsdp_float8_all_gather :
507- if config .cast_config_weight .scaling_type is TensorScalingType .DYNAMIC :
507+ if config .cast_config_weight .scaling_type is ScalingType .DYNAMIC :
508508 new_mod .weight = torch .nn .Parameter (
509509 WeightWithDynamicFloat8CastTensor (
510510 new_mod .weight ,
511511 new_mod .linear_mm_config ,
512512 )
513513 )
514514 else :
515- assert (
516- config .cast_config_weight .scaling_type is TensorScalingType .DELAYED
517- )
515+ assert config .cast_config_weight .scaling_type is ScalingType .DELAYED
518516 new_mod .weight = torch .nn .Parameter (
519517 WeightWithDelayedFloat8CastTensor (
520518 new_mod .weight ,
0 commit comments