@@ -19,8 +19,9 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>,
1919// CHECK-COUNT2: arith.addi {{.*}} : index
2020// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
2121// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
22- // CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
23- // CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
22+ // CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
23+ // CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
24+ // CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
2425// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
2526// CHECK: gpu.return %[[RES]] : vector<8xf32>
2627}
@@ -45,8 +46,9 @@ gpu.func @load_2D_memref(%source: memref<8x32xf32>,
4546// CHECK-COUNT1: arith.addi {{.*}} : index
4647// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
4748// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
48- // CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1]{{\]}} : memref<8x32xf32> into memref<256xf32>
49- // CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<256xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
49+ // CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x32xf32> -> index
50+ // CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
51+ // CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
5052// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
5153// CHECK: gpu.return %[[RES]] : vector<8xf32>
5254}
@@ -71,8 +73,9 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
7173// CHECK-COUNT2: arith.addi {{.*}} : index
7274// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
7375// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
74- // CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
75- // CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
76+ // CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
77+ // CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
78+ // CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
7679// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
7780// CHECK: gpu.return %[[RES]] : vector<8x16xf32>
7881}
@@ -98,8 +101,9 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
98101// CHECK-COUNT2: arith.addi {{.*}} : index
99102// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
100103// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
101- // CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
102- // CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
104+ // CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?xf32> -> index
105+ // CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
106+ // CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
103107// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
104108// CHECK: gpu.return %[[RES]] : vector<8x16xf32>
105109}
@@ -125,8 +129,9 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
125129// CHECK-COUNT2: arith.addi {{.*}} : index
126130// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
127131// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
128- // CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
129- // CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
132+ // CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x8x16xf32> -> index
133+ // CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
134+ // CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
130135// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
131136// CHECK: gpu.return %[[RES]] : vector<8x16xf32>
132137}
@@ -146,42 +151,37 @@ gpu.func @no_load_tensor(%source: tensor<32x64xf32>,
146151
147152// -----
148153gpu.module @xevm_module {
149- gpu.func @no_load_non_unit_inner_stride (
150- %source: memref <32 xf32 , strided <[?], offset : ?>>,
151- %off: index , %indices: vector <8 xindex >, %mask: vector <8 xi1 >,
152- %pass_thru: vector <8 xf32 >) -> vector <8 xf32 > {
153- %0 = vector.gather %source [%off ][%indices ], %mask , %pass_thru
154- : memref <32 xf32 , strided <[?], offset : ?>>, vector <8 xindex >, vector <8 xi1 >, vector <8 xf32 > into vector <8 xf32 >
155- gpu.return %0 : vector <8 xf32 >
156- }
157- // CHECK-LABEL: @no_load_non_unit_inner_stride(
158- // CHECK: vector.gather
154+ gpu.func @gather_from_subview (%source: memref <4096 x4096 xf16 >,
155+ %off1: index , %off2: index ,
156+ %indices: vector <8 xindex >,
157+ %mask: vector <8 xi1 >,
158+ %pass_thru: vector <8 xf16 >) -> vector <8 xf16 > {
159+ %subview = memref.subview %source [%off1 , %off2 ] [256 , 256 ] [1 , 1 ]
160+ : memref <4096 x4096 xf16 >
161+ to memref <256 x256 xf16 , strided <[4096 , 1 ], offset : ?>>
162+ %0 = vector.gather %subview [%off1 , %off2 ][%indices ], %mask , %pass_thru
163+ : memref <256 x256 xf16 , strided <[4096 , 1 ], offset : ?>>,
164+ vector <8 xindex >, vector <8 xi1 >, vector <8 xf16 >
165+ into vector <8 xf16 >
166+ gpu.return %0 : vector <8 xf16 >
167+ }
168+ // CHECK-LABEL: @gather_from_subview(
169+ // CHECK-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
170+ // CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
171+ // CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>,
172+ // CHECK-SAME: %[[MASK:.+]]: vector<8xi1>,
173+ // CHECK-SAME: %[[PASS:.+]]: vector<8xf16>) -> vector<8xf16> {
174+ // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
175+ // CHECK: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
176+ // CHECK: arith.muli {{.*}} : index
177+ // CHECK: arith.addi %[[OFFSET]]{{.*}} : index
178+ // CHECK: %[[BASE_OFF:.+]] = arith.addi {{.*}} : index
179+ // CHECK: %[[SPLAT:.+]] = vector.broadcast %[[BASE_OFF]] : index to vector<8xindex>
180+ // CHECK: %[[LIN:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
181+ // CHECK: %[[BASE_IDX:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
182+ // CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE_IDX]] : index to i64
183+ // CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64]]{{\[}}%[[LIN]]{{\]}}, %[[MASK]]
184+ // CHECK-SAME: : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
185+ // CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS]] : vector<8xi1>, vector<8xf16>
186+ // CHECK: gpu.return %[[RES]] : vector<8xf16>
159187}
160-
161- // -----
162- gpu.module @xevm_module {
163- gpu.func @load_1D_aligned (%source: memref <8 x16 x32 xf32 >,
164- %off1: index , %off2: index , %off3: index ,
165- %indices: vector <8 xindex >, %mask: vector <8 xi1 >,
166- %pass_thru: vector <8 xf32 >) -> vector <8 xf32 > {
167- %0 = vector.gather %source [%off1 , %off2 , %off3 ][%indices ], %mask ,
168- %pass_thru {alignment = 256 } : memref <8 x16 x32 xf32 >, vector <8 xindex >, vector <8 xi1 >, vector <8 xf32 > into vector <8 xf32 >
169- gpu.return %0 : vector <8 xf32 >
170- }
171- // CHECK-LABEL: @load_1D_aligned(
172- // CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
173- // CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
174- // CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>
175- // CHECK-SAME: %[[MASK:.+]]: vector<8xi1>
176- // CHECK-SAME: %[[PASS_THRU:.+]]: vector<8xf32>) -> vector<8xf32> {
177- // CHECK-COUNT2: arith.muli {{.*}} : index
178- // CHECK-COUNT2: arith.addi {{.*}} : index
179- // CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
180- // CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
181- // CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
182- // CHECK: %[[COLLAPSE_ALIGN:.+]] = memref.assume_alignment %[[COLLAPSE]], 256 : memref<4096xf32>
183- // CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_ALIGN]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
184- // CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
185- // CHECK: gpu.return %[[RES]] : vector<8xf32>
186- }
187-
0 commit comments