22
33gpu.module @create_nd_tdesc {
44 // CHECK-LABEL: gpu.func @create_nd_tdesc
5- // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32 , 1>, %[[ARG1:.*]]: ui64,
5+ // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32 , 1>, %[[ARG1:.*]]: ui64,
66 // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
7- gpu.func @create_nd_tdesc (%src: memref <8 x 16 x f32 , 1 >, %ptr: ui64 , %shape1: index , %shape2: index ,
7+ gpu.func @create_nd_tdesc (%src: memref <16 x 32 x f32 , 1 >, %ptr: ui64 , %shape1: index , %shape2: index ,
88 %stride1: index , %stride2: index , %offset1: index , %offset2: index ) kernel {
99 // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
1010 // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
@@ -23,35 +23,35 @@ gpu.module @create_nd_tdesc {
2323 %ptr_tdesc = xegpu.create_nd_tdesc %ptr , shape :[%shape1 , %shape2 ], strides :[%stride1 , %stride2 ]
2424 : ui64 -> !xegpu.tensor_desc <8 x16 xf32 >
2525
26- // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32 , 1> to memref<8x16xf32 >
27- %srcce = memref.memory_space_cast %src : memref <8 x 16 x f32 , 1 > to memref <8 x 16 x f32 >
26+ // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32 , 1> to memref<16x32xf32 >
27+ %srcce = memref.memory_space_cast %src : memref <16 x 32 x f32 , 1 > to memref <16 x 32 x f32 >
2828
2929 // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
30- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32 > -> index
30+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32 > -> index
3131 // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
3232 // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
33+ // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
34+ // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
3335 // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
34- // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %c16_i64 : i64 to i32
35- // CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64
36- // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %c8_i64 : i64 to i32
36+ // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
3737 // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
3838 // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
3939 // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
4040 // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
4141 // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
4242 // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
4343 // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
44- // CHECK: %[[VAR20 :.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
45- %src_tdesc = xegpu.create_nd_tdesc %srcce : memref <8 x 16 x f32 > -> !xegpu.tensor_desc <8 x16 xf32 >
44+ // CHECK: %[[PAYLOAD :.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
45+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref <16 x 32 x f32 > -> !xegpu.tensor_desc <8 x16 xf32 >
4646
4747 // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
48- // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32 > -> index
48+ // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32 > -> index
4949 // CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
5050 // CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
51- // CHECK: %[[C16_I64_6 :.*]] = arith.constant 16 : i64
52- // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C16_I64_6 ]] : i64 to i32
53- // CHECK: %[[C8_I64_7 :.*]] = arith.constant 8 : i64
54- // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C8_I64_7 ]] : i64 to i32
51+ // CHECK: %[[C32_I64_6 :.*]] = arith.constant 32 : i64
52+ // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6 ]] : i64 to i32
53+ // CHECK: %[[C16_I64_7 :.*]] = arith.constant 16 : i64
54+ // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7 ]] : i64 to i32
5555 // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
5656 // CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
5757 // CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
@@ -60,7 +60,21 @@ gpu.module @create_nd_tdesc {
6060 // CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
6161 // CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
6262 // CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
63- %src_tdesc2 = xegpu.create_nd_tdesc %srcce [%offset1 , %offset2 ] : memref <8 x16 xf32 > -> !xegpu.tensor_desc <8 x16 xf32 >
63+ %src_tdesc2 = xegpu.create_nd_tdesc %srcce [%offset1 , %offset2 ] : memref <16 x32 xf32 > -> !xegpu.tensor_desc <8 x16 xf32 >
64+
65+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
66+ %c8 = arith.constant 8 : index
67+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
68+ %c16 = arith.constant 16 : index
69+ // CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
70+ // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
71+ // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
72+ // CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
73+ // CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
74+ // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
75+ // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
76+ // CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
77+ %updated_tdesc = xegpu.update_nd_offset %src_tdesc , [%c8 , %c16 ] : !xegpu.tensor_desc <8 x16 xf32 >
6478 gpu.return
6579 }
6680}
0 commit comments