@@ -642,19 +642,19 @@ struct BitCastRewriter {
642642
643643 BitCastRewriter (VectorType sourceVectorType, VectorType targetVectorType);
644644
645- // / Verify that the preconditions for the rewrite are met.
646- LogicalResult precondition (PatternRewriter &rewriter,
647- VectorType preconditionVectorType , Operation *op);
645+ // / Verify that general preconditions for the rewrite are met.
646+ LogicalResult commonPrecondition (PatternRewriter &rewriter,
647+ VectorType preconditionType , Operation *op);
648648
649649 // / Precompute the metadata for the rewrite.
650650 SmallVector<BitCastRewriter::Metadata>
651651 precomputeMetadata (IntegerType shuffledElementType);
652652
653653 // / Rewrite one step of the sequence:
654654 // / `(shuffle -> and -> shiftright -> shiftleft -> or)`.
655- Value rewriteStep (PatternRewriter &rewriter, Location loc, Value initialValue ,
656- Value runningResult,
657- const BitCastRewriter::Metadata &metadata);
655+ Value genericRewriteStep (PatternRewriter &rewriter, Location loc,
656+ Value initialValue, Value runningResult,
657+ const BitCastRewriter::Metadata &metadata);
658658
659659private:
660660 // / Underlying enumerator that encodes the provenance of the bits in the each
@@ -719,21 +719,57 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
719719 LDBG (" \n " << enumerator.sourceElementRanges );
720720}
721721
722- LogicalResult BitCastRewriter::precondition (PatternRewriter &rewriter,
723- VectorType precondition,
724- Operation *op) {
725- if (precondition.getRank () != 1 || precondition.isScalable ())
722+ // / Verify that the precondition type meets the common preconditions for any
723+ // / conversion.
724+ static LogicalResult commonConversionPrecondition (PatternRewriter &rewriter,
725+ VectorType preconditionType,
726+ Operation *op) {
727+ if (!preconditionType || preconditionType.getRank () != 1 ||
728+ preconditionType.isScalable ())
726729 return rewriter.notifyMatchFailure (op, " scalable or >1-D vector" );
727730
728731 // TODO: consider relaxing this restriction in the future if we find ways
729732 // to really work with subbyte elements across the MLIR/LLVM boundary.
730- int64_t resultBitwidth = precondition .getElementTypeBitWidth ();
733+ unsigned resultBitwidth = preconditionType .getElementTypeBitWidth ();
731734 if (resultBitwidth % 8 != 0 )
732735 return rewriter.notifyMatchFailure (op, " bitwidth is not k * 8" );
733736
734737 return success ();
735738}
736739
740+ LogicalResult BitCastRewriter::commonPrecondition (PatternRewriter &rewriter,
741+ VectorType preconditionType,
742+ Operation *op) {
743+ if (!enumerator.sourceVectorType || !enumerator.targetVectorType )
744+ return rewriter.notifyMatchFailure (op, " types are not vector" );
745+
746+ return commonConversionPrecondition (rewriter, preconditionType, op);
747+ }
748+
749+ // / Verify that source and destination element types meet the precondition for
750+ // / the supported aligned conversion cases. Alignment means that the either the
751+ // / source element type is multiple of the destination element type or the other
752+ // / way around.
753+ // /
754+ // / NOTE: This method assumes that common conversion preconditions are met.
755+ static LogicalResult alignedConversionPrecondition (PatternRewriter &rewriter,
756+ VectorType srcType,
757+ VectorType dstType,
758+ Operation *op) {
759+ if (!srcType || !dstType)
760+ return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
761+ unsigned srcElemBitwidth = srcType.getElementTypeBitWidth ();
762+ unsigned dstElemBitwidth = dstType.getElementTypeBitWidth ();
763+ unsigned byteBitwidth = 8 ;
764+
765+ // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
766+ if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
767+ (dstElemBitwidth % srcElemBitwidth) != 0 )
768+ return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
769+
770+ return success ();
771+ }
772+
737773SmallVector<BitCastRewriter::Metadata>
738774BitCastRewriter::precomputeMetadata (IntegerType shuffledElementType) {
739775 SmallVector<BitCastRewriter::Metadata> result;
@@ -775,9 +811,9 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
775811 return result;
776812}
777813
778- Value BitCastRewriter::rewriteStep (PatternRewriter &rewriter, Location loc,
779- Value initialValue , Value runningResult ,
780- const BitCastRewriter::Metadata &metadata) {
814+ Value BitCastRewriter::genericRewriteStep (
815+ PatternRewriter &rewriter, Location loc , Value initialValue ,
816+ Value runningResult, const BitCastRewriter::Metadata &metadata) {
781817 // Create vector.shuffle from the metadata.
782818 auto shuffleOp = rewriter.create <vector::ShuffleOp>(
783819 loc, initialValue, initialValue, metadata.shuffles );
@@ -810,6 +846,44 @@ Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
810846 return runningResult;
811847}
812848
849+ // / Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
850+ // / bitwise ops that take advantage of high-level information to avoid leaving
851+ // / LLVM to scramble with peephole optimizations.
852+ static Value rewriteI4ToI8SignedExt (PatternRewriter &rewriter, Location loc,
853+ Value srcValue) {
854+ VectorType srcVecType = cast<VectorType>(srcValue.getType ());
855+ assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
856+ " Expected i4 type" );
857+
858+ // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
859+ int64_t vecDimSize = srcVecType.getShape ().back ();
860+ SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
861+ constexpr int64_t i4Toi8BitwidthFactor = 2 ;
862+ i8VecShape.back () = i8VecShape.back () / i4Toi8BitwidthFactor;
863+ auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
864+ Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
865+
866+ // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
867+ // byte are place in one vector and the high i4 elements in another vector.
868+ constexpr int8_t bitsToShift = 4 ;
869+ auto shiftValues = rewriter.create <arith::ConstantOp>(
870+ loc, DenseElementsAttr::get (i8VecType, bitsToShift));
871+ Value shl = rewriter.create <arith::ShLIOp>(loc, i8Vector, shiftValues);
872+ Value low = rewriter.create <arith::ShRSIOp>(loc, shl, shiftValues);
873+ Value high = rewriter.create <arith::ShRSIOp>(loc, i8Vector, shiftValues);
874+
875+ // 3. Interleave low and high i8 elements using a shuffle.
876+ SmallVector<int64_t > interleaveMaskValues;
877+ interleaveMaskValues.reserve (vecDimSize);
878+ for (int i = 0 , end = vecDimSize / 2 ; i < end; ++i) {
879+ interleaveMaskValues.push_back (i);
880+ interleaveMaskValues.push_back (i + (vecDimSize / 2 ));
881+ }
882+
883+ return rewriter.create <vector::ShuffleOp>(
884+ loc, low, high, rewriter.getI64ArrayAttr (interleaveMaskValues));
885+ }
886+
813887namespace {
814888// / Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
815889// / advantage of high-level information to avoid leaving LLVM to scramble with
@@ -829,7 +903,7 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
829903 VectorType sourceVectorType = bitCastOp.getSourceVectorType ();
830904 VectorType targetVectorType = bitCastOp.getResultVectorType ();
831905 BitCastRewriter bcr (sourceVectorType, targetVectorType);
832- if (failed (bcr.precondition (rewriter, targetVectorType, bitCastOp)))
906+ if (failed (bcr.commonPrecondition (rewriter, targetVectorType, bitCastOp)))
833907 return failure ();
834908
835909 // Perform the rewrite.
@@ -839,8 +913,8 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
839913 Value runningResult;
840914 for (const BitCastRewriter ::Metadata &metadata :
841915 bcr.precomputeMetadata (shuffledElementType)) {
842- runningResult = bcr.rewriteStep (rewriter, bitCastOp-> getLoc (), truncValue,
843- runningResult, metadata);
916+ runningResult = bcr.genericRewriteStep (
917+ rewriter, bitCastOp-> getLoc (), truncValue, runningResult, metadata);
844918 }
845919
846920 // Finalize the rewrite.
@@ -893,7 +967,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
893967 VectorType sourceVectorType = bitCastOp.getSourceVectorType ();
894968 VectorType targetVectorType = bitCastOp.getResultVectorType ();
895969 BitCastRewriter bcr (sourceVectorType, targetVectorType);
896- if (failed (bcr.precondition (
970+ if (failed (bcr.commonPrecondition (
897971 rewriter, cast<VectorType>(extOp.getOut ().getType ()), bitCastOp)))
898972 return failure ();
899973
@@ -904,8 +978,8 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
904978 cast<IntegerType>(getElementTypeOrSelf (sourceValue.getType ()));
905979 for (const BitCastRewriter::Metadata &metadata :
906980 bcr.precomputeMetadata (shuffledElementType)) {
907- runningResult = bcr.rewriteStep (rewriter, bitCastOp-> getLoc (),
908- sourceValue, runningResult, metadata);
981+ runningResult = bcr.genericRewriteStep (
982+ rewriter, bitCastOp-> getLoc (), sourceValue, runningResult, metadata);
909983 }
910984
911985 // Finalize the rewrite.
@@ -923,6 +997,62 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
923997 return success ();
924998 }
925999};
1000+
1001+ // / Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
1002+ // / bitwise ops that take advantage of high-level information to avoid leaving
1003+ // / LLVM to scramble with peephole optimizations.
1004+ // /
1005+ // / For example:
1006+ // / arith.extsi %in : vector<8xi4> to vector<8xi32>
1007+ // / is rewriten as
1008+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1009+ // / %1 = arith.shli %0, 4 : vector<4xi8>
1010+ // / %2 = arith.shrsi %1, 4 : vector<4xi8>
1011+ // / %3 = arith.shrsi %0, 4 : vector<4xi8>
1012+ // / %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
1013+ // / : vector<4xi8>, vector<4xi8>
1014+ // / %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
1015+ // /
1016+ // / arith.sitofp %in : vector<8xi4> to vector<8xf32>
1017+ // / is rewriten as
1018+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1019+ // / %1 = arith.shli %0, 4 : vector<4xi8>
1020+ // / %2 = arith.shrsi %1, 4 : vector<4xi8>
1021+ // / %3 = arith.shrsi %0, 4 : vector<4xi8>
1022+ // / %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
1023+ // / : vector<4xi8>, vector<4xi8>
1024+ // / %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1025+ // /
1026+ template <typename ConversionOpType>
1027+ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1028+ using OpRewritePattern<ConversionOpType>::OpRewritePattern;
1029+
1030+ LogicalResult matchAndRewrite (ConversionOpType conversionOp,
1031+ PatternRewriter &rewriter) const override {
1032+ // Set up the BitCastRewriter and verify the preconditions.
1033+ Value srcValue = conversionOp.getIn ();
1034+ auto srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1035+ auto dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
1036+ if (failed (
1037+ commonConversionPrecondition (rewriter, dstVecType, conversionOp)))
1038+ return failure ();
1039+
1040+ // Check general alignment preconditions.
1041+ if (failed (alignedConversionPrecondition (rewriter, srcVecType, dstVecType,
1042+ conversionOp)))
1043+ return failure ();
1044+
1045+ // Perform the rewrite.
1046+ Value subByteExt =
1047+ rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1048+
1049+ // Finalize the rewrite.
1050+ rewriter.replaceOpWithNewOp <ConversionOpType>(
1051+ conversionOp, conversionOp.getType (), subByteExt);
1052+ return success ();
1053+ }
1054+ };
1055+
9261056} // namespace
9271057
9281058// ===----------------------------------------------------------------------===//
@@ -944,4 +1074,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
9441074 patterns.add <RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
9451075 RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext (),
9461076 benefit);
1077+
1078+ // Patterns for aligned cases. We set higher priority as they are expected to
1079+ // generate better performance for aligned cases.
1080+ patterns.add <RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
1081+ RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
1082+ patterns.getContext (), benefit.getBenefit () + 1 );
9471083}
0 commit comments