@@ -896,17 +896,17 @@ static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
896896 auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
897897 Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
898898
899- // 2 Extend the i4 elements using shifts & masking. Low i4 elemens of each
900- // byte are place in one vector and the high i4 elements in another vector.
901- constexpr unsigned char lowBitsMask = 15 ; // Equivalent to [0000IIII ] bit mask
899+ // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
900+ // byte are placed in one vector and the high i4 elements in another vector.
901+ constexpr uint8_t lowBitsMask = 15 ; // Equivalent to [00001111 ] bit mask
902902 auto lowBitsMaskValues = rewriter.create <arith::ConstantOp>(
903903 loc, DenseElementsAttr::get (i8VecType, lowBitsMask));
904- Value low = rewriter.create <arith::AndIOp>(loc, i8Vector. getType () , i8Vector,
904+ Value low = rewriter.create <arith::AndIOp>(loc, i8VecType , i8Vector,
905905 lowBitsMaskValues);
906906 constexpr int8_t highBitsToShift = 4 ;
907907 auto highShiftValues = rewriter.create <arith::ConstantOp>(
908908 loc, DenseElementsAttr::get (i8VecType, highBitsToShift));
909- Value high = rewriter.create <arith::ShRSIOp >(loc, i8Vector, highShiftValues);
909+ Value high = rewriter.create <arith::ShRUIOp >(loc, i8Vector, highShiftValues);
910910
911911 // 3. Interleave low and high i8 elements.
912912 return rewriter.create <vector::InterleaveOp>(loc, low, high);
@@ -1080,9 +1080,10 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10801080
10811081// / Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
10821082// / bitwise ops that take advantage of high-level information to avoid leaving
1083- // / LLVM to scramble with peephole optimizations.
1083+ // / LLVM to scramble with peephole optimizations. Templated to choose between
1084+ // / signed and unsigned conversions.
10841085// /
1085- // / For example:
1086+ // / For example (signed) :
10861087// / arith.extsi %in : vector<8xi4> to vector<8xi32>
10871088// / is rewriten as
10881089// / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
@@ -1101,60 +1102,25 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
11011102// / %4 = vector.interleave %2, %3 : vector<4xi8>
11021103// / %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
11031104// /
1104- template <typename ConversionOpType>
1105- struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1106- using OpRewritePattern<ConversionOpType>::OpRewritePattern;
1107-
1108- LogicalResult matchAndRewrite (ConversionOpType conversionOp,
1109- PatternRewriter &rewriter) const override {
1110- // Verify the preconditions.
1111- Value srcValue = conversionOp.getIn ();
1112- auto srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1113- auto dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
1114- if (failed (
1115- commonConversionPrecondition (rewriter, dstVecType, conversionOp)))
1116- return failure ();
1117-
1118- // Check general alignment preconditions.
1119- if (failed (alignedConversionPrecondition (rewriter, srcVecType, dstVecType,
1120- conversionOp)))
1121- return failure ();
1122-
1123- // Perform the rewrite.
1124- Value subByteExt =
1125- rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1126-
1127- // Finalize the rewrite.
1128- rewriter.replaceOpWithNewOp <ConversionOpType>(
1129- conversionOp, conversionOp.getType (), subByteExt);
1130- return success ();
1131- }
1132- };
1133-
1134- // / Rewrite the i4 -> i8 part of any unsigned conversion into a sequence of
1135- // / shuffles and bitwise ops that take advantage of high-level information to
1136- // / avoid leaving LLVM to scramble with peephole optimizations.
1137- // /
1138- // / For example:
1105+ // / Example (unsigned):
11391106// / arith.extui %in : vector<8xi4> to vector<8xi32>
11401107// / is rewritten as
11411108// / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
11421109// / %1 = arith.andi %0, 15 : vector<4xi8>
1143- // / %2 = arith.shrsi %0, 4 : vector<4xi8>
1110+ // / %2 = arith.shrui %0, 4 : vector<4xi8>
11441111// / %3 = vector.interleave %1, %2 : vector<4xi8>
1145- // / %4 = arith.extsi %3 : vector<8xi8> to vector<8xi32>
1112+ // / %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
11461113// /
1147- template <typename ConversionOpType>
1148- struct RewriteAlignedSubByteIntUnsignedExt
1149- : OpRewritePattern<ConversionOpType> {
1114+ template <typename ConversionOpType, bool isSigned>
1115+ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
11501116 using OpRewritePattern<ConversionOpType>::OpRewritePattern;
11511117
11521118 LogicalResult matchAndRewrite (ConversionOpType conversionOp,
11531119 PatternRewriter &rewriter) const override {
11541120 // Verify the preconditions.
11551121 Value srcValue = conversionOp.getIn ();
1156- auto srcVecType = dyn_cast <VectorType>(srcValue.getType ());
1157- auto dstVecType = dyn_cast <VectorType>(conversionOp.getType ());
1122+ auto srcVecType = cast <VectorType>(srcValue.getType ());
1123+ auto dstVecType = cast <VectorType>(conversionOp.getType ());
11581124 if (failed (
11591125 commonConversionPrecondition (rewriter, dstVecType, conversionOp)))
11601126 return failure ();
@@ -1165,8 +1131,14 @@ struct RewriteAlignedSubByteIntUnsignedExt
11651131 return failure ();
11661132
11671133 // Perform the rewrite.
1168- Value subByteExt =
1169- rewriteI4ToI8UnsignedExt (rewriter, conversionOp.getLoc (), srcValue);
1134+ Value subByteExt;
1135+ if (isSigned) {
1136+ subByteExt =
1137+ rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1138+ } else {
1139+ subByteExt =
1140+ rewriteI4ToI8UnsignedExt (rewriter, conversionOp.getLoc (), srcValue);
1141+ }
11701142
11711143 // Finalize the rewrite.
11721144 rewriter.replaceOpWithNewOp <ConversionOpType>(
@@ -1305,11 +1277,11 @@ void vector::populateVectorNarrowTypeRewritePatterns(
13051277
13061278 // Patterns for aligned cases. We set higher priority as they are expected to
13071279 // generate better performance for aligned cases.
1308- patterns.add <RewriteAlignedSubByteIntSignedExt <arith::ExtSIOp>,
1309- RewriteAlignedSubByteIntSignedExt <arith::SIToFPOp>,
1280+ patterns.add <RewriteAlignedSubByteIntExt <arith::ExtSIOp, /* isSigned= */ true >,
1281+ RewriteAlignedSubByteIntExt <arith::SIToFPOp, /* isSigned= */ true >,
13101282 RewriteAlignedSubByteIntTrunc>(patterns.getContext (),
13111283 benefit.getBenefit () + 1 );
1312- patterns.add <RewriteAlignedSubByteIntUnsignedExt <arith::ExtUIOp>>(
1284+ patterns.add <RewriteAlignedSubByteIntExt <arith::ExtUIOp, /* isSigned= */ false >>(
13131285 patterns.getContext (), benefit.getBenefit () + 1 );
13141286}
13151287
0 commit comments