Skip to content

Commit b8732cc

Browse files
committed
[mlir][Vector] Update VectorEmulateNarrowType.cpp (3/N)
This is PR 3 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. 1. Replaces `isUnalignedEmulation` with `isFullyAligned` Note, `isUnalignedEmulation` is always computed following a "per-element-alignment" condition: ```cpp // Check per-element alignment. if (containerBits % emulatedBits != 0) { return rewriter.notifyMatchFailure( op, "impossible to pack emulated elements into container elements " "(bit-wise misalignment)"); } // (...) bool isUnalignedEmulation = origElements % emulatedPerContainerElem != 0; ``` Given that `isUnalignedEmulation` captures only one of two conditions required for "full alignment", it should be re-named as `isPartiallyUnalignedEmulation`. Instead, I've flipped the condition and renamed it as `isFullyAligned`: ```cpp bool isFullyAligned = origElements % emulatedPerContainerElem == 0; ``` 2. In addition: * Unifies various comments throughout the file (for consistency). * Adds new comments throughout the file and adds TODOs where high-level comments are missing.
1 parent aaeb0fb commit b8732cc

File tree

1 file changed

+64
-47
lines changed

1 file changed

+64
-47
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 64 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ using namespace mlir;
4848
using VectorValue = TypedValue<VectorType>;
4949
using MemRefValue = TypedValue<MemRefType>;
5050

51+
//===----------------------------------------------------------------------===//
52+
// Utils
53+
//===----------------------------------------------------------------------===//
54+
5155
/// Returns a compressed mask for the emulated vector. For example, when
5256
/// emulating an eight-element `i8` vector with `i32` (i.e. when the source
5357
/// elements span two dest elements), this method compresses `vector<8xi1>`
@@ -407,6 +411,7 @@ namespace {
407411
// ConvertVectorStore
408412
//===----------------------------------------------------------------------===//
409413

414+
// TODO: Document-me
410415
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
411416
using OpConversionPattern::OpConversionPattern;
412417

@@ -632,6 +637,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
632637
// ConvertVectorMaskedStore
633638
//===----------------------------------------------------------------------===//
634639

640+
// TODO: Document-me
635641
struct ConvertVectorMaskedStore final
636642
: OpConversionPattern<vector::MaskedStoreOp> {
637643
using OpConversionPattern::OpConversionPattern;
@@ -745,6 +751,7 @@ struct ConvertVectorMaskedStore final
745751
// ConvertVectorLoad
746752
//===----------------------------------------------------------------------===//
747753

754+
// TODO: Document-me
748755
struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
749756
using OpConversionPattern::OpConversionPattern;
750757

@@ -802,7 +809,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
802809
// compile time as they must be constants.
803810

804811
auto origElements = op.getVectorType().getNumElements();
805-
bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
812+
// Note, per-element-alignment was already verified above.
813+
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
806814

807815
auto stridedMetadata =
808816
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -818,9 +826,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
818826
getAsOpFoldResult(adaptor.getIndices()));
819827

820828
std::optional<int64_t> foldedIntraVectorOffset =
821-
isAlignedEmulation
822-
? 0
823-
: getConstantIntValue(linearizedInfo.intraDataOffset);
829+
isFullyAligned ? 0
830+
: getConstantIntValue(linearizedInfo.intraDataOffset);
824831

825832
// Always load enough elements which can cover the original elements.
826833
int64_t maxintraDataOffset =
@@ -834,10 +841,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
834841
if (!foldedIntraVectorOffset) {
835842
auto resultVector = rewriter.create<arith::ConstantOp>(
836843
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
837-
result = dynamicallyExtractSubVector(rewriter, loc, result, resultVector,
838-
linearizedInfo.intraDataOffset,
839-
origElements);
840-
} else if (!isAlignedEmulation) {
844+
result = dynamicallyExtractSubVector(
845+
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
846+
linearizedInfo.intraDataOffset, origElements);
847+
} else if (!isFullyAligned) {
841848
result = staticallyExtractSubvector(
842849
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
843850
}
@@ -850,6 +857,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
850857
// ConvertVectorMaskedLoad
851858
//===----------------------------------------------------------------------===//
852859

860+
// TODO: Document-me
853861
struct ConvertVectorMaskedLoad final
854862
: OpConversionPattern<vector::MaskedLoadOp> {
855863
using OpConversionPattern::OpConversionPattern;
@@ -1016,6 +1024,7 @@ struct ConvertVectorMaskedLoad final
10161024
// ConvertVectorTransferRead
10171025
//===----------------------------------------------------------------------===//
10181026

1027+
// TODO: Document-me
10191028
struct ConvertVectorTransferRead final
10201029
: OpConversionPattern<vector::TransferReadOp> {
10211030
using OpConversionPattern::OpConversionPattern;
@@ -1046,7 +1055,8 @@ struct ConvertVectorTransferRead final
10461055

10471056
auto origElements = op.getVectorType().getNumElements();
10481057

1049-
bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
1058+
// Note, per-element-alignment was already verified above.
1059+
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
10501060

10511061
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
10521062
adaptor.getPadding());
@@ -1065,9 +1075,8 @@ struct ConvertVectorTransferRead final
10651075
getAsOpFoldResult(adaptor.getIndices()));
10661076

10671077
std::optional<int64_t> foldedIntraVectorOffset =
1068-
isAlignedEmulation
1069-
? 0
1070-
: getConstantIntValue(linearizedInfo.intraDataOffset);
1078+
isFullyAligned ? 0
1079+
: getConstantIntValue(linearizedInfo.intraDataOffset);
10711080

10721081
int64_t maxIntraDataOffset =
10731082
foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
@@ -1091,7 +1100,7 @@ struct ConvertVectorTransferRead final
10911100
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
10921101
linearizedInfo.intraDataOffset,
10931102
origElements);
1094-
} else if (!isAlignedEmulation) {
1103+
} else if (!isFullyAligned) {
10951104
result = staticallyExtractSubvector(
10961105
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
10971106
}
@@ -1774,33 +1783,34 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
17741783
/// LLVM to scramble with peephole optimizations. Templated to choose between
17751784
/// signed and unsigned conversions.
17761785
///
1777-
/// For example (signed):
1786+
/// EXAMPLE 1 (signed):
17781787
/// arith.extsi %in : vector<8xi4> to vector<8xi32>
1779-
/// is rewriten as
1780-
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1781-
/// %1 = arith.shli %0, 4 : vector<4xi8>
1782-
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
1783-
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
1784-
/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1785-
/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
1788+
/// is rewriten as:
1789+
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1790+
/// %1 = arith.shli %0, 4 : vector<4xi8>
1791+
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
1792+
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
1793+
/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1794+
/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
17861795
///
1796+
/// EXAMPLE 2 (fp):
17871797
/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
1788-
/// is rewriten as
1789-
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1790-
/// %1 = arith.shli %0, 4 : vector<4xi8>
1791-
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
1792-
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
1793-
/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1794-
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1798+
/// is rewriten as:
1799+
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1800+
/// %1 = arith.shli %0, 4 : vector<4xi8>
1801+
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
1802+
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
1803+
/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1804+
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
17951805
///
1796-
/// Example (unsigned):
1806+
/// EXAMPLE 3 (unsigned):
17971807
/// arith.extui %in : vector<8xi4> to vector<8xi32>
1798-
/// is rewritten as
1799-
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1800-
/// %1 = arith.andi %0, 15 : vector<4xi8>
1801-
/// %2 = arith.shrui %0, 4 : vector<4xi8>
1802-
/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
1803-
/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1808+
/// is rewritten as:
1809+
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1810+
/// %1 = arith.andi %0, 15 : vector<4xi8>
1811+
/// %2 = arith.shrui %0, 4 : vector<4xi8>
1812+
/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
1813+
/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
18041814
///
18051815
template <typename ConversionOpType, bool isSigned>
18061816
struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
@@ -1810,8 +1820,8 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
18101820
PatternRewriter &rewriter) const override {
18111821
// Verify the preconditions.
18121822
Value srcValue = conversionOp.getIn();
1813-
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1814-
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1823+
VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType());
1824+
VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
18151825

18161826
if (failed(
18171827
commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
@@ -1851,15 +1861,16 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
18511861
///
18521862
/// For example:
18531863
/// arith.trunci %in : vector<8xi32> to vector<8xi4>
1854-
/// is rewriten as
18551864
///
1856-
/// %cst = arith.constant dense<15> : vector<4xi8>
1857-
/// %cst_0 = arith.constant dense<4> : vector<4xi8>
1858-
/// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1859-
/// %2 = arith.andi %0, %cst : vector<4xi8>
1860-
/// %3 = arith.shli %1, %cst_0 : vector<4xi8>
1861-
/// %4 = arith.ori %2, %3 : vector<4xi8>
1862-
/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
1865+
/// is rewriten as:
1866+
///
1867+
/// %cst = arith.constant dense<15> : vector<4xi8>
1868+
/// %cst_0 = arith.constant dense<4> : vector<4xi8>
1869+
/// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1870+
/// %2 = arith.andi %0, %cst : vector<4xi8>
1871+
/// %3 = arith.shli %1, %cst_0 : vector<4xi8>
1872+
/// %4 = arith.ori %2, %3 : vector<4xi8>
1873+
/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
18631874
///
18641875
struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
18651876
using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
@@ -1903,10 +1914,11 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
19031914

19041915
/// Rewrite a sub-byte vector transpose into a sequence of instructions that
19051916
/// perform the transpose on wider (byte) element types.
1906-
/// For example:
1917+
///
1918+
/// EXAMPLE:
19071919
/// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
19081920
///
1909-
/// is rewritten as:
1921+
/// is rewritten as:
19101922
///
19111923
/// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
19121924
/// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
@@ -1954,6 +1966,7 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
19541966
// Public Interface Definition
19551967
//===----------------------------------------------------------------------===//
19561968

1969+
// The emulated type is inferred from the converted memref type.
19571970
void vector::populateVectorNarrowTypeEmulationPatterns(
19581971
const arith::NarrowTypeEmulationConverter &typeConverter,
19591972
RewritePatternSet &patterns) {
@@ -1966,22 +1979,26 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
19661979

19671980
void vector::populateVectorNarrowTypeRewritePatterns(
19681981
RewritePatternSet &patterns, PatternBenefit benefit) {
1982+
// TODO: Document what the emulated type is.
19691983
patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
19701984
RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
19711985
benefit);
19721986

19731987
// Patterns for aligned cases. We set higher priority as they are expected to
19741988
// generate better performance for aligned cases.
1989+
// The emulated type is always i8.
19751990
patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
19761991
RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
19771992
RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
19781993
benefit.getBenefit() + 1);
1994+
// The emulated type is always i8.
19791995
patterns
19801996
.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
19811997
RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
19821998
patterns.getContext(), benefit.getBenefit() + 1);
19831999
}
19842000

2001+
// The emulated type is always i8.
19852002
void vector::populateVectorTransposeNarrowTypeRewritePatterns(
19862003
RewritePatternSet &patterns, PatternBenefit benefit) {
19872004
patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);

0 commit comments

Comments
 (0)