@@ -232,6 +232,14 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
232232 return static_cast<unsigned>(MemorySpace::Global);
233233 }
234234
235+ xegpu::DistributeLayoutAttr getLayoutAttr() {
236+ return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getType().getLayout());
237+ }
238+
239+ ArrayRef<int64_t> getDataShape() {
240+ return getTensorDescShape();
241+ }
242+
235243 }];
236244}
237245
@@ -262,6 +270,23 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
262270 xegpu::TensorDescType getTensorDescType() {
263271 return getTensorDesc().getType();
264272 }
273+
274+ SmallVector<OpFoldResult> getMixedOffsets() {
275+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
276+ auto dynamics = getOffsets();
277+ if (statics.size() == 0 && dynamics.size() == 0)
278+ return {};
279+ return getMixedValues(statics, dynamics, getContext());
280+ }
281+
282+ xegpu::DistributeLayoutAttr getLayoutAttr() {
283+ return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
284+ }
285+
286+ ArrayRef<int64_t> getDataShape() {
287+ return getTensorDescType().getShape();
288+ }
289+
265290 }];
266291
267292 let assemblyFormat = [{
@@ -343,6 +368,24 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
343368 xegpu::TensorDescType getTensorDescType() {
344369 return getTensorDesc().getType();
345370 }
371+
372+ SmallVector<OpFoldResult> getMixedOffsets() {
373+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
374+ auto dynamics = getOffsets();
375+ if (statics.size() == 0 && dynamics.size() == 0)
376+ return {};
377+ return getMixedValues(statics, dynamics, getContext());
378+ }
379+
380+ xegpu::DistributeLayoutAttr getLayoutAttr() {
381+ return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
382+ }
383+
384+ ArrayRef<int64_t> getDataShape() {
385+ return getTensorDescType().getShape();
386+ }
387+
388+
346389 }];
347390
348391 let assemblyFormat = [{
@@ -417,6 +460,23 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
417460 xegpu::TensorDescType getTensorDescType() {
418461 return getTensorDesc().getType();
419462 }
463+
464+ SmallVector<OpFoldResult> getMixedOffsets() {
465+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
466+ auto dynamics = getOffsets();
467+ if (statics.size() == 0 && dynamics.size() == 0)
468+ return {};
469+ return getMixedValues(statics, dynamics, getContext());
470+ }
471+
472+ xegpu::DistributeLayoutAttr getLayoutAttr() {
473+ return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
474+ }
475+
476+ ArrayRef<int64_t> getDataShape() {
477+ return getTensorDescType().getShape();
478+ }
479+
420480 }];
421481
422482 let assemblyFormat = [{
@@ -640,6 +700,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
640700 xegpu::TensorDescType getTensorDescType() {
641701 return dyn_cast<xegpu::TensorDescType>(getSourceType());
642702 }
703+
643704 }];
644705
645706 let assemblyFormat = [{
@@ -1150,7 +1211,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
11501211 let arguments = (ins XeGPU_MemDesc:$mem_desc,
11511212 Variadic<Index>: $offsets,
11521213 DenseI64ArrayAttr: $const_offsets,
1153- OptionalAttr<LayoutTrait >:$layout
1214+ OptionalAttr<DistributeLayoutAttr >:$layout
11541215 );
11551216 let results = (outs XeGPU_ValueType:$res);
11561217 let assemblyFormat = [{
@@ -1175,12 +1236,16 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
11751236
11761237 let builders = [
11771238 OpBuilder<(ins "Type":$res, "TypedValue<MemDescType>": $mem_desc,
1178- "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait ": $layout)>,
1239+ "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttr ": $layout)>,
11791240 ];
11801241 let extraClassDeclaration = [{
11811242 SmallVector<OpFoldResult> getMixedOffsets() {
11821243 return getMixedValues(getConstOffsets(), getOffsets(), getContext());
11831244 }
1245+
1246+ ArrayRef<int64_t> getDataShape() {
1247+ return getRes().getType().getShape();
1248+ }
11841249 }];
11851250
11861251 let hasVerifier = 1;
@@ -1194,7 +1259,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
11941259 XeGPU_MemDesc:$mem_desc,
11951260 Variadic<Index>: $offsets,
11961261 DenseI64ArrayAttr: $const_offsets,
1197- OptionalAttr<LayoutTrait >:$layout
1262+ OptionalAttr<DistributeLayoutAttr >:$layout
11981263 );
11991264 let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
12001265 prop-dict attr-dict `` `:` type(operands)}];
@@ -1213,12 +1278,17 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
12131278 }];
12141279 let builders = [
12151280 OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc,
1216- "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait ": $layout)>,
1281+ "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttr ": $layout)>,
12171282 ];
12181283 let extraClassDeclaration = [{
12191284 SmallVector<OpFoldResult> getMixedOffsets() {
12201285 return getMixedValues(getConstOffsets(), getOffsets(), getContext());
12211286 }
1287+
1288+ ArrayRef<int64_t> getDataShape() {
1289+ return getData().getType().getShape();
1290+ }
1291+
12221292 }];
12231293
12241294 let hasVerifier = 1;
0 commit comments