@@ -642,19 +642,19 @@ struct BitCastRewriter {
642642
643643 BitCastRewriter (VectorType sourceVectorType, VectorType targetVectorType);
644644
645- // / Verify that the preconditions for the rewrite are met.
646- LogicalResult precondition (PatternRewriter &rewriter,
647- VectorType preconditionVectorType , Operation *op);
645+ // / Verify that general preconditions for the rewrite are met.
646+ LogicalResult commonPrecondition (PatternRewriter &rewriter,
647+ VectorType preconditionType , Operation *op);
648648
649649 // / Precompute the metadata for the rewrite.
650650 SmallVector<BitCastRewriter::Metadata>
651651 precomputeMetadata (IntegerType shuffledElementType);
652652
653653 // / Rewrite one step of the sequence:
654654 // / `(shuffle -> and -> shiftright -> shiftleft -> or)`.
655- Value rewriteStep (PatternRewriter &rewriter, Location loc, Value initialValue ,
656- Value runningResult,
657- const BitCastRewriter::Metadata &metadata);
655+ Value genericRewriteStep (PatternRewriter &rewriter, Location loc,
656+ Value initialValue, Value runningResult,
657+ const BitCastRewriter::Metadata &metadata);
658658
659659private:
660660 // / Underlying enumerator that encodes the provenance of the bits in the each
@@ -719,21 +719,54 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
719719 LDBG (" \n " << enumerator.sourceElementRanges );
720720}
721721
722- LogicalResult BitCastRewriter::precondition (PatternRewriter &rewriter,
723- VectorType precondition,
724- Operation *op) {
725- if (precondition.getRank () != 1 || precondition.isScalable ())
722+ // / Verify that the precondition type meets the common preconditions for any
723+ // / conversion.
724+ static LogicalResult commonConversionPrecondition (PatternRewriter &rewriter,
725+ VectorType preconditionType,
726+ Operation *op) {
727+ if (preconditionType.getRank () != 1 || preconditionType.isScalable ())
726728 return rewriter.notifyMatchFailure (op, " scalable or >1-D vector" );
727729
728730 // TODO: consider relaxing this restriction in the future if we find ways
729731 // to really work with subbyte elements across the MLIR/LLVM boundary.
730- int64_t resultBitwidth = precondition .getElementTypeBitWidth ();
732+ unsigned resultBitwidth = preconditionType .getElementTypeBitWidth ();
731733 if (resultBitwidth % 8 != 0 )
732734 return rewriter.notifyMatchFailure (op, " bitwidth is not k * 8" );
733735
734736 return success ();
735737}
736738
739+ LogicalResult BitCastRewriter::commonPrecondition (PatternRewriter &rewriter,
740+ VectorType preconditionType,
741+ Operation *op) {
742+ if (!enumerator.sourceVectorType || !enumerator.targetVectorType )
743+ return rewriter.notifyMatchFailure (op, " types are not vector" );
744+
745+ return commonConversionPrecondition (rewriter, preconditionType, op);
746+ }
747+
748+ // / Verify that source and destination element types meet the precondition for
749+ // / the supported aligned conversion cases. Alignment means that the either the
750+ // / source element type is multiple of the destination element type or the other
751+ // / way around.
752+ // /
753+ // / NOTE: This method assumes that common conversion preconditions are met.
754+ static LogicalResult alignedConversionPrecondition (PatternRewriter &rewriter,
755+ VectorType srcType,
756+ VectorType dstType,
757+ Operation *op) {
758+ unsigned srcElemBitwidth = srcType.getElementTypeBitWidth ();
759+ unsigned dstElemBitwidth = dstType.getElementTypeBitWidth ();
760+ unsigned byteBitwidth = 8 ;
761+
762+ // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
763+ if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
764+ (dstElemBitwidth % srcElemBitwidth) != 0 )
765+ return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
766+
767+ return success ();
768+ }
769+
737770SmallVector<BitCastRewriter::Metadata>
738771BitCastRewriter::precomputeMetadata (IntegerType shuffledElementType) {
739772 SmallVector<BitCastRewriter::Metadata> result;
@@ -775,9 +808,9 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
775808 return result;
776809}
777810
778- Value BitCastRewriter::rewriteStep (PatternRewriter &rewriter, Location loc,
779- Value initialValue , Value runningResult ,
780- const BitCastRewriter::Metadata &metadata) {
811+ Value BitCastRewriter::genericRewriteStep (
812+ PatternRewriter &rewriter, Location loc , Value initialValue ,
813+ Value runningResult, const BitCastRewriter::Metadata &metadata) {
781814 // Create vector.shuffle from the metadata.
782815 auto shuffleOp = rewriter.create <vector::ShuffleOp>(
783816 loc, initialValue, initialValue, metadata.shuffles );
@@ -810,6 +843,44 @@ Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
810843 return runningResult;
811844}
812845
846+ // / Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
847+ // / bitwise ops that take advantage of high-level information to avoid leaving
848+ // / LLVM to scramble with peephole optimizations.
849+ static Value rewriteI4ToI8SignedExt (PatternRewriter &rewriter, Location loc,
850+ Value srcValue) {
851+ VectorType srcVecType = cast<VectorType>(srcValue.getType ());
852+ assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
853+ " Expected i4 type" );
854+
855+ // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
856+ int64_t vecDimSize = srcVecType.getShape ().back ();
857+ SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
858+ constexpr int64_t i4Toi8BitwidthFactor = 2 ;
859+ i8VecShape.back () = i8VecShape.back () / i4Toi8BitwidthFactor;
860+ auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
861+ Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
862+
863+ // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
864+ // byte are place in one vector and the high i4 elements in another vector.
865+ constexpr int8_t bitsToShift = 4 ;
866+ auto shiftValues = rewriter.create <arith::ConstantOp>(
867+ loc, DenseElementsAttr::get (i8VecType, bitsToShift));
868+ Value shl = rewriter.create <arith::ShLIOp>(loc, i8Vector, shiftValues);
869+ Value low = rewriter.create <arith::ShRSIOp>(loc, shl, shiftValues);
870+ Value high = rewriter.create <arith::ShRSIOp>(loc, i8Vector, shiftValues);
871+
872+ // 3. Interleave low and high i8 elements using a shuffle.
873+ SmallVector<int64_t > interleaveMaskValues;
874+ interleaveMaskValues.reserve (vecDimSize);
875+ for (int i = 0 , end = vecDimSize / 2 ; i < end; ++i) {
876+ interleaveMaskValues.push_back (i);
877+ interleaveMaskValues.push_back (i + (vecDimSize / 2 ));
878+ }
879+
880+ return rewriter.create <vector::ShuffleOp>(
881+ loc, low, high, rewriter.getI64ArrayAttr (interleaveMaskValues));
882+ }
883+
813884namespace {
814885// / Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
815886// / advantage of high-level information to avoid leaving LLVM to scramble with
@@ -829,7 +900,7 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
829900 VectorType sourceVectorType = bitCastOp.getSourceVectorType ();
830901 VectorType targetVectorType = bitCastOp.getResultVectorType ();
831902 BitCastRewriter bcr (sourceVectorType, targetVectorType);
832- if (failed (bcr.precondition (rewriter, targetVectorType, bitCastOp)))
903+ if (failed (bcr.commonPrecondition (rewriter, targetVectorType, bitCastOp)))
833904 return failure ();
834905
835906 // Perform the rewrite.
@@ -839,8 +910,8 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
839910 Value runningResult;
840911 for (const BitCastRewriter ::Metadata &metadata :
841912 bcr.precomputeMetadata (shuffledElementType)) {
842- runningResult = bcr.rewriteStep (rewriter, bitCastOp-> getLoc (), truncValue,
843- runningResult, metadata);
913+ runningResult = bcr.genericRewriteStep (
914+ rewriter, bitCastOp-> getLoc (), truncValue, runningResult, metadata);
844915 }
845916
846917 // Finalize the rewrite.
@@ -885,7 +956,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
885956 VectorType sourceVectorType = bitCastOp.getSourceVectorType ();
886957 VectorType targetVectorType = bitCastOp.getResultVectorType ();
887958 BitCastRewriter bcr (sourceVectorType, targetVectorType);
888- if (failed (bcr.precondition (
959+ if (failed (bcr.commonPrecondition (
889960 rewriter, cast<VectorType>(extOp.getOut ().getType ()), bitCastOp)))
890961 return failure ();
891962
@@ -896,8 +967,8 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
896967 cast<IntegerType>(getElementTypeOrSelf (sourceValue.getType ()));
897968 for (const BitCastRewriter::Metadata &metadata :
898969 bcr.precomputeMetadata (shuffledElementType)) {
899- runningResult = bcr.rewriteStep (rewriter, bitCastOp-> getLoc (),
900- sourceValue, runningResult, metadata);
970+ runningResult = bcr.genericRewriteStep (
971+ rewriter, bitCastOp-> getLoc (), sourceValue, runningResult, metadata);
901972 }
902973
903974 // Finalize the rewrite.
@@ -915,6 +986,52 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
915986 return success ();
916987 }
917988};
989+
990+ // / Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
991+ // / bitwise ops that take advantage of high-level information to avoid leaving
992+ // / LLVM to scramble with peephole optimizations.
993+ // /
994+ // / For example:
995+ // / extsi vector<8xi4> -> vector<8xi32>
996+ // / is rewriten as
997+ // / sequence of shuffles and bitwise of for i4 -> i8
998+ // / extsi vector<8xi8> -> vector<8xi32>
999+ // /
1000+ // / sitofp vector<8xi4> -> vector<8xf32>
1001+ // / is rewriten as
1002+ // / sequence of shuffles and bitwise of for i4 -> i8
1003+ // / sitofp vector<8xi8> -> vector<8xf32>
1004+ // /
1005+ template <typename ConversionOpType>
1006+ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1007+ using OpRewritePattern<ConversionOpType>::OpRewritePattern;
1008+
1009+ LogicalResult matchAndRewrite (ConversionOpType conversionOp,
1010+ PatternRewriter &rewriter) const override {
1011+ // Set up the BitCastRewriter and verify the preconditions.
1012+ Value srcValue = conversionOp.getIn ();
1013+ auto srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1014+ auto dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
1015+ if (failed (
1016+ commonConversionPrecondition (rewriter, dstVecType, conversionOp)))
1017+ return failure ();
1018+
1019+ // Check general alignment preconditions.
1020+ if (failed (alignedConversionPrecondition (rewriter, srcVecType, dstVecType,
1021+ conversionOp)))
1022+ return failure ();
1023+
1024+ // Perform the rewrite.
1025+ Value subByteExt =
1026+ rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1027+
1028+ // Finalize the rewrite.
1029+ rewriter.replaceOpWithNewOp <ConversionOpType>(
1030+ conversionOp, conversionOp.getType (), subByteExt);
1031+ return success ();
1032+ }
1033+ };
1034+
9181035} // namespace
9191036
9201037// ===----------------------------------------------------------------------===//
@@ -936,4 +1053,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
9361053 patterns.add <RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
9371054 RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext (),
9381055 benefit);
1056+
1057+ // Patterns for aligned cases. We set higher priority as they are expected to
1058+ // generate better performance for aligned cases.
1059+ patterns.add <RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
1060+ RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
1061+ patterns.getContext (), benefit.getBenefit () + 1 );
9391062}
0 commit comments