Skip to content

Commit 87c6047

Browse files
Make dilations/strides optional
1 parent f557fca commit 87c6047

File tree

1 file changed

+16
-21
lines changed

1 file changed

+16
-21
lines changed

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -438,19 +438,6 @@ static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps,
438438
return true;
439439
}
440440

441-
/// Utility to update `dilations` and `strides` by copy the corresponding data
442-
/// from `tempDilations` and `tempStrides`.
443-
static void updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
444-
SmallVector<int64_t> *strides,
445-
ArrayRef<int64_t> tempDilations,
446-
ArrayRef<int64_t> tempStrides) {
447-
if (!(dilations && strides))
448-
return;
449-
*dilations = SmallVector<int64_t>(tempDilations);
450-
*strides = SmallVector<int64_t>(tempStrides);
451-
return;
452-
}
453-
454441
// ---------------------------------------------
455442
// Matchers for specific convolution operation.
456443
// ---------------------------------------------
@@ -497,7 +484,8 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
497484
// Match body
498485
if (!bodyMatcherForConvolutionOps(yieldVal, body))
499486
return false;
500-
updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
487+
*dilations = SmallVector<int64_t>(tempDilations);
488+
*strides = SmallVector<int64_t>(tempStrides);
501489
return true;
502490
}
503491

@@ -547,7 +535,8 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
547535
// Match body
548536
if (!bodyMatcherForConvolutionOps(yieldVal, body))
549537
return false;
550-
updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
538+
*dilations = SmallVector<int64_t>(tempDilations);
539+
*strides = SmallVector<int64_t>(tempStrides);
551540
return true;
552541
}
553542

@@ -608,7 +597,8 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
608597
// Match body
609598
if (!bodyMatcherForConvolutionOps(yieldVal, body))
610599
return false;
611-
updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
600+
*dilations = SmallVector<int64_t>(tempDilations);
601+
*strides = SmallVector<int64_t>(tempStrides);
612602
return true;
613603
}
614604

@@ -655,7 +645,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
655645
// Match body
656646
if (!bodyMatcherForMaxSignedPoolOps(yieldVal, body))
657647
return false;
658-
updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
648+
*dilations = SmallVector<int64_t>(tempDilations);
649+
*strides = SmallVector<int64_t>(tempStrides);
659650
return true;
660651
}
661652

@@ -702,7 +693,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
702693
// Match body
703694
if (!bodyMatcherForMinSignedPoolOps(yieldVal, body))
704695
return false;
705-
updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
696+
*dilations = SmallVector<int64_t>(tempDilations);
697+
*strides = SmallVector<int64_t>(tempStrides);
706698
return true;
707699
}
708700

@@ -749,7 +741,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
749741
// Match body
750742
if (!bodyMatcherForSumPoolOps(yieldVal, body))
751743
return false;
752-
updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
744+
*dilations = SmallVector<int64_t>(tempDilations);
745+
*strides = SmallVector<int64_t>(tempStrides);
753746
return true;
754747
}
755748

@@ -796,7 +789,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
796789
// Match body
797790
if (!bodyMatcherForMaxUnsignedPoolOps(yieldVal, body))
798791
return false;
799-
updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
792+
*dilations = SmallVector<int64_t>(tempDilations);
793+
*strides = SmallVector<int64_t>(tempStrides);
800794
return true;
801795
}
802796

@@ -843,7 +837,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
843837
// Match body
844838
if (!bodyMatcherForMinUnsignedPoolOps(yieldVal, body))
845839
return false;
846-
updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
840+
*dilations = SmallVector<int64_t>(tempDilations);
841+
*strides = SmallVector<int64_t>(tempStrides);
847842
return true;
848843
}
849844

0 commit comments

Comments
 (0)