@@ -510,10 +510,11 @@ struct BubbleUpCollapseShapeThroughExtractSlice
510510 PatternRewriter &rewriter) const override {
511511 auto collapseShapeOp =
512512 sliceOp.getSource ().getDefiningOp <tensor::CollapseShapeOp>();
513- if (!collapseShapeOp)
513+ if (!collapseShapeOp) {
514514 return rewriter.notifyMatchFailure (
515515 sliceOp,
516516 " tensor.extract_slice source not produced by tensor.collapse_shape" );
517+ }
517518
518519 if (!sliceOp.hasUnitStride ()) {
519520 return rewriter.notifyMatchFailure (
@@ -530,9 +531,10 @@ struct BubbleUpCollapseShapeThroughExtractSlice
530531 SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes ();
531532
532533 if (static_cast <size_t >(sliceOp.getResultType ().getRank ()) !=
533- collapsedSizes.size ())
534+ collapsedSizes.size ()) {
534535 return rewriter.notifyMatchFailure (sliceOp,
535536 " unimplemented: rank reducing slice" );
537+ }
536538
537539 ArrayRef<int64_t > srcShape = collapseShapeOp.getSrcType ().getShape ();
538540 SmallVector<ReassociationIndices, 4 > reassociationIndices =
@@ -546,10 +548,9 @@ struct BubbleUpCollapseShapeThroughExtractSlice
546548 SmallVector<OpFoldResult> expandedStrides (srcShape.size (),
547549 rewriter.getIndexAttr (1 ));
548550
549- for (auto [groupIdx, reassocIndices] :
550- enumerate(collapseShapeOp.getReassociationIndices ())) {
551- OpFoldResult collapsedSize = collapsedSizes[groupIdx];
552- OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
551+ for (auto [collapsedSize, collapsedOffset, reassocIndices] :
552+ llvm::zip_equal (collapsedSizes, collapsedOffsets,
553+ collapseShapeOp.getReassociationIndices ())) {
553554 // CASE #1 - size and/or offset are dynamic.
554555 // In this case, the slice can be represented as a contiguous slice only
555556 // if there is a single dimension in the reassociation group that has a
@@ -614,10 +615,11 @@ struct BubbleUpCollapseShapeThroughExtractSlice
614615 // We need to make sure that the slice size can be set to the shape size
615616 // and the offset to 0.
616617 if ((currentCollapsedsize % expandedShapeSize) != 0 ||
617- (currentCollapsedOffset % expandedShapeSize) != 0 )
618+ (currentCollapsedOffset % expandedShapeSize) != 0 ) {
618619 return rewriter.notifyMatchFailure (
619620 sliceOp, " unsupported: cannot be extracted as a contiguous slice "
620621 " of the src of the collapse_shape" );
622+ }
621623
622624 groupExpandedSizes.push_back (rewriter.getIndexAttr (expandedShapeSize));
623625 groupExpandedOffsets.push_back (rewriter.getIndexAttr (0 ));
@@ -632,10 +634,11 @@ struct BubbleUpCollapseShapeThroughExtractSlice
632634 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
633635 // We need to make sure that the slice size in this dim + offset will
634636 // not exceed the shape size.
635- if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize)
637+ if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
636638 return rewriter.notifyMatchFailure (
637639 sliceOp, " unsupported: slice cannot be extracted as a contiguous "
638640 " slice of the src of the collapse_shape" );
641+ }
639642
640643 groupExpandedSizes.push_back (
641644 rewriter.getIndexAttr (currentCollapsedsize));
0 commit comments