@@ -122,7 +122,7 @@ def forward(
122122 round_scales_to_power_of_2 = True ,
123123 )
124124 A_scaled = A .to (torch .float32 ) * A_scales
125- A_fp8_row_major = to_fp8_saturated (A_scaled , torch .float8_e4m3fn )
125+ A_data_row_major = to_fp8_saturated (A_scaled , torch .float8_e4m3fn )
126126
127127 # Convert B to float8, column-major for right operand of grouped GEMM.
128128 # B_t shape: (E, K, N)
@@ -136,18 +136,18 @@ def forward(
136136 round_scales_to_power_of_2 = True ,
137137 )
138138 B_t_scaled = B_t .to (torch .float32 ) * B_t_scales
139- B_t_fp8_col_major = to_fp8_saturated (B_t_scaled , torch .float8_e4m3fn )
139+ B_t_data_col_major = to_fp8_saturated (B_t_scaled , torch .float8_e4m3fn )
140140
141141 # Store what we need for backward.
142142 ctx .save_for_backward (A , B_t , offs )
143143 ctx .out_dtype = out_dtype
144144
145145 # Perform scaled grouped GEMM and return result.
146146 # output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
147- assert not _is_column_major (A_fp8_row_major ), (
147+ assert not _is_column_major (A_data_row_major ), (
148148 "A must be row-major for output = A @ B"
149149 )
150- assert _is_column_major (B_t_fp8_col_major ), (
150+ assert _is_column_major (B_t_data_col_major ), (
151151 "B must be column-major for output = A @ B"
152152 )
153153
@@ -157,8 +157,8 @@ def forward(
157157 A_scales = A_scales .squeeze (- 1 )
158158 B_t_scales = B_t_scales .squeeze (1 )
159159 return torch ._scaled_grouped_mm (
160- A_fp8_row_major ,
161- B_t_fp8_col_major ,
160+ A_data_row_major ,
161+ B_t_data_col_major ,
162162 A_scales .reciprocal (), # Reciprocals are needed for rescaling the output.
163163 B_t_scales .reciprocal (),
164164 offs ,
@@ -184,13 +184,13 @@ def backward(ctx, grad_output: torch.Tensor):
184184 round_scales_to_power_of_2 = True ,
185185 )
186186 grad_output_scaled = grad_output .to (torch .float32 ) * grad_output_scales
187- grad_output_fp8_row_major = to_fp8_saturated (
187+ grad_output_data_row_major = to_fp8_saturated (
188188 grad_output_scaled , torch .float8_e4m3fn
189189 )
190190
191191 # Compute B fp8 column-major for right operand of grouped GEMM:
192192 # grad_A = grad_output @ B.
193- B_fp8_col_major , B_scales = triton_fp8_rowwise_3d_transpose_rhs (
193+ B_data_col_major , B_scales = triton_fp8_rowwise_3d_transpose_rhs (
194194 B_t ._data if hasattr (B_t , "_data" ) else B_t ,
195195 output_dtype = torch .float8_e4m3fn ,
196196 round_scales_to_power_of_2 = True ,
@@ -199,10 +199,10 @@ def backward(ctx, grad_output: torch.Tensor):
199199 # Compute grad_A.
200200 # grad_A = grad_output @ B
201201 # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
202- assert not _is_column_major (grad_output_fp8_row_major ), (
202+ assert not _is_column_major (grad_output_data_row_major ), (
203203 "grad_output must be row-major for grad_A = grad_output @ B"
204204 )
205- assert _is_column_major (B_fp8_col_major ), (
205+ assert _is_column_major (B_data_col_major ), (
206206 "B must be column-major for grad_A = grad_output @ B"
207207 )
208208
@@ -212,8 +212,8 @@ def backward(ctx, grad_output: torch.Tensor):
212212 grad_output_scales = grad_output_scales .squeeze (- 1 )
213213 B_scales = B_scales .squeeze (1 )
214214 grad_A = torch ._scaled_grouped_mm (
215- grad_output_fp8_row_major ,
216- B_fp8_col_major ,
215+ grad_output_data_row_major ,
216+ B_data_col_major ,
217217 grad_output_scales .reciprocal (),
218218 B_scales .reciprocal (),
219219 offs ,
@@ -227,18 +227,18 @@ def backward(ctx, grad_output: torch.Tensor):
227227 # Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM
228228 # needed for grad_B: grad_output_t @ A
229229 # Use transpose method to avoid uncoalesced memory accesses.
230- grad_out_fp8_colwise , grad_out_scales = triton_fp8_per_group_colwise_scales (
230+ grad_out_data_colwise , grad_out_scales = triton_fp8_per_group_colwise_scales (
231231 grad_output .t ()
232232 .contiguous ()
233233 .t (), # Quantization is over 2x faster when input is col major, even with this transformation
234234 offs ,
235235 torch .float8_e4m3fn ,
236236 round_scales_to_power_of_2 = True ,
237237 )
238- grad_output_t_fp8_row_major = grad_out_fp8_colwise .t ()
238+ grad_output_t_data_row_major = grad_out_data_colwise .t ()
239239 grad_output_t_scales = grad_out_scales .t ()
240240
241- A_fp8_col_major , A_scales = triton_fp8_per_group_colwise_scales (
241+ A_data_col_major , A_scales = triton_fp8_per_group_colwise_scales (
242242 A .t ()
243243 .contiguous ()
244244 .t (), # Quantization is over 2x faster when input is col major, even with this transformation
@@ -249,19 +249,19 @@ def backward(ctx, grad_output: torch.Tensor):
249249
250250 # Compute grad_B = grad_output_t @ A.
251251 # grad_B = grad_output_t @ A
252- assert not _is_column_major (grad_output_t_fp8_row_major ), (
252+ assert not _is_column_major (grad_output_t_data_row_major ), (
253253 "grad_output_t must be row-major for grad_B = grad_output_t @ A"
254254 )
255- assert _is_column_major (A_fp8_col_major ), (
255+ assert _is_column_major (A_data_col_major ), (
256256 "A must be column-major for grad_B = grad_output_t @ A"
257257 )
258258
259259 # Per-token group scales computed via triton kernels above do not have
260260 # the empty dim like the scales computed via tensor_to_scale, so we need
261261 # don't need to squeeze here.
262262 grad_B = torch ._scaled_grouped_mm (
263- grad_output_t_fp8_row_major ,
264- A_fp8_col_major ,
263+ grad_output_t_data_row_major ,
264+ A_data_col_major ,
265265 grad_output_t_scales .reciprocal (),
266266 A_scales .reciprocal (),
267267 offs ,
@@ -295,13 +295,13 @@ def forward(
295295 ctx .out_dtype = out_dtype
296296 ctx .emulated = emulated
297297
298- # A_mx shape: (M, K)
298+ # A_fp8 shape: (M, K)
299299 # A_scale shape: (M, K//block_size)
300- A_scale , A_mx = to_mx (A , elem_dtype = torch .float8_e4m3fn , block_size = block_size )
300+ A_scale , A_fp8 = to_mx (A , elem_dtype = torch .float8_e4m3fn , block_size = block_size )
301301
302- # B_mx shape: (E, N, K)
302+ # B_fp8 shape: (E, N, K)
303303 # B_scale shape: (E, N, K//block_size)
304- B_scales , B_mx = to_mx (
304+ B_scales , B_fp8 = to_mx (
305305 B_t .transpose (- 2 , - 1 ),
306306 elem_dtype = torch .float8_e4m3fn ,
307307 block_size = block_size ,
@@ -315,9 +315,9 @@ def forward(
315315 else fbgemm_mxfp8_grouped_mm_2d_3d
316316 )
317317 out = mxfp8_2d_3d_grouped_mm (
318- A_mx ,
318+ A_fp8 ,
319319 A_scale ,
320- B_mx ,
320+ B_fp8 ,
321321 B_scales ,
322322 offs = offs ,
323323 block_size = block_size ,
@@ -332,15 +332,15 @@ def backward(ctx, grad_out: torch.Tensor):
332332 out_dtype = ctx .out_dtype
333333 emulated = ctx .emulated
334334
335- # grad_out_mx shape: (M, N)
335+ # grad_out_fp8 shape: (M, N)
336336 # grad_out_scale shape: (M, N//block_size)
337- grad_out_scale , grad_out_mx = to_mx (
337+ grad_out_scale , grad_out_fp8 = to_mx (
338338 grad_out , elem_dtype = torch .float8_e4m3fn , block_size = block_size
339339 )
340340
341- # B_mx shape: (E, K, N)
341+ # B_fp8 shape: (E, K, N)
342342 # B_scale shape: (E, K, N//block_size)
343- B_scales , B_mx = to_mx (
343+ B_scales , B_fp8 = to_mx (
344344 # TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
345345 B_t .contiguous (),
346346 elem_dtype = torch .float8_e4m3fn ,
@@ -354,43 +354,43 @@ def backward(ctx, grad_out: torch.Tensor):
354354 else fbgemm_mxfp8_grouped_mm_2d_3d
355355 )
356356 grad_A = mxfp8_2d_3d_grouped_mm (
357- grad_out_mx ,
357+ grad_out_fp8 ,
358358 grad_out_scale ,
359- B_mx ,
359+ B_fp8 ,
360360 B_scales ,
361361 offs = offs ,
362362 out_dtype = out_dtype ,
363363 )
364364
365- # grad_out_t_mx shape: (N, M)
365+ # grad_out_t_fp8 shape: (N, M)
366366 # grad_out_t_scales shape: (N, M//block_size)
367- grad_out_t_scales , grad_out_t_mx = to_mx (
367+ grad_out_t_scales , grad_out_t_fp8 = to_mx (
368368 # TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
369369 grad_out .transpose (- 2 , - 1 ).contiguous (),
370370 elem_dtype = torch .float8_e4m3fn ,
371371 block_size = block_size ,
372372 )
373373
374374 # Transpose A so we can scale along the M dimension, then un-transpose.
375- # A_t_mx shape: (K, M)
375+ # A_t_fp8 shape: (K, M)
376376 # A_t_scales shape: (K, M//block_size)
377- A_t_scales , A_t_mx = to_mx (
377+ A_t_scales , A_t_fp8 = to_mx (
378378 A .transpose (- 2 , - 1 ).contiguous (),
379379 elem_dtype = torch .float8_e4m3fn ,
380380 block_size = block_size ,
381381 )
382382
383- # A_mx shape = (M, K)
384- A_mx = A_t_mx .transpose (- 2 , - 1 )
383+ # A_fp8 shape = (M, K)
384+ A_fp8 = A_t_fp8 .transpose (- 2 , - 1 )
385385
386386 # A_scales shape = (M//block_size, K)
387387 A_scales = A_t_scales .transpose (- 2 , - 1 )
388388
389389 # grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
390390 grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d (
391- grad_out_t_mx ,
391+ grad_out_t_fp8 ,
392392 grad_out_t_scales ,
393- A_mx ,
393+ A_fp8 ,
394394 A_scales ,
395395 offs = offs ,
396396 )
@@ -402,64 +402,64 @@ def backward(ctx, grad_out: torch.Tensor):
402402
403403
404404def _emulated_mxfp8_scaled_grouped_mm_2d_3d (
405- A_mx : torch .Tensor ,
405+ A_fp8 : torch .Tensor ,
406406 A_scale : torch .Tensor ,
407- B_mx : torch .Tensor ,
407+ B_fp8 : torch .Tensor ,
408408 B_scale : torch .Tensor ,
409409 offs : Optional [torch .Tensor ] = None ,
410410 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
411411 block_size : int = 32 ,
412412) -> torch .Tensor :
413- assert A_mx .ndim == 2 , f"A must be 2D, got { A_mx .ndim } "
414- assert B_mx .ndim == 3 , f"B must be 3D, got { B_mx .ndim } "
415- assert A_scale .shape [0 ] == A_mx .shape [0 ], (
416- f"A_scale must have same M dim as A_mx , got A={ A_mx .shape } and A_scale={ A_scale .shape } "
413+ assert A_fp8 .ndim == 2 , f"A must be 2D, got { A_fp8 .ndim } "
414+ assert B_fp8 .ndim == 3 , f"B must be 3D, got { B_fp8 .ndim } "
415+ assert A_scale .shape [0 ] == A_fp8 .shape [0 ], (
416+ f"A_scale must have same M dim as A_fp8 , got A={ A_fp8 .shape } and A_scale={ A_scale .shape } "
417417 )
418- assert A_scale .shape [1 ] == A_mx .shape [1 ] // block_size , (
419- f"A_scale dim1 should be size K//block_size, got A={ A_mx .shape } and A_scale={ A_scale .shape } "
418+ assert A_scale .shape [1 ] == A_fp8 .shape [1 ] // block_size , (
419+ f"A_scale dim1 should be size K//block_size, got A={ A_fp8 .shape } and A_scale={ A_scale .shape } "
420420 )
421- assert B_scale .shape [0 ] == B_mx .shape [0 ], (
422- f"B_scale must have same E dim as B_mx , got B={ B_mx .shape } and B_scale={ B_scale .shape } "
421+ assert B_scale .shape [0 ] == B_fp8 .shape [0 ], (
422+ f"B_scale must have same E dim as B_fp8 , got B={ B_fp8 .shape } and B_scale={ B_scale .shape } "
423423 )
424- assert B_scale .shape [1 ] == B_mx .shape [1 ], (
425- f"B_scale must have same N dim as B_mx , got B={ B_mx .shape } and B_scale={ B_scale .shape } "
424+ assert B_scale .shape [1 ] == B_fp8 .shape [1 ], (
425+ f"B_scale must have same N dim as B_fp8 , got B={ B_fp8 .shape } and B_scale={ B_scale .shape } "
426426 )
427- assert B_scale .shape [2 ] == B_mx .shape [2 ] // block_size , (
428- f"B_scale dim2 should be size K//block_size, got B={ B_mx .shape } and B_scale={ B_scale .shape } "
427+ assert B_scale .shape [2 ] == B_fp8 .shape [2 ] // block_size , (
428+ f"B_scale dim2 should be size K//block_size, got B={ B_fp8 .shape } and B_scale={ B_scale .shape } "
429429 )
430430
431431 # Dequantize input
432- # A_mx shape: (M, K)
432+ # A_fp8 shape: (M, K)
433433 # A_scale shape: (M, K//block_size)
434- A_orig_shape = A_mx .shape
434+ A_orig_shape = A_fp8 .shape
435435
436436 # Reshape to be able to do per-scaling group multiplication
437- # A_mx shape: (M, K//block_size, block_size)
437+ # A_fp8 shape: (M, K//block_size, block_size)
438438 # A_scale shape: (M, K//block_size, 1)
439- A_mx = A_mx .reshape (* A_mx .shape [:- 1 ], A_mx .shape [- 1 ] // block_size , block_size )
439+ A_fp8 = A_fp8 .reshape (* A_fp8 .shape [:- 1 ], A_fp8 .shape [- 1 ] // block_size , block_size )
440440 A_scale = A_scale .unsqueeze (- 1 )
441441
442442 # Rescale and cast to bfloat16
443- A = A_mx .to (torch .bfloat16 ) * A_scale .to (torch .bfloat16 )
443+ A = A_fp8 .to (torch .bfloat16 ) * A_scale .to (torch .bfloat16 )
444444
445445 # Reshape back to original shape
446446 # A shape: (M, K)
447447 A = A .reshape (A_orig_shape )
448448
449449 # Dequantize weights
450450 # Tranpose to get block_size on rightmost dim
451- # B_mx shape: (E, N, K)
451+ # B_fp8 shape: (E, N, K)
452452 # B_scale shape: (E, N, K//block_size)
453- E , N , K = B_mx .shape
453+ E , N , K = B_fp8 .shape
454454
455455 # Reshape to be able to do per-scaling group multiplication
456- # B_mx shape: (E, N, K//block_size, block_size)
456+ # B_fp8 shape: (E, N, K//block_size, block_size)
457457 # B_scale shape: (E, N, K//block_size, 1)
458- B_mx = B_mx .reshape (* B_mx .shape [:- 1 ], B_mx .shape [- 1 ] // block_size , block_size )
458+ B_fp8 = B_fp8 .reshape (* B_fp8 .shape [:- 1 ], B_fp8 .shape [- 1 ] // block_size , block_size )
459459 B_scale = B_scale .unsqueeze (- 1 )
460460
461461 # Rescale and cast to bfloat16
462- B = B_mx .to (torch .bfloat16 ) * B_scale .to (torch .bfloat16 )
462+ B = B_fp8 .to (torch .bfloat16 ) * B_scale .to (torch .bfloat16 )
463463
464464 # Reshape back to original shape
465465 # B shape: (E, K, N)
@@ -471,27 +471,27 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
471471
472472
473473def _emulated_mxfp8_scaled_grouped_mm_2d_2d (
474- A_mx : torch .Tensor , # (M, K)
474+ A_fp8 : torch .Tensor , # (M, K)
475475 A_scale : torch .Tensor , # (M, K//block_size)
476- B_mx : torch .Tensor , # (K, N)
476+ B_fp8 : torch .Tensor , # (K, N)
477477 B_scale : torch .Tensor , # (K//block_size, N)
478478 offs : torch .Tensor ,
479479 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
480480 block_size : int = 32 ,
481481) -> torch .Tensor :
482- assert A_mx .ndim == 2 , "A must be 2D"
483- assert B_mx .ndim == 2 , "B must be 2D"
482+ assert A_fp8 .ndim == 2 , "A must be 2D"
483+ assert B_fp8 .ndim == 2 , "B must be 2D"
484484 A = torch .zeros (
485- A_mx .shape ,
485+ A_fp8 .shape ,
486486 dtype = torch .bfloat16 ,
487- device = A_mx .device ,
488- requires_grad = A_mx .requires_grad ,
487+ device = A_fp8 .device ,
488+ requires_grad = A_fp8 .requires_grad ,
489489 )
490490 B = torch .zeros (
491- B_mx .shape ,
491+ B_fp8 .shape ,
492492 dtype = torch .bfloat16 ,
493- device = B_mx .device ,
494- requires_grad = B_mx .requires_grad ,
493+ device = B_fp8 .device ,
494+ requires_grad = B_fp8 .requires_grad ,
495495 )
496496
497497 # Dequantize input per each scaling group
@@ -507,7 +507,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
507507 # -- Dequantize A tensor
508508 # A_group shape: (M, group_size)
509509 # A_scale shape: (M, group_size//block_size)
510- A_group = A_mx [:, group_start_idx :group_end_idx ]
510+ A_group = A_fp8 [:, group_start_idx :group_end_idx ]
511511 A_group_shape = A_group .shape
512512
513513 # Get scales for this group.
@@ -532,7 +532,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
532532
533533 # -- Dequantize B tensor
534534 # B_group shape is (group_size, N)
535- B_group = B_mx [group_start_idx :group_end_idx , :]
535+ B_group = B_fp8 [group_start_idx :group_end_idx , :]
536536 B_group_shape = B_group .shape
537537
538538 # Scales shape is (group_size//block_size, N)
0 commit comments