Skip to content

Commit c7e1840

Browse files
[mxfp8 moe training] refactor all var names with suffix _mx to _data for clarity
stack-info: PR: #2879, branch: danielvegamyhre/stack/60
1 parent 15a6de6 commit c7e1840

File tree

1 file changed

+75
-75
lines changed

1 file changed

+75
-75
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 75 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def forward(
122122
round_scales_to_power_of_2=True,
123123
)
124124
A_scaled = A.to(torch.float32) * A_scales
125-
A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn)
125+
A_data_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn)
126126

127127
# Convert B to float8, column-major for right operand of grouped GEMM.
128128
# B_t shape: (E, K, N)
@@ -136,18 +136,18 @@ def forward(
136136
round_scales_to_power_of_2=True,
137137
)
138138
B_t_scaled = B_t.to(torch.float32) * B_t_scales
139-
B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn)
139+
B_t_data_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn)
140140

141141
# Store what we need for backward.
142142
ctx.save_for_backward(A, B_t, offs)
143143
ctx.out_dtype = out_dtype
144144

145145
# Perform scaled grouped GEMM and return result.
146146
# output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
147-
assert not _is_column_major(A_fp8_row_major), (
147+
assert not _is_column_major(A_data_row_major), (
148148
"A must be row-major for output = A @ B"
149149
)
150-
assert _is_column_major(B_t_fp8_col_major), (
150+
assert _is_column_major(B_t_data_col_major), (
151151
"B must be column-major for output = A @ B"
152152
)
153153

@@ -157,8 +157,8 @@ def forward(
157157
A_scales = A_scales.squeeze(-1)
158158
B_t_scales = B_t_scales.squeeze(1)
159159
return torch._scaled_grouped_mm(
160-
A_fp8_row_major,
161-
B_t_fp8_col_major,
160+
A_data_row_major,
161+
B_t_data_col_major,
162162
A_scales.reciprocal(), # Reciprocals are needed for rescaling the output.
163163
B_t_scales.reciprocal(),
164164
offs,
@@ -184,13 +184,13 @@ def backward(ctx, grad_output: torch.Tensor):
184184
round_scales_to_power_of_2=True,
185185
)
186186
grad_output_scaled = grad_output.to(torch.float32) * grad_output_scales
187-
grad_output_fp8_row_major = to_fp8_saturated(
187+
grad_output_data_row_major = to_fp8_saturated(
188188
grad_output_scaled, torch.float8_e4m3fn
189189
)
190190

191191
# Compute B fp8 column-major for right operand of grouped GEMM:
192192
# grad_A = grad_output @ B.
193-
B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs(
193+
B_data_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs(
194194
B_t._data if hasattr(B_t, "_data") else B_t,
195195
output_dtype=torch.float8_e4m3fn,
196196
round_scales_to_power_of_2=True,
@@ -199,10 +199,10 @@ def backward(ctx, grad_output: torch.Tensor):
199199
# Compute grad_A.
200200
# grad_A = grad_output @ B
201201
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
202-
assert not _is_column_major(grad_output_fp8_row_major), (
202+
assert not _is_column_major(grad_output_data_row_major), (
203203
"grad_output must be row-major for grad_A = grad_output @ B"
204204
)
205-
assert _is_column_major(B_fp8_col_major), (
205+
assert _is_column_major(B_data_col_major), (
206206
"B must be column-major for grad_A = grad_output @ B"
207207
)
208208

@@ -212,8 +212,8 @@ def backward(ctx, grad_output: torch.Tensor):
212212
grad_output_scales = grad_output_scales.squeeze(-1)
213213
B_scales = B_scales.squeeze(1)
214214
grad_A = torch._scaled_grouped_mm(
215-
grad_output_fp8_row_major,
216-
B_fp8_col_major,
215+
grad_output_data_row_major,
216+
B_data_col_major,
217217
grad_output_scales.reciprocal(),
218218
B_scales.reciprocal(),
219219
offs,
@@ -227,18 +227,18 @@ def backward(ctx, grad_output: torch.Tensor):
227227
# Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM
228228
# needed for grad_B: grad_output_t @ A
229229
# Use transpose method to avoid uncoalesced memory accesses.
230-
grad_out_fp8_colwise, grad_out_scales = triton_fp8_per_group_colwise_scales(
230+
grad_out_data_colwise, grad_out_scales = triton_fp8_per_group_colwise_scales(
231231
grad_output.t()
232232
.contiguous()
233233
.t(), # Quantization is over 2x faster when input is col major, even with this transformation
234234
offs,
235235
torch.float8_e4m3fn,
236236
round_scales_to_power_of_2=True,
237237
)
238-
grad_output_t_fp8_row_major = grad_out_fp8_colwise.t()
238+
grad_output_t_data_row_major = grad_out_data_colwise.t()
239239
grad_output_t_scales = grad_out_scales.t()
240240

241-
A_fp8_col_major, A_scales = triton_fp8_per_group_colwise_scales(
241+
A_data_col_major, A_scales = triton_fp8_per_group_colwise_scales(
242242
A.t()
243243
.contiguous()
244244
.t(), # Quantization is over 2x faster when input is col major, even with this transformation
@@ -249,19 +249,19 @@ def backward(ctx, grad_output: torch.Tensor):
249249

250250
# Compute grad_B = grad_output_t @ A.
251251
# grad_B = grad_output_t @ A
252-
assert not _is_column_major(grad_output_t_fp8_row_major), (
252+
assert not _is_column_major(grad_output_t_data_row_major), (
253253
"grad_output_t must be row-major for grad_B = grad_output_t @ A"
254254
)
255-
assert _is_column_major(A_fp8_col_major), (
255+
assert _is_column_major(A_data_col_major), (
256256
"A must be column-major for grad_B = grad_output_t @ A"
257257
)
258258

259259
# Per-token group scales computed via triton kernels above do not have
260260
# the empty dim like the scales computed via tensor_to_scale, so we need
261261
# don't need to squeeze here.
262262
grad_B = torch._scaled_grouped_mm(
263-
grad_output_t_fp8_row_major,
264-
A_fp8_col_major,
263+
grad_output_t_data_row_major,
264+
A_data_col_major,
265265
grad_output_t_scales.reciprocal(),
266266
A_scales.reciprocal(),
267267
offs,
@@ -295,13 +295,13 @@ def forward(
295295
ctx.out_dtype = out_dtype
296296
ctx.emulated = emulated
297297

298-
# A_mx shape: (M, K)
298+
# A_fp8 shape: (M, K)
299299
# A_scale shape: (M, K//block_size)
300-
A_scale, A_mx = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
300+
A_scale, A_fp8 = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
301301

302-
# B_mx shape: (E, N, K)
302+
# B_fp8 shape: (E, N, K)
303303
# B_scale shape: (E, N, K//block_size)
304-
B_scales, B_mx = to_mx(
304+
B_scales, B_fp8 = to_mx(
305305
B_t.transpose(-2, -1),
306306
elem_dtype=torch.float8_e4m3fn,
307307
block_size=block_size,
@@ -315,9 +315,9 @@ def forward(
315315
else fbgemm_mxfp8_grouped_mm_2d_3d
316316
)
317317
out = mxfp8_2d_3d_grouped_mm(
318-
A_mx,
318+
A_fp8,
319319
A_scale,
320-
B_mx,
320+
B_fp8,
321321
B_scales,
322322
offs=offs,
323323
block_size=block_size,
@@ -332,15 +332,15 @@ def backward(ctx, grad_out: torch.Tensor):
332332
out_dtype = ctx.out_dtype
333333
emulated = ctx.emulated
334334

335-
# grad_out_mx shape: (M, N)
335+
# grad_out_fp8 shape: (M, N)
336336
# grad_out_scale shape: (M, N//block_size)
337-
grad_out_scale, grad_out_mx = to_mx(
337+
grad_out_scale, grad_out_fp8 = to_mx(
338338
grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size
339339
)
340340

341-
# B_mx shape: (E, K, N)
341+
# B_fp8 shape: (E, K, N)
342342
# B_scale shape: (E, K, N//block_size)
343-
B_scales, B_mx = to_mx(
343+
B_scales, B_fp8 = to_mx(
344344
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
345345
B_t.contiguous(),
346346
elem_dtype=torch.float8_e4m3fn,
@@ -354,43 +354,43 @@ def backward(ctx, grad_out: torch.Tensor):
354354
else fbgemm_mxfp8_grouped_mm_2d_3d
355355
)
356356
grad_A = mxfp8_2d_3d_grouped_mm(
357-
grad_out_mx,
357+
grad_out_fp8,
358358
grad_out_scale,
359-
B_mx,
359+
B_fp8,
360360
B_scales,
361361
offs=offs,
362362
out_dtype=out_dtype,
363363
)
364364

365-
# grad_out_t_mx shape: (N, M)
365+
# grad_out_t_fp8 shape: (N, M)
366366
# grad_out_t_scales shape: (N, M//block_size)
367-
grad_out_t_scales, grad_out_t_mx = to_mx(
367+
grad_out_t_scales, grad_out_t_fp8 = to_mx(
368368
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
369369
grad_out.transpose(-2, -1).contiguous(),
370370
elem_dtype=torch.float8_e4m3fn,
371371
block_size=block_size,
372372
)
373373

374374
# Transpose A so we can scale along the M dimension, then un-transpose.
375-
# A_t_mx shape: (K, M)
375+
# A_t_fp8 shape: (K, M)
376376
# A_t_scales shape: (K, M//block_size)
377-
A_t_scales, A_t_mx = to_mx(
377+
A_t_scales, A_t_fp8 = to_mx(
378378
A.transpose(-2, -1).contiguous(),
379379
elem_dtype=torch.float8_e4m3fn,
380380
block_size=block_size,
381381
)
382382

383-
# A_mx shape = (M, K)
384-
A_mx = A_t_mx.transpose(-2, -1)
383+
# A_fp8 shape = (M, K)
384+
A_fp8 = A_t_fp8.transpose(-2, -1)
385385

386386
# A_scales shape = (M//block_size, K)
387387
A_scales = A_t_scales.transpose(-2, -1)
388388

389389
# grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
390390
grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d(
391-
grad_out_t_mx,
391+
grad_out_t_fp8,
392392
grad_out_t_scales,
393-
A_mx,
393+
A_fp8,
394394
A_scales,
395395
offs=offs,
396396
)
@@ -402,64 +402,64 @@ def backward(ctx, grad_out: torch.Tensor):
402402

403403

404404
def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
405-
A_mx: torch.Tensor,
405+
A_fp8: torch.Tensor,
406406
A_scale: torch.Tensor,
407-
B_mx: torch.Tensor,
407+
B_fp8: torch.Tensor,
408408
B_scale: torch.Tensor,
409409
offs: Optional[torch.Tensor] = None,
410410
out_dtype: Optional[torch.dtype] = torch.bfloat16,
411411
block_size: int = 32,
412412
) -> torch.Tensor:
413-
assert A_mx.ndim == 2, f"A must be 2D, got {A_mx.ndim}"
414-
assert B_mx.ndim == 3, f"B must be 3D, got {B_mx.ndim}"
415-
assert A_scale.shape[0] == A_mx.shape[0], (
416-
f"A_scale must have same M dim as A_mx, got A={A_mx.shape} and A_scale={A_scale.shape}"
413+
assert A_fp8.ndim == 2, f"A must be 2D, got {A_fp8.ndim}"
414+
assert B_fp8.ndim == 3, f"B must be 3D, got {B_fp8.ndim}"
415+
assert A_scale.shape[0] == A_fp8.shape[0], (
416+
f"A_scale must have same M dim as A_fp8, got A={A_fp8.shape} and A_scale={A_scale.shape}"
417417
)
418-
assert A_scale.shape[1] == A_mx.shape[1] // block_size, (
419-
f"A_scale dim1 should be size K//block_size, got A={A_mx.shape} and A_scale={A_scale.shape}"
418+
assert A_scale.shape[1] == A_fp8.shape[1] // block_size, (
419+
f"A_scale dim1 should be size K//block_size, got A={A_fp8.shape} and A_scale={A_scale.shape}"
420420
)
421-
assert B_scale.shape[0] == B_mx.shape[0], (
422-
f"B_scale must have same E dim as B_mx, got B={B_mx.shape} and B_scale={B_scale.shape}"
421+
assert B_scale.shape[0] == B_fp8.shape[0], (
422+
f"B_scale must have same E dim as B_fp8, got B={B_fp8.shape} and B_scale={B_scale.shape}"
423423
)
424-
assert B_scale.shape[1] == B_mx.shape[1], (
425-
f"B_scale must have same N dim as B_mx, got B={B_mx.shape} and B_scale={B_scale.shape}"
424+
assert B_scale.shape[1] == B_fp8.shape[1], (
425+
f"B_scale must have same N dim as B_fp8, got B={B_fp8.shape} and B_scale={B_scale.shape}"
426426
)
427-
assert B_scale.shape[2] == B_mx.shape[2] // block_size, (
428-
f"B_scale dim2 should be size K//block_size, got B={B_mx.shape} and B_scale={B_scale.shape}"
427+
assert B_scale.shape[2] == B_fp8.shape[2] // block_size, (
428+
f"B_scale dim2 should be size K//block_size, got B={B_fp8.shape} and B_scale={B_scale.shape}"
429429
)
430430

431431
# Dequantize input
432-
# A_mx shape: (M, K)
432+
# A_fp8 shape: (M, K)
433433
# A_scale shape: (M, K//block_size)
434-
A_orig_shape = A_mx.shape
434+
A_orig_shape = A_fp8.shape
435435

436436
# Reshape to be able to do per-scaling group multiplication
437-
# A_mx shape: (M, K//block_size, block_size)
437+
# A_fp8 shape: (M, K//block_size, block_size)
438438
# A_scale shape: (M, K//block_size, 1)
439-
A_mx = A_mx.reshape(*A_mx.shape[:-1], A_mx.shape[-1] // block_size, block_size)
439+
A_fp8 = A_fp8.reshape(*A_fp8.shape[:-1], A_fp8.shape[-1] // block_size, block_size)
440440
A_scale = A_scale.unsqueeze(-1)
441441

442442
# Rescale and cast to bfloat16
443-
A = A_mx.to(torch.bfloat16) * A_scale.to(torch.bfloat16)
443+
A = A_fp8.to(torch.bfloat16) * A_scale.to(torch.bfloat16)
444444

445445
# Reshape back to original shape
446446
# A shape: (M, K)
447447
A = A.reshape(A_orig_shape)
448448

449449
# Dequantize weights
450450
# Tranpose to get block_size on rightmost dim
451-
# B_mx shape: (E, N, K)
451+
# B_fp8 shape: (E, N, K)
452452
# B_scale shape: (E, N, K//block_size)
453-
E, N, K = B_mx.shape
453+
E, N, K = B_fp8.shape
454454

455455
# Reshape to be able to do per-scaling group multiplication
456-
# B_mx shape: (E, N, K//block_size, block_size)
456+
# B_fp8 shape: (E, N, K//block_size, block_size)
457457
# B_scale shape: (E, N, K//block_size, 1)
458-
B_mx = B_mx.reshape(*B_mx.shape[:-1], B_mx.shape[-1] // block_size, block_size)
458+
B_fp8 = B_fp8.reshape(*B_fp8.shape[:-1], B_fp8.shape[-1] // block_size, block_size)
459459
B_scale = B_scale.unsqueeze(-1)
460460

461461
# Rescale and cast to bfloat16
462-
B = B_mx.to(torch.bfloat16) * B_scale.to(torch.bfloat16)
462+
B = B_fp8.to(torch.bfloat16) * B_scale.to(torch.bfloat16)
463463

464464
# Reshape back to original shape
465465
# B shape: (E, K, N)
@@ -471,27 +471,27 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
471471

472472

473473
def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
474-
A_mx: torch.Tensor, # (M, K)
474+
A_fp8: torch.Tensor, # (M, K)
475475
A_scale: torch.Tensor, # (M, K//block_size)
476-
B_mx: torch.Tensor, # (K, N)
476+
B_fp8: torch.Tensor, # (K, N)
477477
B_scale: torch.Tensor, # (K//block_size, N)
478478
offs: torch.Tensor,
479479
out_dtype: Optional[torch.dtype] = torch.bfloat16,
480480
block_size: int = 32,
481481
) -> torch.Tensor:
482-
assert A_mx.ndim == 2, "A must be 2D"
483-
assert B_mx.ndim == 2, "B must be 2D"
482+
assert A_fp8.ndim == 2, "A must be 2D"
483+
assert B_fp8.ndim == 2, "B must be 2D"
484484
A = torch.zeros(
485-
A_mx.shape,
485+
A_fp8.shape,
486486
dtype=torch.bfloat16,
487-
device=A_mx.device,
488-
requires_grad=A_mx.requires_grad,
487+
device=A_fp8.device,
488+
requires_grad=A_fp8.requires_grad,
489489
)
490490
B = torch.zeros(
491-
B_mx.shape,
491+
B_fp8.shape,
492492
dtype=torch.bfloat16,
493-
device=B_mx.device,
494-
requires_grad=B_mx.requires_grad,
493+
device=B_fp8.device,
494+
requires_grad=B_fp8.requires_grad,
495495
)
496496

497497
# Dequantize input per each scaling group
@@ -507,7 +507,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
507507
# -- Dequantize A tensor
508508
# A_group shape: (M, group_size)
509509
# A_scale shape: (M, group_size//block_size)
510-
A_group = A_mx[:, group_start_idx:group_end_idx]
510+
A_group = A_fp8[:, group_start_idx:group_end_idx]
511511
A_group_shape = A_group.shape
512512

513513
# Get scales for this group.
@@ -532,7 +532,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
532532

533533
# -- Dequantize B tensor
534534
# B_group shape is (group_size, N)
535-
B_group = B_mx[group_start_idx:group_end_idx, :]
535+
B_group = B_fp8[group_start_idx:group_end_idx, :]
536536
B_group_shape = B_group.shape
537537

538538
# Scales shape is (group_size//block_size, N)

0 commit comments

Comments
 (0)