2525 tensor_already_casted_to_fp8 ,
2626 to_fp8_no_autograd ,
2727)
28- from float8_experimental .float8_utils import e4m3_dtype , tensor_to_scale
28+ from float8_experimental .float8_utils import (
29+ e4m3_dtype ,
30+ get_supported_granularity ,
31+ tensor_to_scale ,
32+ )
33+
34+ SUPPORTED_GRANULARITY = get_supported_granularity ()
2935
3036
3137class ActivationCasting (Enum ):
@@ -75,7 +81,7 @@ def __init__(
7581 # FP8 specific arguments
7682 quant_config : QuantConfig ,
7783 forward_config : ScaledMMConfig ,
78- scaling_granularity : ScalingGranularity ,
84+ scaling_granularity : Optional [ ScalingGranularity ] ,
7985 # nn.Linear arguments
8086 in_features : int ,
8187 out_features : int ,
@@ -86,7 +92,26 @@ def __init__(
8692 # Construct the superclass this will create dummy weights and biases
8793 super ().__init__ (in_features , out_features , bias , device , dtype )
8894 self .forward_config = forward_config
89- self .scaling_granularity = scaling_granularity
95+ if scaling_granularity is None :
96+ self .scaling_granularity = (
97+ ScalingGranularity .AxisWise
98+ if dtype == torch .bfloat16
99+ and quant_config .static_quantization_scale is None
100+ else ScalingGranularity .TensorWise
101+ )
102+ else :
103+ assert (
104+ scaling_granularity in SUPPORTED_GRANULARITY
105+ ), f"scaling_granularity must be in { SUPPORTED_GRANULARITY } but got { scaling_granularity } "
106+ if (
107+ scaling_granularity == ScalingGranularity .AxisWise
108+ and dtype != torch .bfloat16
109+ ):
110+ raise ValueError (
111+ "AxisWise scaling granularity is only supported for bfloat16."
112+ )
113+ self .scaling_granularity = scaling_granularity
114+
90115 self .activation_casting = quant_config .activation_casting
91116 if self .activation_casting == ActivationCasting .STATIC :
92117 self .register_buffer (
@@ -101,13 +126,22 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
101126 input , self .weight .to_original_precision ()
102127 )
103128
129+ # TODO we arent folding leading dims yet, but need it to calculate the proper scale.. this sucks
130+ original_m = input .shape [:- 1 ]
131+ input = input .view (- 1 , input .shape [- 1 ])
132+
104133 x_fp8 = cast_to_float8_e4m3_inference (
105134 input ,
106135 self .forward_config ,
107136 static_quantization_scale = self .static_quantization_scale ,
108137 scaling_granularity = self .scaling_granularity ,
109138 )
110- return torch .nn .functional .linear (x_fp8 , self .weight , self .bias )
139+ return torch .nn .functional .linear (x_fp8 , self .weight , self .bias ).view (
140+ * original_m , - 1
141+ )
142+
143+ def extra_repr (self ):
144+ return f"{ super ().extra_repr ()} ,activation_casting={ self .activation_casting .name } ,scaling_granularity={ self .scaling_granularity .name } "
111145
112146 # Builder functions for Float8LinearInference
113147 def quantize_weight (self , dtype : torch .dtype = e4m3_dtype ) -> None :
@@ -124,7 +158,12 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
124158 assert not isinstance (
125159 self .weight , Float8Tensor
126160 ), "Weight has already been quantized, cannot quantize again."
127- scale = tensor_to_scale (self .weight , dtype , self .scaling_granularity )
161+
162+ # For weight tensors + AxisWise we calculate scales along columns
163+ dim = None
164+ if self .scaling_granularity == ScalingGranularity .AxisWise :
165+ dim = 1
166+ scale = tensor_to_scale (self .weight , dtype , self .scaling_granularity , dim = dim )
128167 quantized_weight = to_fp8_no_autograd (
129168 self .weight , scale , dtype , self .forward_config
130169 )
@@ -143,19 +182,20 @@ def from_float(
143182 module : nn .Module ,
144183 quant_config : QuantConfig ,
145184 use_fast_accum : bool ,
185+ scaling_granularity : Optional [ScalingGranularity ],
146186 ) -> "Float8InferenceLinear" :
147187 """
148188 Create an nn.Linear with fp8 compute from another nn.Linear
149189
150190 Args:
151191 mod (torch.nn.Linear): nn.Linear to convert
152192 quant_config (QuantConfig): Configuration for the weight and activation casting
193+ use_fast_accum (bool): Whether to enable fast accumulation for the Float8InferenceLinear.
194+ scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
153195 """
154196 forward_config = ScaledMMConfig (
155197 False , use_fast_accum , pad_inner_dim = config .pad_inner_dim
156198 )
157- # TODO: For now hardcode TensorWise scaling
158- scaling_granularity = ScalingGranularity .TensorWise
159199 linear = cls (
160200 quant_config ,
161201 forward_config ,
@@ -164,6 +204,7 @@ def from_float(
164204 module .out_features ,
165205 False ,
166206 device = torch .device ("meta" ),
207+ dtype = module .weight .dtype ,
167208 )
168209 linear .set_weight_and_bias (module .weight , module .bias )
169210 linear .quantize_weight ()
@@ -194,18 +235,29 @@ def cast_to_float8_e4m3_inference(
194235 """
195236 if tensor_already_casted_to_fp8 (inpt_tensor ):
196237 return inpt_tensor
238+
239+ # For input tensors + AxisWise we calculate scales along rows
240+ dim = None
241+ if scaling_granularity == ScalingGranularity .AxisWise :
242+ dim = 1
243+
197244 scale = (
198245 static_quantization_scale
199246 if static_quantization_scale is not None
200247 else tensor_to_scale (
201- inpt_tensor , e4m3_dtype , scaling_granularity , reduce_amax = reduce_amax
248+ inpt_tensor ,
249+ e4m3_dtype ,
250+ scaling_granularity ,
251+ dim = dim ,
252+ reduce_amax = reduce_amax ,
202253 )
203254 )
204255 return Float8Tensor .to_float8 (
205256 inpt_tensor ,
206257 scale ,
207258 e4m3_dtype ,
208259 mm_config = mm_config ,
260+ scaling_granularity = scaling_granularity ,
209261 )
210262
211263
@@ -215,6 +267,7 @@ def quantize_to_float8(
215267 * ,
216268 skip_fqn_list : Optional [List [str ]] = None ,
217269 use_fast_accum : bool = True ,
270+ scaling_granularity : Optional [ScalingGranularity ] = None ,
218271) -> Optional [nn .Module ]:
219272 """
220273 Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.
@@ -228,6 +281,7 @@ def quantize_to_float8(
228281 quant_config (QuantConfig): Quantization configuration for Float8 conversion.
229282 skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion.
230283 use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
284+ scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
231285
232286 Returns:
233287 nn.Module: The modified module with applicable Linear layers converted to Float8.
@@ -237,6 +291,8 @@ def quantize_to_float8(
237291 """
238292 return swap_linear_layers (
239293 module ,
240- lambda m : Float8InferenceLinear .from_float (m , quant_config , use_fast_accum ),
294+ lambda m : Float8InferenceLinear .from_float (
295+ m , quant_config , use_fast_accum , scaling_granularity
296+ ),
241297 skip_fqn_list = skip_fqn_list ,
242298 )
0 commit comments