@@ -268,44 +268,25 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
268268static FailureOr<LinalgOp> specializeLinalgConvolutions (RewriterBase &rewriter,
269269 GenericOp genericOp) {
270270 SmallVector<int64_t > dilations, strides;
271+ #define CONV_OP_SPECIALIZER (ConvOpTy ) \
272+ if (isaConvolutionOpOfType<ConvOpTy>(genericOp, &dilations, &strides)) \
273+ return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations, \
274+ strides); \
271275 // -----------------------------
272276 // Depthwise Convolution ops.
273277 // -----------------------------
274- if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
275- genericOp, &dilations, &strides))
276- return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
277- rewriter, genericOp, dilations, strides);
278- if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
279- genericOp, &dilations, &strides))
280- return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
281- rewriter, genericOp, dilations, strides);
282- if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
283- genericOp, &dilations, &strides))
284- return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
285- rewriter, genericOp, dilations, strides);
278+ CONV_OP_SPECIALIZER (linalg::DepthwiseConv1DNwcWcOp);
279+ CONV_OP_SPECIALIZER (linalg::DepthwiseConv2DNchwChwOp);
280+ CONV_OP_SPECIALIZER (linalg::DepthwiseConv3DNdhwcDhwcmOp);
286281 // -----------------------------
287282 // Pooling ops.
288283 // -----------------------------
289- if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
290- &strides))
291- return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
292- dilations, strides);
293- if (isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(genericOp, &dilations,
294- &strides))
295- return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
296- dilations, strides);
297- if (isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(genericOp, &dilations,
298- &strides))
299- return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
300- dilations, strides);
301- if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
302- genericOp, &dilations, &strides))
303- return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
304- rewriter, genericOp, dilations, strides);
305- if (isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
306- genericOp, &dilations, &strides))
307- return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
308- rewriter, genericOp, dilations, strides);
284+ CONV_OP_SPECIALIZER (linalg::PoolingNhwcMaxOp);
285+ CONV_OP_SPECIALIZER (linalg::PoolingNhwcMinOp);
286+ CONV_OP_SPECIALIZER (linalg::PoolingNhwcSumOp);
287+ CONV_OP_SPECIALIZER (linalg::PoolingNhwcMaxUnsignedOp);
288+ CONV_OP_SPECIALIZER (linalg::PoolingNhwcMinUnsignedOp);
289+ #undef CONV_OP_SPECIALIZER
309290 return failure ();
310291}
311292
0 commit comments