@@ -449,7 +449,7 @@ LogicalResult
449449ScalingExtFRewritePattern::matchAndRewrite (arith::ScalingExtFOp op,
450450 PatternRewriter &rewriter) const {
451451 Location loc = op.getLoc ();
452- constexpr int64_t opWidth = 2 ;
452+ constexpr int64_t opOutWidth = 2 ;
453453
454454 Value in = op.getIn ();
455455 Value scale = op.getScale ();
@@ -460,6 +460,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
460460 Type scaleType = getElementTypeOrSelf (scale);
461461 Type outType = getElementTypeOrSelf (out);
462462
463+ int64_t opInWidth = 32 / inType.getIntOrFloatBitWidth ();
464+
463465 VectorType outVecType = dyn_cast<VectorType>(out.getType ());
464466 VectorType scaleVecType = dyn_cast<VectorType>(scale.getType ());
465467
@@ -473,7 +475,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
473475 else if (scaleType.getIntOrFloatBitWidth () > 32 )
474476 scale = arith::TruncFOp::create (rewriter, loc, scaleF32Type, scale);
475477
476- VectorType extScaleResultType = VectorType::get (opWidth , outType);
478+ VectorType extScaleResultType = VectorType::get (opOutWidth , outType);
477479
478480 if (!outVecType) {
479481 Value inCast = vector::BroadcastOp::create (rewriter, loc,
@@ -487,10 +489,11 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
487489
488490 VectorType inVecType = cast<VectorType>(in.getType ());
489491 Value origScale = getOriginalVectorValue (op.getScale ());
492+ VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType ());
490493
491494 ArrayRef<int64_t > inShape = inVecType.getShape ();
492495 SmallVector<int64_t > originalScaleShape;
493- if (auto origScaleVecType = dyn_cast<VectorType>(origScale. getType ()) )
496+ if (origScaleVecType)
494497 llvm::append_range (originalScaleShape, origScaleVecType.getShape ());
495498
496499 originalScaleShape.insert (originalScaleShape.end (),
@@ -524,19 +527,26 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
524527 Value blockResult =
525528 rewriter.createOrFold <vector::BroadcastOp>(loc, blockResultType, zero);
526529
527- for (int64_t i = 0 , sliceWidth = std::min (opWidth , blockSize - i);
530+ for (int64_t i = 0 , inSliceWidth = std::min (opInWidth , blockSize - i);
528531 i < blockSize;
529- i += sliceWidth, sliceWidth = std::min (opWidth, blockSize - i)) {
530- Value slice = vector::ExtractStridedSliceOp::create (
531- rewriter, loc, block1D, i, sliceWidth, 1 );
532- // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
533- Value scaleExt = amdgpu::ScaledExtPackedOp::create (
534- rewriter, loc, extScaleResultType, slice, uniformScale, 0 );
535- if (sliceWidth != opWidth)
536- scaleExt = vector::ExtractStridedSliceOp::create (
537- rewriter, loc, scaleExt, 0 , sliceWidth, 1 );
538- blockResult = vector::InsertStridedSliceOp::create (
539- rewriter, loc, scaleExt, blockResult, i, 1 );
532+ i += inSliceWidth, inSliceWidth = std::min (opInWidth, blockSize - i)) {
533+ Value inSlice = vector::ExtractStridedSliceOp::create (
534+ rewriter, loc, block1D, i, inSliceWidth, 1 );
535+ for (int64_t j = 0 ,
536+ outSliceWidth = std::min (opOutWidth, inSliceWidth - j);
537+ j < inSliceWidth; j += outSliceWidth,
538+ outSliceWidth = std::min (opOutWidth, inSliceWidth - j)) {
539+ // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
540+ Value scaleExt = amdgpu::ScaledExtPackedOp::create (
541+ rewriter, loc, extScaleResultType, inSlice, uniformScale,
542+ j / opOutWidth);
543+ if (outSliceWidth < opOutWidth) {
544+ scaleExt = vector::ExtractStridedSliceOp::create (
545+ rewriter, loc, scaleExt, 0 , outSliceWidth, 1 );
546+ }
547+ blockResult = vector::InsertStridedSliceOp::create (
548+ rewriter, loc, scaleExt, blockResult, i + j, 1 );
549+ }
540550 }
541551
542552 VectorType resultType = VectorType::get (ratio, outType);
@@ -555,7 +565,7 @@ LogicalResult
555565ScalingTruncFRewritePattern::matchAndRewrite (arith::ScalingTruncFOp op,
556566 PatternRewriter &rewriter) const {
557567 Location loc = op.getLoc ();
558- constexpr int64_t opWidth = 2 ;
568+ constexpr int64_t opInWidth = 2 ;
559569
560570 Value in = op.getIn ();
561571 Value scale = op.getScale ();
@@ -568,7 +578,6 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
568578
569579 VectorType outVecType = dyn_cast<VectorType>(out.getType ());
570580 VectorType scaleVecType = dyn_cast<VectorType>(scale.getType ());
571-
572581 if (outVecType && outVecType.isScalable ())
573582 return failure ();
574583
@@ -581,8 +590,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
581590
582591 Value zero = arith::ConstantOp::create (rewriter, loc, outType,
583592 rewriter.getFloatAttr (outType, 0.0 ));
584- unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth ();
585- VectorType truncScaleResultType = VectorType::get (numPackedElem , outType);
593+ int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth ();
594+ VectorType truncScaleResultType = VectorType::get (opOutWidth , outType);
586595
587596 if (!outVecType) {
588597 Type inVecType = VectorType::get (1 , inType);
@@ -598,16 +607,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
598607
599608 VectorType inVecType = cast<VectorType>(in.getType ());
600609 Value origScale = getOriginalVectorValue (op.getScale ());
610+ VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType ());
601611
602612 ArrayRef<int64_t > inShape = inVecType.getShape ();
603- SmallVector<int64_t > originalScaleShape ;
604- if (auto origScaleVecType = dyn_cast<VectorType>(origScale. getType ()) )
605- llvm::append_range (originalScaleShape , origScaleVecType.getShape ());
613+ SmallVector<int64_t > scaleShape ;
614+ if (origScaleVecType)
615+ llvm::append_range (scaleShape , origScaleVecType.getShape ());
606616
607- originalScaleShape.insert (originalScaleShape.end (),
608- inShape.size () - originalScaleShape.size (), 1 );
617+ scaleShape.insert (scaleShape.end (), inShape.size () - scaleShape.size (), 1 );
609618
610- auto maybeRatio = computeShapeRatio (inShape, originalScaleShape );
619+ auto maybeRatio = computeShapeRatio (inShape, scaleShape );
611620 assert (maybeRatio &&
612621 " failed to derive block size from broadcast or splat operation" );
613622
@@ -633,20 +642,36 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
633642 Value blockResult =
634643 rewriter.createOrFold <vector::BroadcastOp>(loc, blockResultType, zero);
635644
636- for (int64_t i = 0 , sliceWidth = std::min (opWidth, blockSize - i);
637- i < blockSize;
638- i += sliceWidth, sliceWidth = std::min (opWidth, blockSize - i)) {
639- Value slice = vector::ExtractStridedSliceOp::create (
640- rewriter, loc, block1D, i, sliceWidth, 1 );
641- // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
642- Value scaleTrunc = amdgpu::PackedScaledTruncOp::create (
643- rewriter, loc, truncScaleResultType, slice, uniformScale, 0 ,
644- /* existing=*/ nullptr );
645- int64_t packedWidth =
646- cast<VectorType>(scaleTrunc.getType ()).getNumElements ();
647- if (packedWidth != opWidth)
645+ for (int64_t i = 0 , outSliceWidth = std::min (opOutWidth, blockSize - i);
646+ i < blockSize; i += outSliceWidth,
647+ outSliceWidth = std::min (opOutWidth, blockSize - i)) {
648+ Value scaleTrunc;
649+ // Case where <= 2 elements are being truncated.
650+ if (outSliceWidth <= opInWidth) {
651+ Value slice = vector::ExtractStridedSliceOp::create (
652+ rewriter, loc, block1D, i, outSliceWidth, 1 );
653+ // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
654+ scaleTrunc = amdgpu::PackedScaledTruncOp::create (
655+ rewriter, loc, truncScaleResultType, slice, uniformScale, 0 ,
656+ /* existing=*/ nullptr );
657+ } else {
658+ scaleTrunc = vector::BroadcastOp::create (rewriter, loc,
659+ truncScaleResultType, zero);
660+ for (int64_t j = 0 ,
661+ inSliceWidth = std::min (opInWidth, outSliceWidth - j);
662+ j < outSliceWidth; j += opInWidth,
663+ inSliceWidth = std::min (opInWidth, outSliceWidth - j)) {
664+ Value slice = vector::ExtractStridedSliceOp::create (
665+ rewriter, loc, block1D, i + j, inSliceWidth, 1 );
666+ scaleTrunc = amdgpu::PackedScaledTruncOp::create (
667+ rewriter, loc, truncScaleResultType, slice, uniformScale,
668+ j / opInWidth, scaleTrunc);
669+ }
670+ }
671+ if (outSliceWidth != opOutWidth) {
648672 scaleTrunc = vector::ExtractStridedSliceOp::create (
649- rewriter, loc, scaleTrunc, 0 , sliceWidth, 1 );
673+ rewriter, loc, scaleTrunc, 0 , outSliceWidth, 1 );
674+ }
650675 blockResult = vector::InsertStridedSliceOp::create (
651676 rewriter, loc, scaleTrunc, blockResult, i, 1 );
652677 }
0 commit comments