@@ -204,38 +204,21 @@ func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
204204 return %res_0 : vector <4 xf32 >
205205}
206206
207-
208207// -----
209208
210209// CHECK-LABEL: func @scaled_mfma_ugly_shapes
211- // CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
212- // CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
213- // CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[0] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
214- // CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
215210// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
216211// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
217212// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
218213// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
219- func.func @scaled_mfma_ugly_shapes (%opA: vector <32 xf4 E2 M1 FN>, %opB: vector <32 xf4 E2 M1 FN>, %scalesA: vector <5 x5 xf8 E8 M0 FNU>, %scalesB: vector <7 x23 xf8 E8 M0 FNU>) -> (vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector < 4 x f32 >, vector < 4 x f32 >, vector < 4 x f32 >, vector < 4 x f32 > ) {
214+ func.func @scaled_mfma_ugly_shapes (%opA: vector <32 xf4 E2 M1 FN>, %opB: vector <32 xf4 E2 M1 FN>, %scalesA: vector <5 x5 xf8 E8 M0 FNU>, %scalesB: vector <7 x23 xf8 E8 M0 FNU>) -> (vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >) {
220215 %cst_0 = arith.constant dense <0.000000e+00 > : vector <4 xf32 >
221216 %cst_1 = arith.constant dense <5.877470e-39 > : vector <4 xf8 E8 M0 FNU>
222- %scaleA_0_0 = vector.extract %scalesA [0 , 0 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
223- %scaleA_0_1 = vector.extract %scalesA [1 , 0 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
224- %scaleA_0_2 = vector.extract %scalesA [2 , 0 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
225- %scaleA_0_3 = vector.extract %scalesA [3 , 0 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
226217 %scaleA_0_4 = vector.extract %scalesA [4 , 0 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
227218 %scaleA_0_5 = vector.extract %scalesA [4 , 1 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
228219 %scaleA_0_6 = vector.extract %scalesA [4 , 2 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
229220 %scaleA_0_7 = vector.extract %scalesA [4 , 3 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
230221
231- // idx = 138 + 8 = 146 => opsel = 2
232- %scaleB_6_8 = vector.extract %scalesB [6 , 8 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
233- // idx = 147 => opsel = 3
234- %scaleB_6_9 = vector.extract %scalesB [6 , 9 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
235- // idx = 148 => opsel = 0
236- %scaleB_6_10 = vector.extract %scalesB [6 , 10 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
237- // idx = 149 => opsel = 1
238- %scaleB_6_11 = vector.extract %scalesB [6 , 11 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
239222 // idx = 160 => opsel = 3 (last idx of last 4 bytes)
240223 %scaleB_6_22 = vector.extract %scalesB [6 , 22 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
241224 // idx = 159 => opsel = 3
@@ -245,31 +228,19 @@ func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
245228 // idx = 157 => opsel = 1
246229 %scaleB_6_19 = vector.extract %scalesB [6 , 19 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
247230
248- %sA_0_0 = vector.insert %scaleA_0_0 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
249- %sA_0_1 = vector.insert %scaleA_0_1 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
250- %sA_0_2 = vector.insert %scaleA_0_2 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
251- %sA_0_3 = vector.insert %scaleA_0_3 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
252231 %sA_0_4 = vector.insert %scaleA_0_4 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
253232 %sA_0_5 = vector.insert %scaleA_0_5 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
254233 %sA_0_6 = vector.insert %scaleA_0_6 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
255234 %sA_0_7 = vector.insert %scaleA_0_7 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
256235
257- %sB_6_8 = vector.insert %scaleB_6_8 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
258- %sB_6_9 = vector.insert %scaleB_6_9 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
259- %sB_6_10 = vector.insert %scaleB_6_10 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
260- %sB_6_11 = vector.insert %scaleB_6_11 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
261236 %sB_6_22 = vector.insert %scaleB_6_22 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
262237 %sB_6_21 = vector.insert %scaleB_6_21 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
263238 %sB_6_20 = vector.insert %scaleB_6_20 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
264239 %sB_6_19 = vector.insert %scaleB_6_19 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
265240
266- %res_0 = amdgpu.scaled_mfma (%sA_0_0 [0 ] * %opA ) * (%sB_6_8 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
267- %res_1 = amdgpu.scaled_mfma (%sA_0_1 [0 ] * %opA ) * (%sB_6_9 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
268- %res_2 = amdgpu.scaled_mfma (%sA_0_2 [0 ] * %opA ) * (%sB_6_10 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
269- %res_3 = amdgpu.scaled_mfma (%sA_0_3 [0 ] * %opA ) * (%sB_6_11 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
270241 %res_4 = amdgpu.scaled_mfma (%sA_0_4 [0 ] * %opA ) * (%sB_6_22 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
271242 %res_5 = amdgpu.scaled_mfma (%sA_0_5 [0 ] * %opA ) * (%sB_6_21 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
272243 %res_6 = amdgpu.scaled_mfma (%sA_0_6 [0 ] * %opA ) * (%sB_6_20 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
273244 %res_7 = amdgpu.scaled_mfma (%sA_0_7 [0 ] * %opA ) * (%sB_6_19 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
274- return %res_0 , %res_1 , %res_2 , %res_3 , % res_4 , %res_5 , %res_6 , %res_7 : vector < 4 x f32 >, vector < 4 x f32 >, vector < 4 x f32 >, vector < 4 x f32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >
245+ return %res_4 , %res_5 , %res_6 , %res_7 : vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >
275246}
0 commit comments