@@ -880,6 +880,38 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
880880 return rewriter.create <vector::InterleaveOp>(loc, low, high);
881881}
882882
883+ // / Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
884+ // / bitwise ops that take advantage of high-level information to avoid leaving
885+ // / LLVM to scramble with peephole optimizations.
886+ static Value rewriteI4ToI8UnsignedExt (PatternRewriter &rewriter, Location loc,
887+ Value srcValue) {
888+ VectorType srcVecType = cast<VectorType>(srcValue.getType ());
889+ assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
890+ " Expected i4 type" );
891+
892+ // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
893+ SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
894+ constexpr int64_t i4Toi8BitwidthFactor = 2 ;
895+ i8VecShape.back () = i8VecShape.back () / i4Toi8BitwidthFactor;
896+ auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
897+ Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
898+
899+ // 2 Extend the i4 elements using shifts & masking. Low i4 elemens of each
900+ // byte are place in one vector and the high i4 elements in another vector.
901+ constexpr unsigned char lowBitsMask = 15 ; // Equivalent to [0000IIII] bit mask
902+ auto lowBitsMaskValues = rewriter.create <arith::ConstantOp>(
903+ loc, DenseElementsAttr::get (i8VecType, lowBitsMask));
904+ Value low = rewriter.create <arith::AndIOp>(loc, i8Vector.getType (), i8Vector,
905+ lowBitsMaskValues);
906+ constexpr int8_t highBitsToShift = 4 ;
907+ auto highShiftValues = rewriter.create <arith::ConstantOp>(
908+ loc, DenseElementsAttr::get (i8VecType, highBitsToShift));
909+ Value high = rewriter.create <arith::ShRSIOp>(loc, i8Vector, highShiftValues);
910+
911+ // 3. Interleave low and high i8 elements.
912+ return rewriter.create <vector::InterleaveOp>(loc, low, high);
913+ }
914+
883915// / Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
884916// / that take advantage of high-level information to avoid leaving LLVM to
885917// / scramble with peephole optimizations.
@@ -1099,6 +1131,50 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
10991131 }
11001132};
11011133
1134+ // / Rewrite the i4 -> i8 part of any unsigned conversion into a sequence of
1135+ // / shuffles and bitwise ops that take advantage of high-level information to
1136+ // / avoid leaving LLVM to scramble with peephole optimizations.
1137+ // /
1138+ // / For example:
1139+ // / arith.extui %in : vector<8xi4> to vector<8xi32>
1140+ // / is rewritten as
1141+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1142+ // / %1 = arith.andi %0, 15 : vector<4xi8>
1143+ // / %2 = arith.shrsi %0, 4 : vector<4xi8>
1144+ // / %3 = vector.interleave %1, %2 : vector<4xi8>
1145+ // / %4 = arith.extsi %3 : vector<8xi8> to vector<8xi32>
1146+ // /
1147+ template <typename ConversionOpType>
1148+ struct RewriteAlignedSubByteIntUnsignedExt
1149+ : OpRewritePattern<ConversionOpType> {
1150+ using OpRewritePattern<ConversionOpType>::OpRewritePattern;
1151+
1152+ LogicalResult matchAndRewrite (ConversionOpType conversionOp,
1153+ PatternRewriter &rewriter) const override {
1154+ // Verify the preconditions.
1155+ Value srcValue = conversionOp.getIn ();
1156+ auto srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1157+ auto dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
1158+ if (failed (
1159+ commonConversionPrecondition (rewriter, dstVecType, conversionOp)))
1160+ return failure ();
1161+
1162+ // Check general alignment preconditions.
1163+ if (failed (alignedConversionPrecondition (rewriter, srcVecType, dstVecType,
1164+ conversionOp)))
1165+ return failure ();
1166+
1167+ // Perform the rewrite.
1168+ Value subByteExt =
1169+ rewriteI4ToI8UnsignedExt (rewriter, conversionOp.getLoc (), srcValue);
1170+
1171+ // Finalize the rewrite.
1172+ rewriter.replaceOpWithNewOp <ConversionOpType>(
1173+ conversionOp, conversionOp.getType (), subByteExt);
1174+ return success ();
1175+ }
1176+ };
1177+
11021178// / Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
11031179// / bitwise ops that take advantage of high-level information to avoid leaving
11041180// / LLVM to scramble with peephole optimizations.
@@ -1233,6 +1309,8 @@ void vector::populateVectorNarrowTypeRewritePatterns(
12331309 RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
12341310 RewriteAlignedSubByteIntTrunc>(patterns.getContext (),
12351311 benefit.getBenefit () + 1 );
1312+ patterns.add <RewriteAlignedSubByteIntUnsignedExt<arith::ExtUIOp>>(
1313+ patterns.getContext (), benefit.getBenefit () + 1 );
12361314}
12371315
12381316void vector::populateVectorTransposeNarrowTypeRewritePatterns (
0 commit comments