diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp index d1eb27099db61..108abe800b13e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" @@ -50,28 +51,71 @@ static Value createMul(Location loc, Value x, Value y, Type accType, return arith::MulFOp::create(builder, loc, xConvert, yConvert); } -// Delinearizes the given composite `index` by the basis specified in `factors`. -static SmallVector unrollIndex(OpBuilder &b, Location loc, Value index, - ArrayRef factors) { - assert(!factors.empty() && "empty factor list"); - SmallVector basis; - for (int64_t f : factors) - basis.push_back(arith::ConstantOp::create(b, loc, b.getIndexAttr(f))); - FailureOr> multiIndex = - affine::delinearizeIndex(b, loc, index, basis); - assert(!failed(multiIndex) && "Failed to linearize img2col index"); - return *multiIndex; +// Generate the affine expression to compute the convolved index +// for the input as `oIndex * stride + fIndex`, +// where oIndex: output iterator; fIndex: filter iterator. +static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride, + bool useSymbols = true) { + AffineExpr oExpr, fExpr; + if (useSymbols) + bindSymbols(b.getContext(), oExpr, fExpr); + else + bindDims(b.getContext(), oExpr, fExpr); + return AffineExpr(stride * oExpr + fExpr); } -// Given indices corresponding to iterators in the output (oIndex) and filter -// (fIndex) for a convolution, compute the convolved index for the -// input as `oIndex * stride + fIndex`. -static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, - Value fIndex, int64_t stride) { - AffineExpr oExpr, fExpr; - bindSymbols(b.getContext(), oExpr, fExpr); - AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr); - return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex}); +// Stores the affine expressions to map the iteration space of the im2col matrix +// to the corresponding indices of the output and filter matrices +struct Im2ColToOperandsExprs { + AffineExpr fhIndex; + AffineExpr fwIndex; + AffineExpr icIndex; + AffineExpr ohIndex; + AffineExpr owIndex; +}; + +// Stores the affine expressions to map the iteration space of the im2col matrix +// to the input matrix indices +struct Im2ColToInputDimsExprs { + AffineExpr bIndex; + AffineExpr hIndex; + AffineExpr wIndex; + AffineExpr cIndex; +}; + +/// Construct the affine expressions that map the indices of the im2col matrix +/// to the corresponding input tensor indices for a 2D convolution with the the +/// provided strides. +/// +/// @param exprs Affine expressions for output and filter indices. +/// @param strides [height, width] stride values for the convolution. +/// @param rewriter Pattern rewriter. +/// @return Affine expressions mapping im2col matrix indices to input +/// offsets. +static Im2ColToInputDimsExprs +getIm2ColInputExpressions(Im2ColToOperandsExprs exprs, + ArrayRef strides, RewriterBase &rewriter) { + // maps the iteration space of the im2col matrix to (output_y, filter_y) + auto hIndicesMap = AffineMap::inferFromExprList( + {ArrayRef{exprs.ohIndex, exprs.fhIndex}}, rewriter.getContext())[0]; + // maps the iteration space of the im2col matrix to (output_x, filter_x) + auto wIndicesMap = AffineMap::inferFromExprList( + {ArrayRef{exprs.owIndex, exprs.fwIndex}}, rewriter.getContext())[0]; + // Compute the input indexing map, to map the indices of the im2col matrix to + // the original input offsets. Each element of the im2col matrix corresponds + // to a pair of (out_element, filter_element). First, we build the expressions + // to compute the input (ix, iy) indices from [out_x/y, filter_x/y] pairs; + // then we compose them with the maps that map the im2col matrix elements to + // the (out_element, filter_element) pairs. + auto bIndexExpr = rewriter.getAffineDimExpr(0U); + auto hIndexExpr = getConvolvedExpr(rewriter, strides[0], + /*useSymbols*/ false); + hIndexExpr = hIndexExpr.compose(hIndicesMap); + auto wIndexExpr = getConvolvedExpr(rewriter, strides[1], + /*useSymbols*/ false); + wIndexExpr = wIndexExpr.compose(wIndicesMap); + auto cIndexExpr = exprs.icIndex; + return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr}; } FailureOr> @@ -135,44 +179,37 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { auto reduction = utils::IteratorType::reduction; SmallVector img2colIterators(nloops, parallel); + // Given an index of the im2col matrix, retrieve the corresponding indices of + // the output and filter matrices + auto mIndicesExprs = + delinearize(rewriter.getAffineDimExpr(1U), ArrayRef{ow, 1}); + auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U), + ArrayRef{fw * ic, ic, 1}); + Im2ColToOperandsExprs i2cToOperExprs; + i2cToOperExprs.fhIndex = kIndicesExprs[0]; + i2cToOperExprs.fwIndex = kIndicesExprs[1]; + i2cToOperExprs.icIndex = kIndicesExprs[2]; + i2cToOperExprs.ohIndex = mIndicesExprs[0]; + i2cToOperExprs.owIndex = mIndicesExprs[1]; + + // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] + Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions( + i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues()), + rewriter); + auto inMap = + AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex, + inExprs.wIndex, inExprs.cIndex}}, + rewriter.getContext())[0]; + SmallVector img2colIndexingMaps = { - AffineMap::getMultiDimIdentityMap(nloops, context)}; + inMap, AffineMap::getMultiDimIdentityMap(nloops, context)}; auto img2ColTensor = linalg::GenericOp::create( rewriter, loc, colTensor.getType(), - /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, + /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - // Get the iterators named based on the matmul (batch, m, k). - Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0); - Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1); - Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2); - - // Recover the original iteration indices from the problem/input sizes. - SmallVector mIndices = unrollIndex( - nestedBuilder, nestedLoc, mIndex, ArrayRef{oh, ow}); - auto ohIndex = mIndices[0]; - auto owIndex = mIndices[1]; - - SmallVector kIndices = unrollIndex( - nestedBuilder, nestedLoc, kIndex, ArrayRef{fh, fw, ic}); - auto fhIndex = kIndices[0]; - auto fwIndex = kIndices[1]; - auto icIndex = kIndices[2]; - - // Extract the input element corresponding to the expanded indices. - Value hIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex, - convOp.getStrides().getValues()[0]); - Value wIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex, - convOp.getStrides().getValues()[1]); - - // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] - SmallVector extractionIndices{bIndex, hIndex, wIndex, icIndex}; - Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input, - extractionIndices); - linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal); + linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]); }); // Because the filter does not share the same batch dimension, @@ -421,44 +458,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { auto reduction = utils::IteratorType::reduction; SmallVector img2colIterators(nloops, parallel); - SmallVector img2colIndexingMaps = { - AffineMap::getMultiDimIdentityMap(nloops, context)}; + // Recover the original iteration indices from the problem/input sizes: + // given an index of the im2col matrix, retrieve the corresponding indices of + // the output and filter matrices + auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(1U), + ArrayRef{fh * fw, fw, 1}); + auto mIndicesExprs = + delinearize(rewriter.getAffineDimExpr(2U), ArrayRef{ow, 1}); + Im2ColToOperandsExprs i2cToOperExprs; + i2cToOperExprs.icIndex = kIndicesExprs[0]; + i2cToOperExprs.fhIndex = kIndicesExprs[1]; + i2cToOperExprs.fwIndex = kIndicesExprs[2]; + i2cToOperExprs.ohIndex = mIndicesExprs[0]; + i2cToOperExprs.owIndex = mIndicesExprs[1]; + Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions( + i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues()), + rewriter); + auto inMap = + AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.cIndex, + inExprs.hIndex, inExprs.wIndex}}, + rewriter.getContext())[0]; + // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw] + SmallVector img2colIndexingMaps = { + inMap, AffineMap::getMultiDimIdentityMap(nloops, context)}; auto img2ColTensor = linalg::GenericOp::create( rewriter, loc, colTensor.getType(), - /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, + /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - // Get the iterators named based on the matmul (batch, m, k). - Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0); - Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 1); - Value nIndex = linalg::IndexOp::create(nestedBuilder, loc, 2); - - // Recover the original iteration indices from the problem/input sizes. - SmallVector kIndices = unrollIndex( - nestedBuilder, nestedLoc, kIndex, ArrayRef{ic, fh, fw}); - auto icIndex = kIndices[0]; - auto fhIndex = kIndices[1]; - auto fwIndex = kIndices[2]; - - SmallVector nIndices = unrollIndex( - nestedBuilder, nestedLoc, nIndex, ArrayRef{oh, ow}); - auto ohIndex = nIndices[0]; - auto owIndex = nIndices[1]; - - // Extract the input element corresponding to the expanded indices. - Value hIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex, - convOp.getStrides().getValues()[0]); - Value wIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex, - convOp.getStrides().getValues()[1]); - - // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw] - SmallVector extractionIndices{bIndex, icIndex, hIndex, wIndex}; - Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input, - extractionIndices); - linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal); + linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]); }); // Because the filter does not share the same batch dimension, @@ -545,6 +574,7 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { Value reshapedOutput = tensor::CollapseShapeOp::create( rewriter, loc, reshapedOutputType, output, outputReassocIndices); + // Shape of the Toeplitz matrix produced by Im2col. SmallVector colTensorShape = {n, oh * ow, fh * fw * ic}; Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape, inputType.getElementType()); @@ -556,44 +586,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { auto reduction = utils::IteratorType::reduction; SmallVector img2colIterators(nloops, parallel); + // Given an index of the im2col matrix, retrieve the corresponding indices of + // the output and filter matrices + auto mIndicesExprs = + delinearize(rewriter.getAffineDimExpr(1U), ArrayRef{ow, 1}); + auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U), + ArrayRef{fw * ic, ic, 1}); + Im2ColToOperandsExprs i2cToOperExprs; + i2cToOperExprs.fhIndex = kIndicesExprs[0]; + i2cToOperExprs.fwIndex = kIndicesExprs[1]; + i2cToOperExprs.icIndex = kIndicesExprs[2]; + i2cToOperExprs.ohIndex = mIndicesExprs[0]; + i2cToOperExprs.owIndex = mIndicesExprs[1]; + + // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] + Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions( + i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues()), + rewriter); + auto inMap = + AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex, + inExprs.wIndex, inExprs.cIndex}}, + rewriter.getContext())[0]; SmallVector img2colIndexingMaps = { - AffineMap::getMultiDimIdentityMap(nloops, context)}; + inMap, AffineMap::getMultiDimIdentityMap(nloops, context)}; auto img2ColTensor = linalg::GenericOp::create( rewriter, loc, colTensor.getType(), - /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, + /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - // Get the iterators named based on the matmul (batch, m, k). - Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0); - Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1); - Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2); - - // Recover the original iteration indices from the problem/input sizes. - SmallVector mIndices = unrollIndex( - nestedBuilder, nestedLoc, mIndex, ArrayRef{oh, ow}); - auto ohIndex = mIndices[0]; - auto owIndex = mIndices[1]; - - SmallVector kIndices = unrollIndex( - nestedBuilder, nestedLoc, kIndex, ArrayRef{fh, fw, ic}); - auto fhIndex = kIndices[0]; - auto fwIndex = kIndices[1]; - auto icIndex = kIndices[2]; - - // Extract the input element corresponding to the expanded indices. - Value hIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex, - convOp.getStrides().getValues()[0]); - Value wIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex, - convOp.getStrides().getValues()[1]); - - // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] - SmallVector extractionIndices{bIndex, hIndex, wIndex, icIndex}; - Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input, - extractionIndices); - linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal); + linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]); }); // Because we didn't transpose the filters we don't actually have a batched diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir index c17f20b2d03ab..8627fcd2576b9 100644 --- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir +++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir @@ -34,40 +34,35 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>] // CHECK: ^bb0(%[[OUT_DATA:.+]]: f32) -// Collapsed indices. -// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index -// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index -// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index - -// Compute input channel/convolved indices. -// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]] -// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]] -// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]] - -// Extract from the input tensor. -// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract -// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32> -// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32 - // CHECK: IR printer: transformed // CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// Im2col maps +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)> +// CHECK-DAG: #[[MAPI2C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// Matmul maps // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK: @conv_16433136 -// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32> -// CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x4x16xf32> -// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32> + +// CHECK: @conv_16433136 +// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32> +// CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x4x16xf32> +// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32> // CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32> // CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32> -// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32> -// CHECK: %[[COL_TENSOR:.+]] = linalg.generic -// CHECK-SAME: #[[MAP0]] -// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32) -// CHECK: linalg.yield %{{.+}} : f32 -// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic +// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32> + +// CHECK: %[[COL_TENSOR:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAPI2C]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[INPUT]] : tensor<1x16x16x4xf32>) +// CHECK-SAME: outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>) +// CHECK: ^bb0(%[[IN:.+]]: f32, %out: f32): +// CHECK: linalg.yield %[[IN]] : f32 +// CHECK: } -> tensor<1x196x36xf32> + +// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic // CHECK-SAME: #[[MAP1]] // CHECK-SAME: #[[MAP2]] // CHECK-SAME: #[[MAP3]] @@ -180,7 +175,10 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// Im2col maps +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)> +// CHECK-DAG: #[[MAPI2C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + // CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> // CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> // CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> @@ -191,9 +189,13 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32> // CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x196x36xf32> // CHECK: %[[IMG2COL:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP]]] +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAPI2C]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[INPUT]] : tensor<8x16x16x4xf32>) // CHECK-SAME: outs(%[[IT]] : tensor<8x196x36xf32>) +// CHECK: ^bb0(%[[IN:.+]]: f32, %out: f32): +// CHECK: linalg.yield %[[IN]] : f32 +// CHECK: } -> tensor<8x196x36xf32> // CHECK: %[[MATMUL:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] @@ -224,13 +226,9 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - // Im2col maps -// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 floordiv 9)> -// CHECK-DAG: #[[MAP7:.+]] = affine_map<()[s0, s1] -> (s0 floordiv 14 + (s1 mod 9) floordiv 3)> -// CHECK-DAG: #[[MAP8:.+]] = affine_map<()[s0, s1] -> (s0 + s1 - (s0 floordiv 14) * 14 - (s1 floordiv 3) * 3)> - +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 9, d2 floordiv 14 + (d1 mod 9) floordiv 3, d2 mod 14 + d1 mod 3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> // CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> @@ -242,32 +240,12 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32> // CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x36x196xf32> // CHECK: %[[IMG2COL:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP]]] +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[INPUT]] : tensor<8x4x16x16xf32>) // CHECK-SAME: outs(%[[IT]] : tensor<8x36x196xf32>) -// Collapsed indices. -// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index -// CHECK: %[[KINDEX:.+]] = linalg.index 1 : index -// CHECK: %[[NINDEX:.+]] = linalg.index 2 : index - -// Compute input channel/convolved indices. -// CHECK: %[[ICINDEX:.+]] = affine.apply #[[MAP1]]()[%[[KINDEX]]] -// CHECK: %[[CONVH:.+]] = affine.apply #[[MAP7]]()[%[[NINDEX]], %[[KINDEX]]] -// CHECK: %[[CONVW:.+]] = affine.apply #[[MAP8]]()[%[[NINDEX]], %[[KINDEX]]] - -// Extract from the input tensor. -// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract -// CHECK-SAME: %[[INPUT]]{{\[}}%[[BINDEX]], %[[ICINDEX]], %[[CONVH]], %[[CONVW]]] : tensor<8x4x16x16xf32> -// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32 -// CHECK: %[[MATMUL:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] -// CHECK-SAME: ins(%[[CS_FILTER]], %[[IMG2COL]] : tensor<16x36xf32>, tensor<8x36x196xf32>) -// CHECK-SAME: outs(%[[CS_RESULT]] : tensor<8x16x196xf32>) -// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32): -// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 -// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32 -// CHECK: linalg.yield %[[ADD]] : f32 +// CHECK: ^bb0(%[[IN:.+]]: f32, %out: f32): +// CHECK: linalg.yield %[[IN]] : f32 // CHECK: } -> tensor<8x16x196xf32> // CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] output_shape [8, 16, 14, 14] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32> // CHECK: return %[[CS_FINAL]] @@ -291,31 +269,19 @@ module attributes {transform.with_named_sequence} { // CHECK: IR printer: tensor_producer // CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic +// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)> // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>] -// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32) - -// Collapsed indices. -// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index -// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index -// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index - -// Compute input channel/convolved indices. -// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]] -// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]] -// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]] - -// Extract from the input tensor. -// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract -// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32> -// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32 +// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32) +// CHECK: linalg.yield %[[IN_DATA]] : f32 // CHECK: IR printer: transformed // CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> // CHECK: @conv_2d_nhwc_fhwc // CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32> // CHECK-SAME: %[[FILTER:.+]]: tensor<16x3x3x4xf32> @@ -324,13 +290,13 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32> // CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32> // CHECK: %[[COL_TENSOR:.+]] = linalg.generic -// CHECK-SAME: #[[MAP0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]]] // CHECK: ^bb0(%[[OUT_DATA:.+]]: f32) // CHECK: linalg.yield %{{.+}} : f32 // CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic -// CHECK-SAME: #[[MAP1]] // CHECK-SAME: #[[MAP2]] // CHECK-SAME: #[[MAP3]] +// CHECK-SAME: #[[MAP4]] // CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32>) // CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>) // CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)