Skip to content

Commit 36fbc6a

Browse files
authored
[MLIR][XeGPU] Remove the transpose attribute from Gather/Scatter ops and Cleanup the documents (#145389)
1 parent 0faa181 commit 36fbc6a

File tree

11 files changed

+207
-289
lines changed

11 files changed

+207
-289
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 36 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,6 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
8080
information e.g., memref<?x?xf16>, the strides information has to be explicitly
8181
passed via the "strides" and "const_strides" argument.
8282

83-
In SIMT mode, tensor descriptor is augmented with `LayoutAttr` which describes the
84-
mapping of the tensor descriptor to the work items.
85-
8683
Example 1 (suppose the tensor shape inferred by the compiler is 8x16):
8784
```mlir
8885
%0 = memref.alloc() : memref<1024x1024xf32>
@@ -106,15 +103,6 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
106103
%c1 = arith.constant 1 : index
107104
%1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: ui64 -> TensorDesc<8x16xf32>
108105
```
109-
110-
Example 4 (SIMT mode):
111-
```mlir
112-
%0 = memref.alloc() : memref<1024x1024xf32>
113-
%c0 = arith.constant 0 : index
114-
%c1 = arith.constant 8 : index
115-
%1 = xegpu.create_nd_tdesc %0[%c0, %c0] : memref<1024x1024xf32>
116-
-> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
117-
```
118106
}];
119107

120108
let arguments = (ins
@@ -301,9 +289,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
301289
fp32 or fp64. It implies that vnni and transpose cannot exit at the
302290
same time.
303291

304-
In SIMT mode, LoadNdOp expects the tensor descriptor to be augmented with `LayoutAttr`
305-
which describes the mapping of the tensor to the work items. In this case, result
306-
vector represents the data to be loaded by each work-item.
292+
In SIMT mode, result vector represents the data to be loaded by each work-item.
307293

308294
Example 1:
309295
```mlir
@@ -317,8 +303,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
317303
```mlir
318304
xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint<cached>,
319305
l2_hint = #xegpu.cache_hint<uncached>}>
320-
: !xegpu.tensor_desc<8x16xf32,
321-
#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x1xf32>
306+
: !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
322307
```
323308

324309

@@ -359,9 +344,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
359344
of cache, L1, L2 and L3. If hardware does not have a correspoding cache,
360345
Corresponding cache hint attribute will be masked.
361346

362-
In SIMT mode, StoreNdOp expects the tensor descriptor to be augmented with `LayoutAttr`
363-
which describes the mapping of the tensor to the work items. In this case, input
364-
vector represents the data to be stored by each work-item.
347+
In SIMT mode, the input vector represents the data to be stored by each work-item.
365348

366349
Example 1:
367350
```mlir
@@ -375,8 +358,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
375358
xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
376359
l2_hint = #xegpu.cache_hint<write_back>,
377360
l3_hint = #xegpu.cache_hint<write_through>}
378-
: vector<8x1xf16>, !xegpu.tensor_desc<8x16xf16,
379-
#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
361+
: vector<8xf16>, !xegpu.tensor_desc<8x16xf16>
380362
```
381363

382364

@@ -410,15 +392,10 @@ def XeGPU_UpdateNdOffsetOp : XeGPU_Op<"update_nd_offset",
410392
The offsets are relative offset to the current position in the number
411393
of elements. It will result in a same type TensorDesc as the input.
412394

413-
Example 1:
395+
Example:
414396
```
415397
%2 = xegpu.update_nd_offset %1, [0, 16]: !xegpu.tensor_desc<8x16xf32>
416398
```
417-
Example 2 (SIMT mode):
418-
```
419-
%2 = xegpu.update_nd_offset %1, [0, 16]:
420-
!xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
421-
```
422399
}];
423400

424401
let arguments = (ins
@@ -476,11 +453,6 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
476453
match the dimension of offsets. It may also has a second dimension corresponding to
477454
the chunk_size if the chunk size is larger than 1.
478455

479-
In SIMT mode, similar to `create_nd_tdesc` the resulting tensor descriptor is augmented
480-
with `LayoutAttr` which describes the mapping of the tensor descriptor to the work items.
481-
In this case, the first dimension of the tensor descriptor represents the work-items, and
482-
the second dimension represents the chunk size.
483-
484456
Example 1: It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
485457
```mlir
486458
%a = memref.alloc() : memref<1024xf32>
@@ -505,15 +477,6 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
505477
%1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
506478
-> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>>
507479
```
508-
509-
Example 4: SIMT mode
510-
```mlir
511-
%0 = memref.alloc() : memref<1024xf32>
512-
%off = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
513-
%1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
514-
-> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>,
515-
#xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
516-
```
517480
}];
518481

519482
let arguments = (ins XeGPU_BaseAddrType: $source,
@@ -609,54 +572,44 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
609572
let description = [{ It (aka. load) load data per each work-item. The output
610573
describes the data being loaded at the subgroup level, so its size is
611574
consistent with the number of work-items in a subgroup. When the chunk size
612-
is larger than 2, the output vector is a 2D vector, with dim-1 correspoding
613-
to work-items, and dim-0 corresponding to the chunk size loaded by each work-item.
614-
Specially, there is a transpose effect on the result (as compared to the TensorDesc)
615-
due to the hardware implementation. Therefore, a transpose attribute is introduced
616-
on purpose, making sure users are aware of this implicit transformation.
617-
575+
is larger than 2, the output vector is a 2D vector, with dim-0 correspoding
576+
to work-items, and dim-1 corresponding to the chunk size loaded by each work-item.
618577
The mask operand masks out memory access so that it is safe to pass out-of-boundary
619578
addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
620579

621-
In SIMT mode, LoadGatherOp expects the tensor descriptor to be augmented with `LayoutAttr`
622-
which describes the mapping of the tensor to the work items. In this case, result vector
623-
represents the data to be loaded by each work-item. Each work-item recieves a `chunk_size`
624-
number of elements.
580+
In SIMT mode, the result vector represents the data to be loaded by each work-item.
581+
Each work-item recieves a `chunk_size` number of elements.
625582

626583
Example 1:
627584
```mlir
628-
%2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
629-
l2_hint = #xegpu.cache_hint<uncached>,
630-
l3_hint = #xegpu.cache_hint<uncached>}
585+
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
586+
l2_hint = #xegpu.cache_hint<uncached>,
587+
l3_hint = #xegpu.cache_hint<uncached>}>
631588
: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space=global>>,
632589
vector<16xi1> -> vector<16xf32>
633590
```
634591

635592
Example 2:
636593
```mlir
637-
%2 = xegpu.load %1, %0 {transpose,
638-
l1_hint = #xegpu.cache_hint<cached>,
639-
l2_hint = #xegpu.cache_hint<uncached>,
640-
l3_hint = #xegpu.cache_hint<uncached>}
594+
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
595+
l2_hint = #xegpu.cache_hint<uncached>,
596+
l3_hint = #xegpu.cache_hint<uncached>}>
641597
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
642-
vector<16xi1> -> vector<8x16xf32>
598+
vector<16xi1> -> vector<16x8xf32>
643599
```
644600
Example 3 (SIMT mode):
645601
```mlir
646-
%2 = xegpu.load %1, %0 {transpose,
647-
l1_hint = #xegpu.cache_hint<cached>,
648-
l2_hint = #xegpu.cache_hint<uncached>,
649-
l3_hint = #xegpu.cache_hint<uncached>}
650-
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>,
651-
!xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
652-
vector<16xi1> -> vector<8x1xf32>
602+
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
603+
l2_hint = #xegpu.cache_hint<uncached>,
604+
l3_hint = #xegpu.cache_hint<uncached>}>
605+
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
606+
vector<16xi1> -> vector<8xf32>
653607
```
654608

655609
}];
656610

657611
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
658612
XeGPU_MaskType: $mask,
659-
OptionalAttr<UnitAttr>: $transpose,
660613
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
661614
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
662615
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -699,44 +652,38 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
699652
has transpose effect, which is similar to `load_gather`. Therefore, a transpose attribute is
700653
introduced on purpose, making sure users are aware of this implicit transformation.
701654

702-
In SIMT mode, StoreScatterOp expects the tensor descriptor to be augmented with `LayoutAttr`
703-
which describes the mapping of the tensor to the work items. In this case, input vector
704-
represents the data to be stored by each work-item. Each work-item recieves a `chunk_size`
705-
number of elements.
655+
In SIMT mode, the input vector represents the data to be stored by each work-item.
656+
Each work-item stores a `chunk_size` number of elements.
706657

707658
Example 1:
708659
```mlir
709-
xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
710-
l2_hint = #xegpu.cache_hint<write_back>,
711-
l3_hint = #xegpu.cache_hint<write_through>}
660+
xegpu.store %0, %1, %2 <{l1_hint = #xegpu.cache_hint<uncached>,
661+
l2_hint = #xegpu.cache_hint<write_back>,
662+
l3_hint = #xegpu.cache_hint<write_through>}>
712663
: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered_tdesc_attr<>>, vector<16xi1>
713664
```
714665

715666
Example 2:
716667
```mlir
717-
xegpu.store %0, %1, %2 {transpose,
718-
l1_hint = #xegpu.cache_hint<uncached>,
719-
l2_hint = #xegpu.cache_hint<write_back>,
720-
l3_hint = #xegpu.cache_hint<write_through>}
721-
: vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
668+
xegpu.store %0, %1, %2 <{l1_hint = #xegpu.cache_hint<uncached>,
669+
l2_hint = #xegpu.cache_hint<write_back>,
670+
l3_hint = #xegpu.cache_hint<write_through>}>
671+
: vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
722672
```
673+
723674
Example 3 (SIMT mode):
724675
```mlir
725-
xegpu.store %0, %1, %2 {transpose,
726-
l1_hint = #xegpu.cache_hint<uncached>,
727-
l2_hint = #xegpu.cache_hint<write_back>,
728-
l3_hint = #xegpu.cache_hint<write_through>}
729-
: vector<8x1xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>,
730-
!xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> vector<16xi1>
676+
xegpu.store %0, %1, %2 <{l1_hint = #xegpu.cache_hint<uncached>,
677+
l2_hint = #xegpu.cache_hint<write_back>,
678+
l3_hint = #xegpu.cache_hint<write_through>}>
679+
: vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1>
731680
```
732-
733681
}];
734682

735683
let arguments = (ins
736684
XeGPU_ValueType: $value,
737685
XeGPU_TensorDesc: $TensorDesc,
738686
XeGPU_MaskType: $mask,
739-
OptionalAttr<UnitAttr>: $transpose,
740687
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
741688
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
742689
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -773,20 +720,13 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
773720
update the offset per work-item, so its offsets contains values representing
774721
shifts for each work-item.
775722

776-
Example 1:
723+
Example:
777724
```mlir
778725
%off = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
779726
%2 = xegpu.update_offset %1, %off :
780727
!xegpu.tensor_desc<4x2xf32, #xegpu.scattered_tdesc_attr<chunk_size=2>>, vector<4xindex>
781728
```
782729

783-
Example 2 (SIMT mode):
784-
```mlir
785-
%off = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
786-
%2 = xegpu.update_offset %1, %off :
787-
!xegpu.tensor_desc<4x2xf32, #xegpu.scattered_tdesc_attr<chunk_size=2>,
788-
#xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>, vector<4xindex>
789-
```
790730
}];
791731

792732
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ constexpr unsigned packedSizeInBitsForDefault =
3535
16; // Minimum packing size per register for DPAS A.
3636
constexpr unsigned packedSizeInBitsForDpasB =
3737
32; // Minimum packing size per register for DPAS B.
38+
constexpr unsigned packedSizeInBitsForGatherScatter =
39+
32; // Minimum packing size per register for Gather and Scatter ops.
3840
} // namespace targetinfo
3941
} // namespace xegpu
4042

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/Utils/IndexingUtils.h"
1010
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
11+
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
1112
#include "mlir/IR/Builders.h"
1213
#include "mlir/IR/DialectImplementation.h"
1314
#include "llvm/ADT/TypeSwitch.h"
@@ -309,11 +310,23 @@ LogicalResult TensorDescType::verify(
309310
llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
310311
mlir::Attribute encoding, mlir::Attribute layout) {
311312
size_t rank = shape.size();
312-
// Low-precision types are packed in 32-bit units.
313-
int32_t packingFactor = 32 / elementType.getIntOrFloatBitWidth();
314313
if (rank != 1 && rank != 2)
315314
return emitError() << "expected 1D or 2D tensor";
316315

316+
auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
317+
if (blockAttr) {
318+
MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
319+
if (rank == 2 && memorySpaceAttr &&
320+
memorySpaceAttr.getValue() == MemorySpace::SLM)
321+
return emitError() << "SLM is not supported for 2D block tensor";
322+
}
323+
324+
// for gather and scatter ops, Low-precision types are packed in 32-bit units.
325+
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
326+
int chunkAlignmentFactor =
327+
bitWidth < targetinfo::packedSizeInBitsForGatherScatter
328+
? targetinfo::packedSizeInBitsForGatherScatter / bitWidth
329+
: 1;
317330
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
318331
if (scatterAttr) {
319332
// Expected tensor ranks for scattered data:
@@ -329,21 +342,13 @@ LogicalResult TensorDescType::verify(
329342
if (chunkSize > 1) {
330343
if (shape.back() != chunkSize)
331344
return emitError() << "expected tensor shape[1] to match chunk size";
332-
if (shape.back() % packingFactor != 0)
333-
return emitError()
334-
<< "expected tensor shape[1] to be a multiple of packing factor "
335-
<< packingFactor;
345+
if (shape.back() % chunkAlignmentFactor != 0)
346+
return emitError() << "expected tensor shape[1] to be a multiple of "
347+
"chunk alignment factor "
348+
<< chunkAlignmentFactor;
336349
}
337350
}
338351

339-
auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
340-
if (blockAttr) {
341-
MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
342-
if (rank == 2 && memorySpaceAttr &&
343-
memorySpaceAttr.getValue() == MemorySpace::SLM)
344-
return emitError() << "SLM is not supported for 2D block tensor";
345-
}
346-
347352
auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
348353
if (layoutAttr) {
349354
if (rank != (size_t)layoutAttr.getRank())
@@ -360,7 +365,7 @@ LogicalResult TensorDescType::verify(
360365
if (rank > 1 && laneData[0] != 1)
361366
return emitError()
362367
<< "cannot map over non-contiguous scattered row elements";
363-
if (laneData[rank - 1] != packingFactor)
368+
if (laneData[rank - 1] != chunkAlignmentFactor)
364369
return emitError() << "work item data mapping must match the number of "
365370
"contiguous elements";
366371
}

0 commit comments

Comments
 (0)