@@ -490,3 +490,30 @@ func.func @fold_transpose_into_transfer_read(%alloc: memref<64x128xf16>, %vector
490490}
491491
492492// -----
493+
494+ #map1 = affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>
495+ #map2 = affine_map <(d0 , d1 , d2 ) -> (d1 , d2 )>
496+ #map3 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
497+
498+ // CHECK-LABEL: func @cast_f16_to_f32_read
499+ // CHECK: %[[A:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
500+ // CHECK: %[[C:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
501+ // CHECK: %[[AE:.+]] = gpu.subgroup_mma_elementwise extf %[[A]] : (!gpu.mma_matrix<16x16xf16, "AOp">) -> !gpu.mma_matrix<16x16xf32, "AOp">
502+ // CHECK: %[[CE:.+]] = gpu.subgroup_mma_elementwise extf %[[C]] : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
503+ // CHECK: %[[B:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
504+ // CHECK: %[[BE:.+]] = gpu.subgroup_mma_elementwise extf %[[B]] : (!gpu.mma_matrix<16x16xf16, "BOp">) -> !gpu.mma_matrix<16x16xf32, "BOp">
505+ // CHECK: gpu.subgroup_mma_compute %[[AE]], %[[BE]], %[[CE]]
506+ func.func @cast_f16_to_f32_read (%arg0: memref <16 x16 xf16 >, %arg1: memref <16 x16 xf16 >, %arg2: memref <16 x16 xf16 >, %arg3: memref <16 x16 xf32 >) {
507+ %c0 = arith.constant 0 : index
508+ %cst = arith.constant 0.000000e+00 : f16
509+ %A = vector.transfer_read %arg0 [%c0 , %c0 ], %cst {in_bounds = [true , true ]} : memref <16 x16 xf16 >, vector <16 x16 xf16 >
510+ %B = vector.transfer_read %arg1 [%c0 , %c0 ], %cst {in_bounds = [true , true ]} : memref <16 x16 xf16 >, vector <16 x16 xf16 >
511+ %C = vector.transfer_read %arg2 [%c0 , %c0 ], %cst {in_bounds = [true , true ]} : memref <16 x16 xf16 >, vector <16 x16 xf16 >
512+ %Aext = arith.extf %A : vector <16 x16 xf16 > to vector <16 x16 xf32 >
513+ %Bext = arith.extf %B : vector <16 x16 xf16 > to vector <16 x16 xf32 >
514+ %Cext = arith.extf %C : vector <16 x16 xf16 > to vector <16 x16 xf32 >
515+ %D = vector.contract {index ing_maps = [#map1 , #map2 , #map3 ], iterator_types = [" parallel" , " parallel" , " reduction" ], kind = #vector.kind <add >}
516+ %Aext , %Bext , %Cext : vector <16 x16 xf32 >, vector <16 x16 xf32 > into vector <16 x16 xf32 >
517+ vector.transfer_write %D , %arg3 [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <16 x16 xf32 >, memref <16 x16 xf32 >
518+ return
519+ }
0 commit comments