|
14 | 14 | from torchao.prototype.moe_training.kernels import ( |
15 | 15 | triton_fp8_col_major_jagged_colwise_scales, |
16 | 16 | triton_fp8_row_major_jagged_rowwise_scales, |
| 17 | + triton_fp8_rowwise_3d_transpose_rhs, |
17 | 18 | ) |
18 | 19 | from torchao.prototype.moe_training.utils import ( |
19 | 20 | _is_column_major, |
@@ -44,7 +45,7 @@ def _scaled_grouped_mm( |
44 | 45 | out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. |
45 | 46 | """ |
46 | 47 | # TODO: Remove once prototype is more mature. This is currently very useful for development and debugging. |
47 | | - logger.info("Using scaled_grouped_mm") |
| 48 | + # logger.info("Using scaled_grouped_mm") |
48 | 49 | return _Float8GroupedMM.apply( |
49 | 50 | A, |
50 | 51 | B_t, |
@@ -127,20 +128,11 @@ def forward( |
127 | 128 | # Precompute non-transposed B column-major for backward, to save memory by storing the |
128 | 129 | # low precision B tensor instead of the high precision B tensor. |
129 | 130 | # In the backward this is needed for grad_A: grad_output @ B. |
130 | | - B = B_t.contiguous().transpose(-2, -1) |
131 | | - |
132 | | - # - B shape: (E, N, K) |
133 | | - # - B scales must be computed rowwise keeping the outer/final dim, so: |
134 | | - # - B_scale shape: (E, 1, K) |
135 | | - B_scales = tensor_to_scale( |
136 | | - B, |
137 | | - torch.float8_e4m3fn, |
138 | | - scaling_granularity=ScalingGranularity.AXISWISE, |
139 | | - axiswise_dim=-2, |
| 131 | + B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs( |
| 132 | + B_t, |
| 133 | + output_dtype=torch.float8_e4m3fn, |
140 | 134 | round_scales_to_power_of_2=True, |
141 | 135 | ) |
142 | | - B_scaled = B.to(torch.float32) * B_scales |
143 | | - B_fp8_col_major = to_fp8_saturated(B_scaled, torch.float8_e4m3fn) |
144 | 136 |
|
145 | 137 | # Store what we need for backward. |
146 | 138 | ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs) |
|
0 commit comments