@@ -291,13 +291,13 @@ def forward(
291291 ctx .out_dtype = out_dtype
292292 ctx .emulated = emulated
293293
294- # A_mx shape: (M, K)
294+ # A_fp8 shape: (M, K)
295295 # A_scale shape: (M, K//block_size)
296- A_scale , A_mx = to_mx (A , elem_dtype = torch .float8_e4m3fn , block_size = block_size )
296+ A_scale , A_fp8 = to_mx (A , elem_dtype = torch .float8_e4m3fn , block_size = block_size )
297297
298- # B_mx shape: (E, N, K)
298+ # B_fp8 shape: (E, N, K)
299299 # B_scale shape: (E, N, K//block_size)
300- B_scales , B_mx = to_mx (
300+ B_scales , B_fp8 = to_mx (
301301 B_t .transpose (- 2 , - 1 ),
302302 elem_dtype = torch .float8_e4m3fn ,
303303 block_size = block_size ,
@@ -311,9 +311,9 @@ def forward(
311311 else fbgemm_mxfp8_grouped_mm_2d_3d
312312 )
313313 out = mxfp8_2d_3d_grouped_mm (
314- A_mx ,
314+ A_fp8 ,
315315 A_scale ,
316- B_mx ,
316+ B_fp8 ,
317317 B_scales ,
318318 offs = offs ,
319319 block_size = block_size ,
@@ -328,15 +328,15 @@ def backward(ctx, grad_out: torch.Tensor):
328328 out_dtype = ctx .out_dtype
329329 emulated = ctx .emulated
330330
331- # grad_out_mx shape: (M, N)
331+ # grad_out_fp8 shape: (M, N)
332332 # grad_out_scale shape: (M, N//block_size)
333- grad_out_scale , grad_out_mx = to_mx (
333+ grad_out_scale , grad_out_fp8 = to_mx (
334334 grad_out , elem_dtype = torch .float8_e4m3fn , block_size = block_size
335335 )
336336
337- # B_mx shape: (E, K, N)
337+ # B_fp8 shape: (E, K, N)
338338 # B_scale shape: (E, K, N//block_size)
339- B_scales , B_mx = to_mx (
339+ B_scales , B_fp8 = to_mx (
340340 # TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
341341 B_t .contiguous (),
342342 elem_dtype = torch .float8_e4m3fn ,
@@ -350,43 +350,43 @@ def backward(ctx, grad_out: torch.Tensor):
350350 else fbgemm_mxfp8_grouped_mm_2d_3d
351351 )
352352 grad_A = mxfp8_2d_3d_grouped_mm (
353- grad_out_mx ,
353+ grad_out_fp8 ,
354354 grad_out_scale ,
355- B_mx ,
355+ B_fp8 ,
356356 B_scales ,
357357 offs = offs ,
358358 out_dtype = out_dtype ,
359359 )
360360
361- # grad_out_t_mx shape: (N, M)
361+ # grad_out_t_fp8 shape: (N, M)
362362 # grad_out_t_scales shape: (N, M//block_size)
363- grad_out_t_scales , grad_out_t_mx = to_mx (
363+ grad_out_t_scales , grad_out_t_fp8 = to_mx (
364364 # TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
365365 grad_out .transpose (- 2 , - 1 ).contiguous (),
366366 elem_dtype = torch .float8_e4m3fn ,
367367 block_size = block_size ,
368368 )
369369
370370 # Transpose A so we can scale along the M dimension, then un-transpose.
371- # A_t_mx shape: (K, M)
371+ # A_t_fp8 shape: (K, M)
372372 # A_t_scales shape: (K, M//block_size)
373- A_t_scales , A_t_mx = to_mx (
373+ A_t_scales , A_t_fp8 = to_mx (
374374 A .transpose (- 2 , - 1 ).contiguous (),
375375 elem_dtype = torch .float8_e4m3fn ,
376376 block_size = block_size ,
377377 )
378378
379- # A_mx shape = (M, K)
380- A_mx = A_t_mx .transpose (- 2 , - 1 )
379+ # A_fp8 shape = (M, K)
380+ A_fp8 = A_t_fp8 .transpose (- 2 , - 1 )
381381
382382 # A_scales shape = (M//block_size, K)
383383 A_scales = A_t_scales .transpose (- 2 , - 1 )
384384
385385 # grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
386386 grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d (
387- grad_out_t_mx ,
387+ grad_out_t_fp8 ,
388388 grad_out_t_scales ,
389- A_mx ,
389+ A_fp8 ,
390390 A_scales ,
391391 offs = offs ,
392392 )
@@ -398,64 +398,64 @@ def backward(ctx, grad_out: torch.Tensor):
398398
399399
400400def _emulated_mxfp8_scaled_grouped_mm_2d_3d (
401- A_mx : torch .Tensor ,
401+ A_fp8 : torch .Tensor ,
402402 A_scale : torch .Tensor ,
403- B_mx : torch .Tensor ,
403+ B_fp8 : torch .Tensor ,
404404 B_scale : torch .Tensor ,
405405 offs : Optional [torch .Tensor ] = None ,
406406 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
407407 block_size : int = 32 ,
408408) -> torch .Tensor :
409- assert A_mx .ndim == 2 , f"A must be 2D, got { A_mx .ndim } "
410- assert B_mx .ndim == 3 , f"B must be 3D, got { B_mx .ndim } "
411- assert A_scale .shape [0 ] == A_mx .shape [0 ], (
412- f"A_scale must have same M dim as A_mx , got A={ A_mx .shape } and A_scale={ A_scale .shape } "
409+ assert A_fp8 .ndim == 2 , f"A must be 2D, got { A_fp8 .ndim } "
410+ assert B_fp8 .ndim == 3 , f"B must be 3D, got { B_fp8 .ndim } "
411+ assert A_scale .shape [0 ] == A_fp8 .shape [0 ], (
412+ f"A_scale must have same M dim as A_fp8 , got A={ A_fp8 .shape } and A_scale={ A_scale .shape } "
413413 )
414- assert A_scale .shape [1 ] == A_mx .shape [1 ] // block_size , (
415- f"A_scale dim1 should be size K//block_size, got A={ A_mx .shape } and A_scale={ A_scale .shape } "
414+ assert A_scale .shape [1 ] == A_fp8 .shape [1 ] // block_size , (
415+ f"A_scale dim1 should be size K//block_size, got A={ A_fp8 .shape } and A_scale={ A_scale .shape } "
416416 )
417- assert B_scale .shape [0 ] == B_mx .shape [0 ], (
418- f"B_scale must have same E dim as B_mx , got B={ B_mx .shape } and B_scale={ B_scale .shape } "
417+ assert B_scale .shape [0 ] == B_fp8 .shape [0 ], (
418+ f"B_scale must have same E dim as B_fp8 , got B={ B_fp8 .shape } and B_scale={ B_scale .shape } "
419419 )
420- assert B_scale .shape [1 ] == B_mx .shape [1 ], (
421- f"B_scale must have same N dim as B_mx , got B={ B_mx .shape } and B_scale={ B_scale .shape } "
420+ assert B_scale .shape [1 ] == B_fp8 .shape [1 ], (
421+ f"B_scale must have same N dim as B_fp8 , got B={ B_fp8 .shape } and B_scale={ B_scale .shape } "
422422 )
423- assert B_scale .shape [2 ] == B_mx .shape [2 ] // block_size , (
424- f"B_scale dim2 should be size K//block_size, got B={ B_mx .shape } and B_scale={ B_scale .shape } "
423+ assert B_scale .shape [2 ] == B_fp8 .shape [2 ] // block_size , (
424+ f"B_scale dim2 should be size K//block_size, got B={ B_fp8 .shape } and B_scale={ B_scale .shape } "
425425 )
426426
427427 # Dequantize input
428- # A_mx shape: (M, K)
428+ # A_fp8 shape: (M, K)
429429 # A_scale shape: (M, K//block_size)
430- A_orig_shape = A_mx .shape
430+ A_orig_shape = A_fp8 .shape
431431
432432 # Reshape to be able to do per-scaling group multiplication
433- # A_mx shape: (M, K//block_size, block_size)
433+ # A_fp8 shape: (M, K//block_size, block_size)
434434 # A_scale shape: (M, K//block_size, 1)
435- A_mx = A_mx .reshape (* A_mx .shape [:- 1 ], A_mx .shape [- 1 ] // block_size , block_size )
435+ A_fp8 = A_fp8 .reshape (* A_fp8 .shape [:- 1 ], A_fp8 .shape [- 1 ] // block_size , block_size )
436436 A_scale = A_scale .unsqueeze (- 1 )
437437
438438 # Rescale and cast to bfloat16
439- A = A_mx .to (torch .bfloat16 ) * A_scale .to (torch .bfloat16 )
439+ A = A_fp8 .to (torch .bfloat16 ) * A_scale .to (torch .bfloat16 )
440440
441441 # Reshape back to original shape
442442 # A shape: (M, K)
443443 A = A .reshape (A_orig_shape )
444444
445445 # Dequantize weights
446446 # Tranpose to get block_size on rightmost dim
447- # B_mx shape: (E, N, K)
447+ # B_fp8 shape: (E, N, K)
448448 # B_scale shape: (E, N, K//block_size)
449- E , N , K = B_mx .shape
449+ E , N , K = B_fp8 .shape
450450
451451 # Reshape to be able to do per-scaling group multiplication
452- # B_mx shape: (E, N, K//block_size, block_size)
452+ # B_fp8 shape: (E, N, K//block_size, block_size)
453453 # B_scale shape: (E, N, K//block_size, 1)
454- B_mx = B_mx .reshape (* B_mx .shape [:- 1 ], B_mx .shape [- 1 ] // block_size , block_size )
454+ B_fp8 = B_fp8 .reshape (* B_fp8 .shape [:- 1 ], B_fp8 .shape [- 1 ] // block_size , block_size )
455455 B_scale = B_scale .unsqueeze (- 1 )
456456
457457 # Rescale and cast to bfloat16
458- B = B_mx .to (torch .bfloat16 ) * B_scale .to (torch .bfloat16 )
458+ B = B_fp8 .to (torch .bfloat16 ) * B_scale .to (torch .bfloat16 )
459459
460460 # Reshape back to original shape
461461 # B shape: (E, K, N)
@@ -467,27 +467,27 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
467467
468468
469469def _emulated_mxfp8_scaled_grouped_mm_2d_2d (
470- A_mx : torch .Tensor , # (M, K)
470+ A_fp8 : torch .Tensor , # (M, K)
471471 A_scale : torch .Tensor , # (M, K//block_size)
472- B_mx : torch .Tensor , # (K, N)
472+ B_fp8 : torch .Tensor , # (K, N)
473473 B_scale : torch .Tensor , # (K//block_size, N)
474474 offs : torch .Tensor ,
475475 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
476476 block_size : int = 32 ,
477477) -> torch .Tensor :
478- assert A_mx .ndim == 2 , "A must be 2D"
479- assert B_mx .ndim == 2 , "B must be 2D"
478+ assert A_fp8 .ndim == 2 , "A must be 2D"
479+ assert B_fp8 .ndim == 2 , "B must be 2D"
480480 A = torch .zeros (
481- A_mx .shape ,
481+ A_fp8 .shape ,
482482 dtype = torch .bfloat16 ,
483- device = A_mx .device ,
484- requires_grad = A_mx .requires_grad ,
483+ device = A_fp8 .device ,
484+ requires_grad = A_fp8 .requires_grad ,
485485 )
486486 B = torch .zeros (
487- B_mx .shape ,
487+ B_fp8 .shape ,
488488 dtype = torch .bfloat16 ,
489- device = B_mx .device ,
490- requires_grad = B_mx .requires_grad ,
489+ device = B_fp8 .device ,
490+ requires_grad = B_fp8 .requires_grad ,
491491 )
492492
493493 # Dequantize input per each scaling group
@@ -503,7 +503,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
503503 # -- Dequantize A tensor
504504 # A_group shape: (M, group_size)
505505 # A_scale shape: (M, group_size//block_size)
506- A_group = A_mx [:, group_start_idx :group_end_idx ]
506+ A_group = A_fp8 [:, group_start_idx :group_end_idx ]
507507 A_group_shape = A_group .shape
508508
509509 # Get scales for this group.
@@ -528,7 +528,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
528528
529529 # -- Dequantize B tensor
530530 # B_group shape is (group_size, N)
531- B_group = B_mx [group_start_idx :group_end_idx , :]
531+ B_group = B_fp8 [group_start_idx :group_end_idx , :]
532532 B_group_shape = B_group .shape
533533
534534 # Scales shape is (group_size//block_size, N)
0 commit comments