@@ -45,6 +45,10 @@ using namespace mlir;
4545#define DBGSNL () (llvm::dbgs() << " \n " )
4646#define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
4747
48+ // ===----------------------------------------------------------------------===//
49+ // Utils
50+ // ===----------------------------------------------------------------------===//
51+
4852// / Returns a compressed mask for the emulated vector. For example, when
4953// / emulating an eight-element `i8` vector with `i32` (i.e. when the source
5054// / elements span two dest elements), this method compresses `vector<8xi1>`
@@ -300,6 +304,7 @@ namespace {
300304// ConvertVectorStore
301305// ===----------------------------------------------------------------------===//
302306
307+ // TODO: Document-me
303308struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
304309 using OpConversionPattern::OpConversionPattern;
305310
@@ -370,6 +375,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
370375// ConvertVectorMaskedStore
371376// ===----------------------------------------------------------------------===//
372377
378+ // TODO: Document-me
373379struct ConvertVectorMaskedStore final
374380 : OpConversionPattern<vector::MaskedStoreOp> {
375381 using OpConversionPattern::OpConversionPattern;
@@ -481,6 +487,7 @@ struct ConvertVectorMaskedStore final
481487// ConvertVectorLoad
482488// ===----------------------------------------------------------------------===//
483489
490+ // TODO: Document-me
484491struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
485492 using OpConversionPattern::OpConversionPattern;
486493
@@ -536,7 +543,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
536543 // compile time as they must be constants.
537544
538545 auto origElements = op.getVectorType ().getNumElements ();
539- bool isUnalignedEmulation = origElements % elementsPerContainerType != 0 ;
546+ // Note, per-element-alignment was already verified above.
547+ bool isFullyAligned = origElements % elementsPerContainerType == 0 ;
540548
541549 auto stridedMetadata =
542550 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -552,9 +560,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
552560 getAsOpFoldResult (adaptor.getIndices ()));
553561
554562 std::optional<int64_t > foldedIntraVectorOffset =
555- isUnalignedEmulation
556- ? getConstantIntValue (linearizedInfo.intraDataOffset )
557- : 0 ;
563+ isFullyAligned ? 0
564+ : getConstantIntValue (linearizedInfo.intraDataOffset );
558565
559566 // Always load enough elements which can cover the original elements.
560567 int64_t maxintraDataOffset =
@@ -571,7 +578,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
571578 result = dynamicallyExtractSubVector (
572579 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
573580 linearizedInfo.intraDataOffset , origElements);
574- } else if (isUnalignedEmulation ) {
581+ } else if (!isFullyAligned ) {
575582 result =
576583 staticallyExtractSubvector (rewriter, loc, op.getType (), result,
577584 *foldedIntraVectorOffset, origElements);
@@ -585,6 +592,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
585592// ConvertVectorMaskedLoad
586593// ===----------------------------------------------------------------------===//
587594
595+ // TODO: Document-me
588596struct ConvertVectorMaskedLoad final
589597 : OpConversionPattern<vector::MaskedLoadOp> {
590598 using OpConversionPattern::OpConversionPattern;
@@ -749,6 +757,7 @@ struct ConvertVectorMaskedLoad final
749757// ConvertVectorTransferRead
750758// ===----------------------------------------------------------------------===//
751759
760+ // TODO: Document-me
752761struct ConvertVectorTransferRead final
753762 : OpConversionPattern<vector::TransferReadOp> {
754763 using OpConversionPattern::OpConversionPattern;
@@ -777,7 +786,8 @@ struct ConvertVectorTransferRead final
777786
778787 auto origElements = op.getVectorType ().getNumElements ();
779788
780- bool isUnalignedEmulation = origElements % elementsPerContainerType != 0 ;
789+ // Note, per-element-alignment was already verified above.
790+ bool isFullyAligned = origElements % elementsPerContainerType == 0 ;
781791
782792 auto newPadding = rewriter.create <arith::ExtUIOp>(loc, newElementType,
783793 adaptor.getPadding ());
@@ -796,9 +806,8 @@ struct ConvertVectorTransferRead final
796806 getAsOpFoldResult (adaptor.getIndices ()));
797807
798808 std::optional<int64_t > foldedIntraVectorOffset =
799- isUnalignedEmulation
800- ? getConstantIntValue (linearizedInfo.intraDataOffset )
801- : 0 ;
809+ isFullyAligned ? 0
810+ : getConstantIntValue (linearizedInfo.intraDataOffset );
802811
803812 int64_t maxIntraDataOffset =
804813 foldedIntraVectorOffset.value_or (elementsPerContainerType - 1 );
@@ -822,7 +831,7 @@ struct ConvertVectorTransferRead final
822831 result = dynamicallyExtractSubVector (rewriter, loc, bitCast, zeros,
823832 linearizedInfo.intraDataOffset ,
824833 origElements);
825- } else if (isUnalignedEmulation ) {
834+ } else if (!isFullyAligned ) {
826835 result =
827836 staticallyExtractSubvector (rewriter, loc, op.getType (), result,
828837 *foldedIntraVectorOffset, origElements);
@@ -1506,33 +1515,34 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
15061515// / LLVM to scramble with peephole optimizations. Templated to choose between
15071516// / signed and unsigned conversions.
15081517// /
1509- // / For example (signed):
1518+ // / EXAMPLE 1 (signed):
15101519// / arith.extsi %in : vector<8xi4> to vector<8xi32>
1511- // / is rewriten as
1512- // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1513- // / %1 = arith.shli %0, 4 : vector<4xi8>
1514- // / %2 = arith.shrsi %1, 4 : vector<4xi8>
1515- // / %3 = arith.shrsi %0, 4 : vector<4xi8>
1516- // / %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1517- // / %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
1520+ // / is rewriten as:
1521+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1522+ // / %1 = arith.shli %0, 4 : vector<4xi8>
1523+ // / %2 = arith.shrsi %1, 4 : vector<4xi8>
1524+ // / %3 = arith.shrsi %0, 4 : vector<4xi8>
1525+ // / %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1526+ // / %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
15181527// /
1528+ // / EXAMPLE 2 (fp):
15191529// / arith.sitofp %in : vector<8xi4> to vector<8xf32>
1520- // / is rewriten as
1521- // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1522- // / %1 = arith.shli %0, 4 : vector<4xi8>
1523- // / %2 = arith.shrsi %1, 4 : vector<4xi8>
1524- // / %3 = arith.shrsi %0, 4 : vector<4xi8>
1525- // / %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1526- // / %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1530+ // / is rewriten as:
1531+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1532+ // / %1 = arith.shli %0, 4 : vector<4xi8>
1533+ // / %2 = arith.shrsi %1, 4 : vector<4xi8>
1534+ // / %3 = arith.shrsi %0, 4 : vector<4xi8>
1535+ // / %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1536+ // / %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
15271537// /
1528- // / Example (unsigned):
1538+ // / EXAMPLE 3 (unsigned):
15291539// / arith.extui %in : vector<8xi4> to vector<8xi32>
1530- // / is rewritten as
1531- // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1532- // / %1 = arith.andi %0, 15 : vector<4xi8>
1533- // / %2 = arith.shrui %0, 4 : vector<4xi8>
1534- // / %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
1535- // / %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1540+ // / is rewritten as:
1541+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1542+ // / %1 = arith.andi %0, 15 : vector<4xi8>
1543+ // / %2 = arith.shrui %0, 4 : vector<4xi8>
1544+ // / %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
1545+ // / %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
15361546// /
15371547template <typename ConversionOpType, bool isSigned>
15381548struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
@@ -1542,8 +1552,8 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
15421552 PatternRewriter &rewriter) const override {
15431553 // Verify the preconditions.
15441554 Value srcValue = conversionOp.getIn ();
1545- auto srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1546- auto dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
1555+ VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1556+ VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
15471557
15481558 if (failed (
15491559 commonConversionPrecondition (rewriter, dstVecType, conversionOp)))
@@ -1583,15 +1593,16 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
15831593// /
15841594// / For example:
15851595// / arith.trunci %in : vector<8xi32> to vector<8xi4>
1586- // / is rewriten as
15871596// /
1588- // / %cst = arith.constant dense<15> : vector<4xi8>
1589- // / %cst_0 = arith.constant dense<4> : vector<4xi8>
1590- // / %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1591- // / %2 = arith.andi %0, %cst : vector<4xi8>
1592- // / %3 = arith.shli %1, %cst_0 : vector<4xi8>
1593- // / %4 = arith.ori %2, %3 : vector<4xi8>
1594- // / %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
1597+ // / is rewriten as:
1598+ // /
1599+ // / %cst = arith.constant dense<15> : vector<4xi8>
1600+ // / %cst_0 = arith.constant dense<4> : vector<4xi8>
1601+ // / %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1602+ // / %2 = arith.andi %0, %cst : vector<4xi8>
1603+ // / %3 = arith.shli %1, %cst_0 : vector<4xi8>
1604+ // / %4 = arith.ori %2, %3 : vector<4xi8>
1605+ // / %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
15951606// /
15961607struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
15971608 using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
@@ -1635,10 +1646,11 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
16351646
16361647// / Rewrite a sub-byte vector transpose into a sequence of instructions that
16371648// / perform the transpose on wider (byte) element types.
1638- // / For example:
1649+ // /
1650+ // / EXAMPLE:
16391651// / %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
16401652// /
1641- // / is rewritten as:
1653+ // / is rewritten as:
16421654// /
16431655// / %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
16441656// / %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
@@ -1686,6 +1698,7 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
16861698// Public Interface Definition
16871699// ===----------------------------------------------------------------------===//
16881700
1701+ // The emulated type is inferred from the converted memref type.
16891702void vector::populateVectorNarrowTypeEmulationPatterns (
16901703 const arith::NarrowTypeEmulationConverter &typeConverter,
16911704 RewritePatternSet &patterns) {
@@ -1698,22 +1711,26 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
16981711
16991712void vector::populateVectorNarrowTypeRewritePatterns (
17001713 RewritePatternSet &patterns, PatternBenefit benefit) {
1714+ // TODO: Document what the emulated type is.
17011715 patterns.add <RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
17021716 RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext (),
17031717 benefit);
17041718
17051719 // Patterns for aligned cases. We set higher priority as they are expected to
17061720 // generate better performance for aligned cases.
1721+ // The emulated type is always i8.
17071722 patterns.add <RewriteAlignedSubByteIntExt<arith::ExtSIOp, /* isSigned=*/ true >,
17081723 RewriteAlignedSubByteIntExt<arith::SIToFPOp, /* isSigned=*/ true >,
17091724 RewriteAlignedSubByteIntTrunc>(patterns.getContext (),
17101725 benefit.getBenefit () + 1 );
1726+ // The emulated type is always i8.
17111727 patterns
17121728 .add <RewriteAlignedSubByteIntExt<arith::ExtUIOp, /* isSigned=*/ false >,
17131729 RewriteAlignedSubByteIntExt<arith::UIToFPOp, /* isSigned=*/ false >>(
17141730 patterns.getContext (), benefit.getBenefit () + 1 );
17151731}
17161732
1733+ // The emulated type is always i8.
17171734void vector::populateVectorTransposeNarrowTypeRewritePatterns (
17181735 RewritePatternSet &patterns, PatternBenefit benefit) {
17191736 patterns.add <RewriteVectorTranspose>(patterns.getContext (), benefit);
0 commit comments