@@ -48,6 +48,10 @@ using namespace mlir;
4848using VectorValue = TypedValue<VectorType>;
4949using MemRefValue = TypedValue<MemRefType>;
5050
51+ // ===----------------------------------------------------------------------===//
52+ // Utils
53+ // ===----------------------------------------------------------------------===//
54+
5155// / Returns a compressed mask for the emulated vector. For example, when
5256// / emulating an eight-element `i8` vector with `i32` (i.e. when the source
5357// / elements span two dest elements), this method compresses `vector<8xi1>`
@@ -407,6 +411,7 @@ namespace {
407411// ConvertVectorStore
408412// ===----------------------------------------------------------------------===//
409413
414+ // TODO: Document-me
410415struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
411416 using OpConversionPattern::OpConversionPattern;
412417
@@ -632,6 +637,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
632637// ConvertVectorMaskedStore
633638// ===----------------------------------------------------------------------===//
634639
640+ // TODO: Document-me
635641struct ConvertVectorMaskedStore final
636642 : OpConversionPattern<vector::MaskedStoreOp> {
637643 using OpConversionPattern::OpConversionPattern;
@@ -745,6 +751,7 @@ struct ConvertVectorMaskedStore final
745751// ConvertVectorLoad
746752// ===----------------------------------------------------------------------===//
747753
754+ // TODO: Document-me
748755struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
749756 using OpConversionPattern::OpConversionPattern;
750757
@@ -802,7 +809,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
802809 // compile time as they must be constants.
803810
804811 auto origElements = op.getVectorType ().getNumElements ();
805- bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0 ;
812+ // Note, per-element-alignment was already verified above.
813+ bool isFullyAligned = origElements % emulatedPerContainerElem == 0 ;
806814
807815 auto stridedMetadata =
808816 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -818,9 +826,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
818826 getAsOpFoldResult (adaptor.getIndices ()));
819827
820828 std::optional<int64_t > foldedIntraVectorOffset =
821- isAlignedEmulation
822- ? 0
823- : getConstantIntValue (linearizedInfo.intraDataOffset );
829+ isFullyAligned ? 0
830+ : getConstantIntValue (linearizedInfo.intraDataOffset );
824831
825832 // Always load enough elements which can cover the original elements.
826833 int64_t maxintraDataOffset =
@@ -834,10 +841,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
834841 if (!foldedIntraVectorOffset) {
835842 auto resultVector = rewriter.create <arith::ConstantOp>(
836843 loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
837- result = dynamicallyExtractSubVector (rewriter, loc, result, resultVector,
838- linearizedInfo. intraDataOffset ,
839- origElements);
840- } else if (!isAlignedEmulation ) {
844+ result = dynamicallyExtractSubVector (
845+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector ,
846+ linearizedInfo. intraDataOffset , origElements);
847+ } else if (!isFullyAligned ) {
841848 result = staticallyExtractSubvector (
842849 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
843850 }
@@ -850,6 +857,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
850857// ConvertVectorMaskedLoad
851858// ===----------------------------------------------------------------------===//
852859
860+ // TODO: Document-me
853861struct ConvertVectorMaskedLoad final
854862 : OpConversionPattern<vector::MaskedLoadOp> {
855863 using OpConversionPattern::OpConversionPattern;
@@ -1016,6 +1024,7 @@ struct ConvertVectorMaskedLoad final
10161024// ConvertVectorTransferRead
10171025// ===----------------------------------------------------------------------===//
10181026
1027+ // TODO: Document-me
10191028struct ConvertVectorTransferRead final
10201029 : OpConversionPattern<vector::TransferReadOp> {
10211030 using OpConversionPattern::OpConversionPattern;
@@ -1046,7 +1055,8 @@ struct ConvertVectorTransferRead final
10461055
10471056 auto origElements = op.getVectorType ().getNumElements ();
10481057
1049- bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0 ;
1058+ // Note, per-element-alignment was already verified above.
1059+ bool isFullyAligned = origElements % emulatedPerContainerElem == 0 ;
10501060
10511061 auto newPadding = rewriter.create <arith::ExtUIOp>(loc, containerElemTy,
10521062 adaptor.getPadding ());
@@ -1065,9 +1075,8 @@ struct ConvertVectorTransferRead final
10651075 getAsOpFoldResult (adaptor.getIndices ()));
10661076
10671077 std::optional<int64_t > foldedIntraVectorOffset =
1068- isAlignedEmulation
1069- ? 0
1070- : getConstantIntValue (linearizedInfo.intraDataOffset );
1078+ isFullyAligned ? 0
1079+ : getConstantIntValue (linearizedInfo.intraDataOffset );
10711080
10721081 int64_t maxIntraDataOffset =
10731082 foldedIntraVectorOffset.value_or (emulatedPerContainerElem - 1 );
@@ -1091,7 +1100,7 @@ struct ConvertVectorTransferRead final
10911100 result = dynamicallyExtractSubVector (rewriter, loc, bitCast, zeros,
10921101 linearizedInfo.intraDataOffset ,
10931102 origElements);
1094- } else if (!isAlignedEmulation ) {
1103+ } else if (!isFullyAligned ) {
10951104 result = staticallyExtractSubvector (
10961105 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
10971106 }
@@ -1774,33 +1783,34 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
17741783// / LLVM to scramble with peephole optimizations. Templated to choose between
17751784// / signed and unsigned conversions.
17761785// /
1777- // / For example (signed):
1786+ // / EXAMPLE 1 (signed):
17781787// / arith.extsi %in : vector<8xi4> to vector<8xi32>
1779- // / is rewriten as
1780- // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1781- // / %1 = arith.shli %0, 4 : vector<4xi8>
1782- // / %2 = arith.shrsi %1, 4 : vector<4xi8>
1783- // / %3 = arith.shrsi %0, 4 : vector<4xi8>
1784- // / %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1785- // / %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
1788+ // / is rewriten as:
1789+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1790+ // / %1 = arith.shli %0, 4 : vector<4xi8>
1791+ // / %2 = arith.shrsi %1, 4 : vector<4xi8>
1792+ // / %3 = arith.shrsi %0, 4 : vector<4xi8>
1793+ // / %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1794+ // / %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
17861795// /
1796+ // / EXAMPLE 2 (fp):
17871797// / arith.sitofp %in : vector<8xi4> to vector<8xf32>
1788- // / is rewriten as
1789- // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1790- // / %1 = arith.shli %0, 4 : vector<4xi8>
1791- // / %2 = arith.shrsi %1, 4 : vector<4xi8>
1792- // / %3 = arith.shrsi %0, 4 : vector<4xi8>
1793- // / %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1794- // / %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1798+ // / is rewriten as:
1799+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1800+ // / %1 = arith.shli %0, 4 : vector<4xi8>
1801+ // / %2 = arith.shrsi %1, 4 : vector<4xi8>
1802+ // / %3 = arith.shrsi %0, 4 : vector<4xi8>
1803+ // / %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1804+ // / %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
17951805// /
1796- // / Example (unsigned):
1806+ // / EXAMPLE 3 (unsigned):
17971807// / arith.extui %in : vector<8xi4> to vector<8xi32>
1798- // / is rewritten as
1799- // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1800- // / %1 = arith.andi %0, 15 : vector<4xi8>
1801- // / %2 = arith.shrui %0, 4 : vector<4xi8>
1802- // / %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
1803- // / %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1808+ // / is rewritten as:
1809+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1810+ // / %1 = arith.andi %0, 15 : vector<4xi8>
1811+ // / %2 = arith.shrui %0, 4 : vector<4xi8>
1812+ // / %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
1813+ // / %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
18041814// /
18051815template <typename ConversionOpType, bool isSigned>
18061816struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
@@ -1810,8 +1820,8 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
18101820 PatternRewriter &rewriter) const override {
18111821 // Verify the preconditions.
18121822 Value srcValue = conversionOp.getIn ();
1813- auto srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1814- auto dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
1823+ VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1824+ VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
18151825
18161826 if (failed (
18171827 commonConversionPrecondition (rewriter, dstVecType, conversionOp)))
@@ -1851,15 +1861,16 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
18511861// /
18521862// / For example:
18531863// / arith.trunci %in : vector<8xi32> to vector<8xi4>
1854- // / is rewriten as
18551864// /
1856- // / %cst = arith.constant dense<15> : vector<4xi8>
1857- // / %cst_0 = arith.constant dense<4> : vector<4xi8>
1858- // / %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1859- // / %2 = arith.andi %0, %cst : vector<4xi8>
1860- // / %3 = arith.shli %1, %cst_0 : vector<4xi8>
1861- // / %4 = arith.ori %2, %3 : vector<4xi8>
1862- // / %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
1865+ // / is rewriten as:
1866+ // /
1867+ // / %cst = arith.constant dense<15> : vector<4xi8>
1868+ // / %cst_0 = arith.constant dense<4> : vector<4xi8>
1869+ // / %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1870+ // / %2 = arith.andi %0, %cst : vector<4xi8>
1871+ // / %3 = arith.shli %1, %cst_0 : vector<4xi8>
1872+ // / %4 = arith.ori %2, %3 : vector<4xi8>
1873+ // / %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
18631874// /
18641875struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
18651876 using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
@@ -1903,10 +1914,11 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
19031914
19041915// / Rewrite a sub-byte vector transpose into a sequence of instructions that
19051916// / perform the transpose on wider (byte) element types.
1906- // / For example:
1917+ // /
1918+ // / EXAMPLE:
19071919// / %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
19081920// /
1909- // / is rewritten as:
1921+ // / is rewritten as:
19101922// /
19111923// / %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
19121924// / %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
@@ -1954,6 +1966,7 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
19541966// Public Interface Definition
19551967// ===----------------------------------------------------------------------===//
19561968
1969+ // The emulated type is inferred from the converted memref type.
19571970void vector::populateVectorNarrowTypeEmulationPatterns (
19581971 const arith::NarrowTypeEmulationConverter &typeConverter,
19591972 RewritePatternSet &patterns) {
@@ -1966,22 +1979,26 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
19661979
19671980void vector::populateVectorNarrowTypeRewritePatterns (
19681981 RewritePatternSet &patterns, PatternBenefit benefit) {
1982+ // TODO: Document what the emulated type is.
19691983 patterns.add <RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
19701984 RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext (),
19711985 benefit);
19721986
19731987 // Patterns for aligned cases. We set higher priority as they are expected to
19741988 // generate better performance for aligned cases.
1989+ // The emulated type is always i8.
19751990 patterns.add <RewriteAlignedSubByteIntExt<arith::ExtSIOp, /* isSigned=*/ true >,
19761991 RewriteAlignedSubByteIntExt<arith::SIToFPOp, /* isSigned=*/ true >,
19771992 RewriteAlignedSubByteIntTrunc>(patterns.getContext (),
19781993 benefit.getBenefit () + 1 );
1994+ // The emulated type is always i8.
19791995 patterns
19801996 .add <RewriteAlignedSubByteIntExt<arith::ExtUIOp, /* isSigned=*/ false >,
19811997 RewriteAlignedSubByteIntExt<arith::UIToFPOp, /* isSigned=*/ false >>(
19821998 patterns.getContext (), benefit.getBenefit () + 1 );
19831999}
19842000
2001+ // The emulated type is always i8.
19852002void vector::populateVectorTransposeNarrowTypeRewritePatterns (
19862003 RewritePatternSet &patterns, PatternBenefit benefit) {
19872004 patterns.add <RewriteVectorTranspose>(patterns.getContext (), benefit);
0 commit comments