From 4240341b4f06f1b77f63b0f619cae3804d88eb68 Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Mon, 17 Jun 2024 11:24:07 +0100 Subject: [PATCH 1/9] [mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm Define high level winograd operators and convert conv_2d_nhwc_fhwc into winograd operators. According to Winograd Conv2D algorithm, we need three transform operators for input, filter, and output transformation. The formula of Winograd Conv2D algorithm is Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A filter transform: G x g x G^T input transform: B^T x d x B output transform: A^T x y x A The implementation is based on the paper, Fast Algorithm for Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308) --- .../mlir/Dialect/Linalg/IR/LinalgOps.td | 114 +++++++ .../Dialect/Linalg/Transforms/Transforms.h | 4 + mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 78 +++++ .../Dialect/Linalg/Transforms/CMakeLists.txt | 1 + .../Linalg/Transforms/WinogradConv2D.cpp | 321 ++++++++++++++++++ mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 248 ++++++++++++++ .../Dialect/Linalg/TestLinalgTransforms.cpp | 13 + 7 files changed, 779 insertions(+) create mode 100644 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp create mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 64c538367267d..de1097b6ac27b 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -154,4 +154,118 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax", let hasVerifier = 1; } +def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> { + let summary = "Winograd filter transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of filter + transformation (G x g x G^T) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins AnyRankedTensor:$filter, + AnyRankedTensor:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $filter `:` type($filter) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + +def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> { + let summary = "Winograd input transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of input + transformation (B^T x d x B) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $input `:` type($input) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + +def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> { + let summary = "Winograd output transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of output + transformation (A^T x y x A) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins AnyRankedTensor:$value, + AnyRankedTensor:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $value `:` type($value) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + #endif // LINALG_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 05e97befdec1f..835aeaf2ffed3 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1692,6 +1692,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns, void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn); +/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r). +void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, + int64_t r); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 57d126603ebd7..7bf2a5bca037f 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2734,6 +2734,84 @@ FailureOr> SoftmaxOp::decomposeOperation(OpBuilder &b) { return SmallVector{result}; } +//===----------------------------------------------------------------------===// +// WinogradFilterTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradFilterTransformOp::verify() { + auto filterType = cast(getFilter().getType()); + auto outputType = cast(getOutput().getType()); + auto filterElemType = filterType.getElementType(); + auto outputElemType = outputType.getElementType(); + if (filterElemType != outputElemType) { + return emitOpError() << "expected element type of input " << filterElemType + << " to match element type of output " + << outputElemType; + } + + unsigned filterRank = filterType.getRank(); + if (filterRank != 4) + return emitOpError() << "expected rank of input is 4"; + + unsigned outputRank = outputType.getRank(); + if (outputRank != 6) + return emitOpError() << "expected rank of output is 6"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// WinogradInputTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradInputTransformOp::verify() { + auto inputType = cast(getInput().getType()); + auto outputType = cast(getOutput().getType()); + auto inputElemType = inputType.getElementType(); + auto outputElemType = outputType.getElementType(); + if (inputElemType != outputElemType) { + return emitOpError() << "expected element type of input " << inputElemType + << " to match element type of output " + << outputElemType; + } + + unsigned inputRank = inputType.getRank(); + if (inputRank != 4) + return emitOpError() << "expected rank of input is 4"; + + unsigned outputRank = outputType.getRank(); + if (outputRank != 6) + return emitOpError() << "expected rank of output is 6"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// WinogradOutputTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradOutputTransformOp::verify() { + auto valueType = cast(getValue().getType()); + auto outputType = cast(getOutput().getType()); + auto valueElemType = valueType.getElementType(); + auto outputElemType = outputType.getElementType(); + if (valueElemType != outputElemType) { + return emitOpError() << "expected element type of value " << valueElemType + << " to match element type of output " + << outputElemType; + } + + unsigned valueRank = valueType.getRank(); + if (valueRank != 6) + return emitOpError() << "expected rank of input is 6"; + + unsigned outputRank = outputType.getRank(); + if (outputRank != 4) + return emitOpError() << "expected rank of output is 4"; + + return success(); +} + //===----------------------------------------------------------------------===// // LinalgDialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 7e3dc56e0acdc..a7dcc29b5b9be 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms Transforms.cpp TransposeConv2D.cpp Vectorization.cpp + WinogradConv2D.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp new file mode 100644 index 0000000000000..86e834d51f2fc --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -0,0 +1,321 @@ +//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implement Winograd Conv2D algorithm. The implementation is based on the +// paper: Fast Algorithms for Convolutional Neural Networks +// (https://arxiv.org/abs/1509.09308) +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir { +namespace linalg { + +namespace { + +using TransformMapKeyTy = std::pair; + +// We use F(m, r) to define the size of minimal filtering algorithms. +// m is the output dimension and r is the filter dimension. We can get +// the input dimension, alpha, from the formula, alpha = m + r - 1. +// +// For example, when m = 2 and r = 3, we know its input size is 4. +// The Conv2D will operate on 4x4 input data with 3x3 filter and get +// 2x2 output result. +constexpr TransformMapKeyTy F_2_3{2, 3}; +constexpr TransformMapKeyTy F_4_3{4, 3}; +constexpr TransformMapKeyTy F_2_5{2, 5}; + +Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) { + auto type = cast(data.getType()); + auto elementType = type.getElementType(); + auto shape = type.getShape(); + auto collapseType = RankedTensorType::get( + {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]}, + elementType); + SmallVector reassociation = {{0, 1, 2, 3}, {4}, {5}}; + return rewriter.create(loc, collapseType, data, + reassociation); +} + +// This function generates linalg.batch_matmul to multiply input with filter. +// linalg.batch_matmul only supports 3-dimension data sets. We can treat +// tileH x tileW x H x W data as the 1-dimension data array. That is to convert +// [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this way, we +// can convert 6-dimension input data to 3-dimension representation that is +// suitable for linalg.batch_matmul. +// +// Batched matmul will do the matrix multiply with the reduction on channel. +// +// We get +// +// %collapsed_input = tensor.collapse_shape %input +// %collapsed_filter = tensor.collapse_shape %filter +// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter +// %expanded_ret = tensor.expand_shape %ret +// +// After this function, we get return value with data layout +// (tileH, tileW, H, W, N, F). +Value matrixMultiply(RewriterBase &rewriter, Location loc, + Value transformedFilter, Value transformedInput) { + auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter); + auto collapseInput = collapse2DData(rewriter, loc, transformedInput); + + // Batched matrix multiply + auto filterType = cast(transformedFilter.getType()); + auto filterShape = filterType.getShape(); + auto inputType = cast(transformedInput.getType()); + auto inputElemType = inputType.getElementType(); + auto inputShape = inputType.getShape(); + + auto matmulType = RankedTensorType::get( + {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3], + inputShape[4], filterShape[5]}, + inputElemType); + Value init = rewriter.create(loc, matmulType.getShape(), + inputElemType); + + auto matmulOp = rewriter.create( + loc, matmulType, ValueRange({collapseInput, collapseFilter}), + ValueRange{init}); + + // Expand matmul result + SmallVector reassociation = {{0, 1, 2, 3}, {4}, {5}}; + auto expandType = + RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2], + inputShape[3], inputShape[4], filterShape[5]}, + inputElemType); + auto expandOutput = rewriter.create( + loc, expandType, matmulOp.getResult(0), reassociation); + return expandOutput; +} + +Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value, + RankedTensorType alignedType) { + Value alignedInput = rewriter.create( + loc, alignedType.getShape(), alignedType.getElementType()); + + auto zeroIndex = rewriter.getIndexAttr(0); + auto oneIndex = rewriter.getIndexAttr(1); + SmallVector offsets(4, zeroIndex); + SmallVector strides(4, oneIndex); + + auto valueType = cast(value.getType()); + auto valueShape = valueType.getShape(); + SmallVector sizes; + sizes.emplace_back(rewriter.getIndexAttr(valueShape[0])); + sizes.emplace_back(rewriter.getIndexAttr(valueShape[1])); + sizes.emplace_back(rewriter.getIndexAttr(valueShape[2])); + sizes.emplace_back(rewriter.getIndexAttr(valueShape[3])); + + return rewriter.create(loc, value, alignedInput, + offsets, sizes, strides); +} + +Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc, + Value value, RankedTensorType extractedType) { + auto zeroIndex = rewriter.getIndexAttr(0); + auto oneIndex = rewriter.getIndexAttr(1); + SmallVector offsets(4, zeroIndex); + SmallVector strides(4, oneIndex); + + auto extractedShape = extractedType.getShape(); + SmallVector sizes; + sizes.emplace_back(rewriter.getIndexAttr(extractedShape[0])); + sizes.emplace_back(rewriter.getIndexAttr(extractedShape[1])); + sizes.emplace_back(rewriter.getIndexAttr(extractedShape[2])); + sizes.emplace_back(rewriter.getIndexAttr(extractedShape[3])); + + return rewriter.create(loc, extractedType, value, + offsets, sizes, strides); +} + +bool hasAllOneValues(DenseIntElementsAttr attr) { + return llvm::all_of( + attr, [](const APInt &element) { return element.getSExtValue() == 1; }); +} + +FailureOr winogradConv2DHelper(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp convOp, + int64_t m, int64_t r) { + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + auto inputType = cast(input.getType()); + auto filterType = cast(filter.getType()); + auto outputType = cast(output.getType()); + + if (!inputType.hasStaticShape()) + return rewriter.notifyMatchFailure(convOp, + "expected a static shape for the input"); + + if (!filterType.hasStaticShape()) + return rewriter.notifyMatchFailure( + convOp, "expected a static shape for the filter"); + + if (!hasAllOneValues(convOp.getDilations())) + return rewriter.notifyMatchFailure(convOp, + "expected all ones for dilations"); + + if (!hasAllOneValues(convOp.getStrides())) + return rewriter.notifyMatchFailure(convOp, "expected all ones for strides"); + + auto filterShape = filterType.getShape(); + int64_t filterF = filterShape[0]; + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + int64_t filterC = filterShape[3]; + auto inputShape = inputType.getShape(); + int64_t inputN = inputShape[0]; + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + int64_t inputC = inputShape[3]; + auto outputShape = outputType.getShape(); + int64_t outputN = outputShape[0]; + int64_t outputH = outputShape[1]; + int64_t outputW = outputShape[2]; + int64_t outputF = outputShape[3]; + + // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r) + bool isSupportedFilter = false; + if (filterH == filterW && filterH == r) + isSupportedFilter = true; + if (filterH == r && filterW == 1) + isSupportedFilter = true; + if (filterH == 1 && filterW == r) + isSupportedFilter = true; + + if (!isSupportedFilter) + return rewriter.notifyMatchFailure( + convOp, "only support filter (r x r), (r x 1) or (1 x r)"); + + // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5) + static const llvm::SmallVector validConfigs = { + F_2_3, F_4_3, F_2_5}; + + TransformMapKeyTy key = {m, r}; + auto it = std::find(validConfigs.begin(), validConfigs.end(), key); + // If we cannot find the constant transformation matrix, it means we do + // not support this configuration yet. + if (it == validConfigs.end()) + return failure(); + + // All the criterias are satisfied. We can do Winograd Conv2D. + Location loc = convOp.getLoc(); + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = filterH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = filterW != 1; + int64_t heightM = leftTransform ? m : 1; + int64_t widthM = rightTransform ? m : 1; + int64_t heightR = leftTransform ? r : 1; + int64_t widthR = rightTransform ? r : 1; + + // --- Create operator for filter transform --- + Type elementType = filterType.getElementType(); + int64_t alphaH = heightM + heightR - 1; + int64_t alphaW = widthM + widthR - 1; + int64_t tileH = llvm::divideCeilSigned(outputH, heightM); + int64_t tileW = llvm::divideCeilSigned(outputW, widthM); + auto retType = RankedTensorType::get( + {tileH, tileW, alphaH, alphaW, filterC, filterF}, elementType); + Value retValue = + rewriter.create(loc, retType.getShape(), elementType); + auto transformedFilter = rewriter.create( + loc, retType, filter, retValue, m, r); + + // --- Create operator for input transform --- + + // When input size - (r - 1) is not aligned with output tile size, we need to + // pad the input data to create the full tiles as tiling. + int64_t alignedInputH = tileH * heightM + (heightR - 1); + int64_t alignedInputW = tileW * widthM + (widthR - 1); + if (alignedInputH != inputH || alignedInputW != inputW) { + auto alignedInputType = RankedTensorType::get( + {inputN, alignedInputH, alignedInputW, inputC}, elementType); + input = insertToAlignedTensor(rewriter, loc, input, alignedInputType); + } + + retType = RankedTensorType::get( + {tileH, tileW, alphaH, alphaW, inputN, inputC}, elementType); + retValue = + rewriter.create(loc, retType.getShape(), elementType); + auto transformedInput = rewriter.create( + loc, retType, input, retValue, m, r); + + Value matmulRet = + matrixMultiply(rewriter, loc, transformedFilter, transformedInput); + + // --- Create operator for output transform --- + + // When output size is not aligned with output tile size, we need to pad the + // output buffer to insert the full tiles after tiling. + int64_t alignedOutputH = tileH * heightM; + int64_t alignedOutputW = tileW * widthM; + bool isOutputUnaligned = + ((alignedOutputH != outputH) || (alignedOutputW != outputW)); + if (isOutputUnaligned) { + auto alignedOutputType = RankedTensorType::get( + {outputN, alignedOutputH, alignedOutputW, outputF}, elementType); + output = insertToAlignedTensor(rewriter, loc, output, alignedOutputType); + outputType = alignedOutputType; + } + + Value transformedOutput = rewriter.create( + loc, outputType, matmulRet, output, m, r); + + // When output size is not aligned with output tile size, extract the + // value from the padded buffer. + if (isOutputUnaligned) { + transformedOutput = extractFromAlignedTensor( + rewriter, loc, transformedOutput, + RankedTensorType::get({outputN, outputH, outputW, outputF}, + elementType)); + } + + rewriter.replaceOp(convOp, transformedOutput); + + return transformedOutput.getDefiningOp(); +} + +class WinogradConv2DNhwcFhwc final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r) + : OpRewritePattern(context), m(m), r(r) {} + + LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp, + PatternRewriter &rewriter) const override { + if (failed(winogradConv2DHelper(rewriter, convOp, m, r))) + return failure(); + + return success(); + } + +private: + int64_t m; + int64_t r; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, + int64_t r) { + MLIRContext *context = patterns.getContext(); + patterns.insert(context, m, r); +} + +} // end namespace linalg +} // end namespace mlir diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir new file mode 100644 index 0000000000000..6cca3c602d4c0 --- /dev/null +++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir @@ -0,0 +1,248 @@ +// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s + +func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> { + %0 = tensor.empty() : tensor<2x4x4x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x4x4x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %2 : tensor<2x4x4x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_4x4_3x3 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x4x4x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x4x4x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x2x2x2xf32> { + %0 = tensor.empty() : tensor<2x2x2x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x2x2x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x2x2x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> + return %2 : tensor<2x2x2x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_2x2_5x5 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x2x2x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x2x2x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x2x2x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x2x2x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x2x2x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x1x4x2xf32> { + %0 = tensor.empty() : tensor<2x1x4x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x1x4x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x1x4x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%1 : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> + return %2 : tensor<2x1x4x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_1x4_1x3 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x1x4x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x1x4x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x1x4x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x1x4x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x1x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x1x1x6x5x2xf32>) -> tensor<1x1x1x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x1x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x1x1x6x2x5xf32>) -> tensor<1x1x1x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x5x2xf32> into tensor<6x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x2x5xf32> into tensor<6x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 1, 6, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x1x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x1x6x2x2xf32>) outs(%[[S1]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x1x4x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x1x2xf32> { + %0 = tensor.empty() : tensor<2x4x1x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x1x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x4x1x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%1 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> + return %2 : tensor<2x4x1x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_4x1_3x1 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x1x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x4x1x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x1x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x4x1x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x6x1x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<1x1x6x1x5x2xf32>) -> tensor<1x1x6x1x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x6x1x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<1x1x6x1x2x5xf32>) -> tensor<1x1x6x1x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x5x2xf32> into tensor<6x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x2x5xf32> into tensor<6x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x6x1x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x1x2x2xf32>) outs(%[[S1]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x4x1x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> { + %0 = tensor.empty() : tensor<2x8x8x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x8x8x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %2 : tensor<2x8x8x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_aligned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> { + %0 = tensor.empty() : tensor<2x9x9x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x9x9x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> + return %2 : tensor<2x9x9x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_unaligned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x9x9x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32> +// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32> +// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> { + %0 = tensor.empty() : tensor<2x4x4x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x4x4x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %2 : tensor<2x4x4x2xf32> +} + +// CHECK-LABEL: conv2d_unsupported_1 +// CHECK: linalg.conv_2d_nhwc_fhwc + +// ----- + +func.func @conv2d_unsupported_2(%arg0: tensor<2x7x7x5xf32>, %arg1: tensor<2x4x4x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> { + %0 = tensor.empty() : tensor<2x4x4x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x4x4x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x7x7x5xf32>, tensor<2x4x4x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %2 : tensor<2x4x4x2xf32> +} + +// CHECK-LABEL: conv2d_unsupported_2 +// CHECK: linalg.conv_2d_nhwc_fhwc + +// ----- + +func.func @conv2d_unsupported_3(%arg0: tensor, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor<2x3x3x5xf32>) outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: conv2d_unsupported_3 +// CHECK: linalg.conv_2d_nhwc_fhwc diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index 4892fa2f99a7c..12cb46a5968f1 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -123,6 +123,10 @@ struct TestLinalgTransforms *this, "test-erase-unnecessary-inputs", llvm::cl::desc("Test patterns to erase unnecessary inputs"), llvm::cl::init(false)}; + Option testWinogradConv2D{ + *this, "test-winograd-conv2d", + llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"), + llvm::cl::init(false)}; }; } // namespace @@ -207,6 +211,13 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) { (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyWinogradConv2D(func::FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3); + populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { if (testPatterns) @@ -231,6 +242,8 @@ void TestLinalgTransforms::runOnOperation() { return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); if (testEraseUnnecessaryInputs) return applyEraseUnnecessaryInputs(getOperation()); + if (testWinogradConv2D) + return applyWinogradConv2D(getOperation()); } namespace mlir { From 374b0d5b83ce080bea690199380e270a36ad1c52 Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Mon, 17 Jun 2024 11:49:08 +0100 Subject: [PATCH 2/9] [mlir][linalg] Add transform operator for Winograd Conv2D algorithm Add a transform operator structured.winograd_conv2d to convert linalg.conv_2d_nhwc_fhwc to Linalg winograd operators. --- .../Linalg/TransformOps/LinalgTransformOps.td | 51 +++++++++++ .../Dialect/Linalg/Transforms/Transforms.h | 7 ++ .../TransformOps/LinalgTransformOps.cpp | 25 ++++++ .../Linalg/Transforms/WinogradConv2D.cpp | 6 ++ .../Linalg/transform-winograd-conv2d.mlir | 88 +++++++++++++++++++ 5 files changed, 177 insertions(+) create mode 100644 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 93e2c2db729da..68d0f713caad4 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp : }]; } +//===----------------------------------------------------------------------===// +// Winograd Conv2D +//===----------------------------------------------------------------------===// + +def WinogradConv2DOp : Op { + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + #### Return modes: + + This operation fails if `target` is unsupported. Otherwise, the operation + succeeds and returns a handle of the sequence that replaces the original + convolution. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + I64Attr:$m, + I64Attr:$r); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = + "$target attr-dict `:` functional-type($target, results)"; + + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::linalg::LinalgOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 835aeaf2ffed3..da107b66257a5 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1312,6 +1312,13 @@ FailureOr transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS = true); +/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm +/// F(m x m, r x r). m is the dimension size of output and r is the dimension +/// size of filter. +FailureOr winogradConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp op, int64_t m, + int64_t r); + //===----------------------------------------------------------------------===// // Rewrite patterns wrapping transformations. // TODO: every single such pattern should be a close to noop wrapper around a diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index bc02788f9c441..d051b29e1f06f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne( return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// WinogradConv2DOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne( + transform::TransformRewriter &rewriter, linalg::LinalgOp target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + rewriter.setInsertionPoint(target); + auto maybeTransformed = + TypeSwitch>(target) + .Case([&](linalg::Conv2DNhwcFhwcOp op) { + return winogradConv2D(rewriter, op, getM(), getR()); + }) + .Default([&](Operation *op) { + return rewriter.notifyMatchFailure(op, "not supported"); + }); + + if (failed(maybeTransformed)) + return emitDefaultSilenceableFailure(target); + + results.push_back(*maybeTransformed); + return DiagnosedSilenceableFailure::success(); +} + #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp index 86e834d51f2fc..d1f4be8bbf29a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -311,6 +311,12 @@ class WinogradConv2DNhwcFhwc final } // end anonymous namespace //===----------------------------------------------------------------------===// +FailureOr winogradConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp op, int64_t m, + int64_t r) { + return winogradConv2DHelper(rewriter, op, m, r); +} + void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, int64_t r) { MLIRContext *context = patterns.getContext(); diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir new file mode 100644 index 0000000000000..1e74fea5a1c31 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s + +func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> { + %0 = tensor.empty() : tensor<2x8x8x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x8x8x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %2 : tensor<2x8x8x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> { + %0 = tensor.empty() : tensor<2x9x9x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x9x9x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> + return %2 : tensor<2x9x9x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_unaligned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x9x9x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32> +// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32> +// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32> +// CHECK-NEXT: } From c94b1a3d2b30eefaa556b8ddf1f4767d89d72fe0 Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Wed, 26 Jun 2024 09:43:43 +0100 Subject: [PATCH 3/9] Revert "[mlir][linalg] Add transform operator for Winograd Conv2D algorithm" This reverts commit 374b0d5b83ce080bea690199380e270a36ad1c52. --- .../Linalg/TransformOps/LinalgTransformOps.td | 51 ----------- .../Dialect/Linalg/Transforms/Transforms.h | 7 -- .../TransformOps/LinalgTransformOps.cpp | 25 ------ .../Linalg/Transforms/WinogradConv2D.cpp | 6 -- .../Linalg/transform-winograd-conv2d.mlir | 88 ------------------- 5 files changed, 177 deletions(-) delete mode 100644 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 68d0f713caad4..93e2c2db729da 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2587,55 +2587,4 @@ def MapCopyToThreadsOp : }]; } -//===----------------------------------------------------------------------===// -// Winograd Conv2D -//===----------------------------------------------------------------------===// - -def WinogradConv2DOp : Op { - let description = [{ - Winograd Conv2D algorithm will convert linalg Conv2D operator into batched - matrix multiply. Before the matrix multiply, it will convert filter and - input into a format suitable for batched matrix multiply. After the matrix - multiply, it will convert output to the final result tensor. - - The algorithm F(m x m, r x r) is - - Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A - - The size of output Y is m x m. The size of filter g is r x r. The size of - input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are - transformation matrices. - - #### Return modes: - - This operation fails if `target` is unsupported. Otherwise, the operation - succeeds and returns a handle of the sequence that replaces the original - convolution. - }]; - - let arguments = (ins TransformHandleTypeInterface:$target, - I64Attr:$m, - I64Attr:$r); - let results = (outs TransformHandleTypeInterface:$transformed); - - let assemblyFormat = - "$target attr-dict `:` functional-type($target, results)"; - - let builders = [ - OpBuilder<(ins "Value":$target)> - ]; - - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::transform::TransformRewriter &rewriter, - ::mlir::linalg::LinalgOp target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); - }]; -} - #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index da107b66257a5..835aeaf2ffed3 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1312,13 +1312,6 @@ FailureOr transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS = true); -/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm -/// F(m x m, r x r). m is the dimension size of output and r is the dimension -/// size of filter. -FailureOr winogradConv2D(RewriterBase &rewriter, - linalg::Conv2DNhwcFhwcOp op, int64_t m, - int64_t r); - //===----------------------------------------------------------------------===// // Rewrite patterns wrapping transformations. // TODO: every single such pattern should be a close to noop wrapper around a diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index d051b29e1f06f..bc02788f9c441 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3480,31 +3480,6 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne( return DiagnosedSilenceableFailure::success(); } -//===----------------------------------------------------------------------===// -// WinogradConv2DOp -//===----------------------------------------------------------------------===// - -DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne( - transform::TransformRewriter &rewriter, linalg::LinalgOp target, - transform::ApplyToEachResultList &results, - transform::TransformState &state) { - rewriter.setInsertionPoint(target); - auto maybeTransformed = - TypeSwitch>(target) - .Case([&](linalg::Conv2DNhwcFhwcOp op) { - return winogradConv2D(rewriter, op, getM(), getR()); - }) - .Default([&](Operation *op) { - return rewriter.notifyMatchFailure(op, "not supported"); - }); - - if (failed(maybeTransformed)) - return emitDefaultSilenceableFailure(target); - - results.push_back(*maybeTransformed); - return DiagnosedSilenceableFailure::success(); -} - #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp index d1f4be8bbf29a..86e834d51f2fc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -311,12 +311,6 @@ class WinogradConv2DNhwcFhwc final } // end anonymous namespace //===----------------------------------------------------------------------===// -FailureOr winogradConv2D(RewriterBase &rewriter, - linalg::Conv2DNhwcFhwcOp op, int64_t m, - int64_t r) { - return winogradConv2DHelper(rewriter, op, m, r); -} - void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, int64_t r) { MLIRContext *context = patterns.getContext(); diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir deleted file mode 100644 index 1e74fea5a1c31..0000000000000 --- a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir +++ /dev/null @@ -1,88 +0,0 @@ -// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s - -func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> { - %0 = tensor.empty() : tensor<2x8x8x2xf32> - %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<2x8x8x2xf32> - %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> - return %2 : tensor<2x8x8x2xf32> -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) - transform.yield - } -} - -// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-LABEL: func.func @conv2d -// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> { -// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32> -// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) { -// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// CHECK-NEXT: linalg.yield %[[IN]] : f32 -// CHECK-NEXT: } -> tensor<2x8x8x2xf32> -// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32> -// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32> -// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32> -// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32> -// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32> -// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32> -// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32> -// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32> -// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32> -// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> -// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32> -// CHECK-NEXT: } - -// ----- - -func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> { - %0 = tensor.empty() : tensor<2x9x9x2xf32> - %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<2x9x9x2xf32> - %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> - return %2 : tensor<2x9x9x2xf32> -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) - transform.yield - } -} - -// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-LABEL: func.func @conv2d_unaligned -// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> { -// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32> -// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) { -// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// CHECK-NEXT: linalg.yield %[[IN]] : f32 -// CHECK-NEXT: } -> tensor<2x9x9x2xf32> -// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32> -// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32> -// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32> -// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32> -// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32> -// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32> -// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32> -// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32> -// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32> -// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32> -// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32> -// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32> -// CHECK-NEXT: %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32> -// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> -// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32> -// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32> -// CHECK-NEXT: } From 5a391881394094bfd747cb97bf023ed3df06923e Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Wed, 26 Jun 2024 09:44:19 +0100 Subject: [PATCH 4/9] Revert "[mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm" This reverts commit 4240341b4f06f1b77f63b0f619cae3804d88eb68. --- .../mlir/Dialect/Linalg/IR/LinalgOps.td | 114 ------- .../Dialect/Linalg/Transforms/Transforms.h | 4 - mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 78 ----- .../Dialect/Linalg/Transforms/CMakeLists.txt | 1 - .../Linalg/Transforms/WinogradConv2D.cpp | 321 ------------------ mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 248 -------------- .../Dialect/Linalg/TestLinalgTransforms.cpp | 13 - 7 files changed, 779 deletions(-) delete mode 100644 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp delete mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index de1097b6ac27b..64c538367267d 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -154,118 +154,4 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax", let hasVerifier = 1; } -def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> { - let summary = "Winograd filter transform operator"; - let description = [{ - Winograd Conv2D algorithm will convert linalg Conv2D operator into batched - matrix multiply. Before the matrix multiply, it will convert filter and - input into a format suitable for batched matrix multiply. After the matrix - multiply, it will convert output to the final result tensor. - - The algorithm F(m x m, r x r) is - - Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A - - The size of output Y is m x m. The size of filter g is r x r. The size of - input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are - transformation matrices. - - This operator is defined to represent the high level concept of filter - transformation (G x g x G^T) in the Winograd Conv2D algorithm. - }]; - - let arguments = (ins AnyRankedTensor:$filter, - AnyRankedTensor:$output, - I64Attr:$m, - I64Attr:$r - ); - - let results = (outs AnyRankedTensor:$result); - let assemblyFormat = [{ - attr-dict - `m` `(` $m `)` - `r` `(` $r `)` - `ins` `(` $filter `:` type($filter) `)` - `outs` `(` $output `:` type($output) `)` - `->` type($result) - }]; - let hasVerifier = 1; -} - -def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> { - let summary = "Winograd input transform operator"; - let description = [{ - Winograd Conv2D algorithm will convert linalg Conv2D operator into batched - matrix multiply. Before the matrix multiply, it will convert filter and - input into a format suitable for batched matrix multiply. After the matrix - multiply, it will convert output to the final result tensor. - - The algorithm F(m x m, r x r) is - - Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A - - The size of output Y is m x m. The size of filter g is r x r. The size of - input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are - transformation matrices. - - This operator is defined to represent the high level concept of input - transformation (B^T x d x B) in the Winograd Conv2D algorithm. - }]; - - let arguments = (ins AnyRankedTensor:$input, - AnyRankedTensor:$output, - I64Attr:$m, - I64Attr:$r - ); - - let results = (outs AnyRankedTensor:$result); - let assemblyFormat = [{ - attr-dict - `m` `(` $m `)` - `r` `(` $r `)` - `ins` `(` $input `:` type($input) `)` - `outs` `(` $output `:` type($output) `)` - `->` type($result) - }]; - let hasVerifier = 1; -} - -def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> { - let summary = "Winograd output transform operator"; - let description = [{ - Winograd Conv2D algorithm will convert linalg Conv2D operator into batched - matrix multiply. Before the matrix multiply, it will convert filter and - input into a format suitable for batched matrix multiply. After the matrix - multiply, it will convert output to the final result tensor. - - The algorithm F(m x m, r x r) is - - Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A - - The size of output Y is m x m. The size of filter g is r x r. The size of - input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are - transformation matrices. - - This operator is defined to represent the high level concept of output - transformation (A^T x y x A) in the Winograd Conv2D algorithm. - }]; - - let arguments = (ins AnyRankedTensor:$value, - AnyRankedTensor:$output, - I64Attr:$m, - I64Attr:$r - ); - - let results = (outs AnyRankedTensor:$result); - let assemblyFormat = [{ - attr-dict - `m` `(` $m `)` - `r` `(` $r `)` - `ins` `(` $value `:` type($value) `)` - `outs` `(` $output `:` type($output) `)` - `->` type($result) - }]; - let hasVerifier = 1; -} - #endif // LINALG_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 835aeaf2ffed3..05e97befdec1f 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1692,10 +1692,6 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns, void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn); -/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r). -void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, - int64_t r); - } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 7bf2a5bca037f..57d126603ebd7 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2734,84 +2734,6 @@ FailureOr> SoftmaxOp::decomposeOperation(OpBuilder &b) { return SmallVector{result}; } -//===----------------------------------------------------------------------===// -// WinogradFilterTransformOp -//===----------------------------------------------------------------------===// - -LogicalResult WinogradFilterTransformOp::verify() { - auto filterType = cast(getFilter().getType()); - auto outputType = cast(getOutput().getType()); - auto filterElemType = filterType.getElementType(); - auto outputElemType = outputType.getElementType(); - if (filterElemType != outputElemType) { - return emitOpError() << "expected element type of input " << filterElemType - << " to match element type of output " - << outputElemType; - } - - unsigned filterRank = filterType.getRank(); - if (filterRank != 4) - return emitOpError() << "expected rank of input is 4"; - - unsigned outputRank = outputType.getRank(); - if (outputRank != 6) - return emitOpError() << "expected rank of output is 6"; - - return success(); -} - -//===----------------------------------------------------------------------===// -// WinogradInputTransformOp -//===----------------------------------------------------------------------===// - -LogicalResult WinogradInputTransformOp::verify() { - auto inputType = cast(getInput().getType()); - auto outputType = cast(getOutput().getType()); - auto inputElemType = inputType.getElementType(); - auto outputElemType = outputType.getElementType(); - if (inputElemType != outputElemType) { - return emitOpError() << "expected element type of input " << inputElemType - << " to match element type of output " - << outputElemType; - } - - unsigned inputRank = inputType.getRank(); - if (inputRank != 4) - return emitOpError() << "expected rank of input is 4"; - - unsigned outputRank = outputType.getRank(); - if (outputRank != 6) - return emitOpError() << "expected rank of output is 6"; - - return success(); -} - -//===----------------------------------------------------------------------===// -// WinogradOutputTransformOp -//===----------------------------------------------------------------------===// - -LogicalResult WinogradOutputTransformOp::verify() { - auto valueType = cast(getValue().getType()); - auto outputType = cast(getOutput().getType()); - auto valueElemType = valueType.getElementType(); - auto outputElemType = outputType.getElementType(); - if (valueElemType != outputElemType) { - return emitOpError() << "expected element type of value " << valueElemType - << " to match element type of output " - << outputElemType; - } - - unsigned valueRank = valueType.getRank(); - if (valueRank != 6) - return emitOpError() << "expected rank of input is 6"; - - unsigned outputRank = outputType.getRank(); - if (outputRank != 4) - return emitOpError() << "expected rank of output is 4"; - - return success(); -} - //===----------------------------------------------------------------------===// // LinalgDialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index a7dcc29b5b9be..7e3dc56e0acdc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -38,7 +38,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms Transforms.cpp TransposeConv2D.cpp Vectorization.cpp - WinogradConv2D.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp deleted file mode 100644 index 86e834d51f2fc..0000000000000 --- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp +++ /dev/null @@ -1,321 +0,0 @@ -//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Implement Winograd Conv2D algorithm. The implementation is based on the -// paper: Fast Algorithms for Convolutional Neural Networks -// (https://arxiv.org/abs/1509.09308) -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/Support/MathExtras.h" - -namespace mlir { -namespace linalg { - -namespace { - -using TransformMapKeyTy = std::pair; - -// We use F(m, r) to define the size of minimal filtering algorithms. -// m is the output dimension and r is the filter dimension. We can get -// the input dimension, alpha, from the formula, alpha = m + r - 1. -// -// For example, when m = 2 and r = 3, we know its input size is 4. -// The Conv2D will operate on 4x4 input data with 3x3 filter and get -// 2x2 output result. -constexpr TransformMapKeyTy F_2_3{2, 3}; -constexpr TransformMapKeyTy F_4_3{4, 3}; -constexpr TransformMapKeyTy F_2_5{2, 5}; - -Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) { - auto type = cast(data.getType()); - auto elementType = type.getElementType(); - auto shape = type.getShape(); - auto collapseType = RankedTensorType::get( - {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]}, - elementType); - SmallVector reassociation = {{0, 1, 2, 3}, {4}, {5}}; - return rewriter.create(loc, collapseType, data, - reassociation); -} - -// This function generates linalg.batch_matmul to multiply input with filter. -// linalg.batch_matmul only supports 3-dimension data sets. We can treat -// tileH x tileW x H x W data as the 1-dimension data array. That is to convert -// [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this way, we -// can convert 6-dimension input data to 3-dimension representation that is -// suitable for linalg.batch_matmul. -// -// Batched matmul will do the matrix multiply with the reduction on channel. -// -// We get -// -// %collapsed_input = tensor.collapse_shape %input -// %collapsed_filter = tensor.collapse_shape %filter -// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter -// %expanded_ret = tensor.expand_shape %ret -// -// After this function, we get return value with data layout -// (tileH, tileW, H, W, N, F). -Value matrixMultiply(RewriterBase &rewriter, Location loc, - Value transformedFilter, Value transformedInput) { - auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter); - auto collapseInput = collapse2DData(rewriter, loc, transformedInput); - - // Batched matrix multiply - auto filterType = cast(transformedFilter.getType()); - auto filterShape = filterType.getShape(); - auto inputType = cast(transformedInput.getType()); - auto inputElemType = inputType.getElementType(); - auto inputShape = inputType.getShape(); - - auto matmulType = RankedTensorType::get( - {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3], - inputShape[4], filterShape[5]}, - inputElemType); - Value init = rewriter.create(loc, matmulType.getShape(), - inputElemType); - - auto matmulOp = rewriter.create( - loc, matmulType, ValueRange({collapseInput, collapseFilter}), - ValueRange{init}); - - // Expand matmul result - SmallVector reassociation = {{0, 1, 2, 3}, {4}, {5}}; - auto expandType = - RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2], - inputShape[3], inputShape[4], filterShape[5]}, - inputElemType); - auto expandOutput = rewriter.create( - loc, expandType, matmulOp.getResult(0), reassociation); - return expandOutput; -} - -Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value, - RankedTensorType alignedType) { - Value alignedInput = rewriter.create( - loc, alignedType.getShape(), alignedType.getElementType()); - - auto zeroIndex = rewriter.getIndexAttr(0); - auto oneIndex = rewriter.getIndexAttr(1); - SmallVector offsets(4, zeroIndex); - SmallVector strides(4, oneIndex); - - auto valueType = cast(value.getType()); - auto valueShape = valueType.getShape(); - SmallVector sizes; - sizes.emplace_back(rewriter.getIndexAttr(valueShape[0])); - sizes.emplace_back(rewriter.getIndexAttr(valueShape[1])); - sizes.emplace_back(rewriter.getIndexAttr(valueShape[2])); - sizes.emplace_back(rewriter.getIndexAttr(valueShape[3])); - - return rewriter.create(loc, value, alignedInput, - offsets, sizes, strides); -} - -Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc, - Value value, RankedTensorType extractedType) { - auto zeroIndex = rewriter.getIndexAttr(0); - auto oneIndex = rewriter.getIndexAttr(1); - SmallVector offsets(4, zeroIndex); - SmallVector strides(4, oneIndex); - - auto extractedShape = extractedType.getShape(); - SmallVector sizes; - sizes.emplace_back(rewriter.getIndexAttr(extractedShape[0])); - sizes.emplace_back(rewriter.getIndexAttr(extractedShape[1])); - sizes.emplace_back(rewriter.getIndexAttr(extractedShape[2])); - sizes.emplace_back(rewriter.getIndexAttr(extractedShape[3])); - - return rewriter.create(loc, extractedType, value, - offsets, sizes, strides); -} - -bool hasAllOneValues(DenseIntElementsAttr attr) { - return llvm::all_of( - attr, [](const APInt &element) { return element.getSExtValue() == 1; }); -} - -FailureOr winogradConv2DHelper(RewriterBase &rewriter, - linalg::Conv2DNhwcFhwcOp convOp, - int64_t m, int64_t r) { - Value input = convOp.getInputs()[0]; - Value filter = convOp.getInputs()[1]; - Value output = convOp.getOutputs()[0]; - auto inputType = cast(input.getType()); - auto filterType = cast(filter.getType()); - auto outputType = cast(output.getType()); - - if (!inputType.hasStaticShape()) - return rewriter.notifyMatchFailure(convOp, - "expected a static shape for the input"); - - if (!filterType.hasStaticShape()) - return rewriter.notifyMatchFailure( - convOp, "expected a static shape for the filter"); - - if (!hasAllOneValues(convOp.getDilations())) - return rewriter.notifyMatchFailure(convOp, - "expected all ones for dilations"); - - if (!hasAllOneValues(convOp.getStrides())) - return rewriter.notifyMatchFailure(convOp, "expected all ones for strides"); - - auto filterShape = filterType.getShape(); - int64_t filterF = filterShape[0]; - int64_t filterH = filterShape[1]; - int64_t filterW = filterShape[2]; - int64_t filterC = filterShape[3]; - auto inputShape = inputType.getShape(); - int64_t inputN = inputShape[0]; - int64_t inputH = inputShape[1]; - int64_t inputW = inputShape[2]; - int64_t inputC = inputShape[3]; - auto outputShape = outputType.getShape(); - int64_t outputN = outputShape[0]; - int64_t outputH = outputShape[1]; - int64_t outputW = outputShape[2]; - int64_t outputF = outputShape[3]; - - // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r) - bool isSupportedFilter = false; - if (filterH == filterW && filterH == r) - isSupportedFilter = true; - if (filterH == r && filterW == 1) - isSupportedFilter = true; - if (filterH == 1 && filterW == r) - isSupportedFilter = true; - - if (!isSupportedFilter) - return rewriter.notifyMatchFailure( - convOp, "only support filter (r x r), (r x 1) or (1 x r)"); - - // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5) - static const llvm::SmallVector validConfigs = { - F_2_3, F_4_3, F_2_5}; - - TransformMapKeyTy key = {m, r}; - auto it = std::find(validConfigs.begin(), validConfigs.end(), key); - // If we cannot find the constant transformation matrix, it means we do - // not support this configuration yet. - if (it == validConfigs.end()) - return failure(); - - // All the criterias are satisfied. We can do Winograd Conv2D. - Location loc = convOp.getLoc(); - - // For F(m x 1, r x 1), we only need to do left side transform. - bool leftTransform = filterH != 1; - // For F(1 x m, 1 x r), we only need to do right side transform. - bool rightTransform = filterW != 1; - int64_t heightM = leftTransform ? m : 1; - int64_t widthM = rightTransform ? m : 1; - int64_t heightR = leftTransform ? r : 1; - int64_t widthR = rightTransform ? r : 1; - - // --- Create operator for filter transform --- - Type elementType = filterType.getElementType(); - int64_t alphaH = heightM + heightR - 1; - int64_t alphaW = widthM + widthR - 1; - int64_t tileH = llvm::divideCeilSigned(outputH, heightM); - int64_t tileW = llvm::divideCeilSigned(outputW, widthM); - auto retType = RankedTensorType::get( - {tileH, tileW, alphaH, alphaW, filterC, filterF}, elementType); - Value retValue = - rewriter.create(loc, retType.getShape(), elementType); - auto transformedFilter = rewriter.create( - loc, retType, filter, retValue, m, r); - - // --- Create operator for input transform --- - - // When input size - (r - 1) is not aligned with output tile size, we need to - // pad the input data to create the full tiles as tiling. - int64_t alignedInputH = tileH * heightM + (heightR - 1); - int64_t alignedInputW = tileW * widthM + (widthR - 1); - if (alignedInputH != inputH || alignedInputW != inputW) { - auto alignedInputType = RankedTensorType::get( - {inputN, alignedInputH, alignedInputW, inputC}, elementType); - input = insertToAlignedTensor(rewriter, loc, input, alignedInputType); - } - - retType = RankedTensorType::get( - {tileH, tileW, alphaH, alphaW, inputN, inputC}, elementType); - retValue = - rewriter.create(loc, retType.getShape(), elementType); - auto transformedInput = rewriter.create( - loc, retType, input, retValue, m, r); - - Value matmulRet = - matrixMultiply(rewriter, loc, transformedFilter, transformedInput); - - // --- Create operator for output transform --- - - // When output size is not aligned with output tile size, we need to pad the - // output buffer to insert the full tiles after tiling. - int64_t alignedOutputH = tileH * heightM; - int64_t alignedOutputW = tileW * widthM; - bool isOutputUnaligned = - ((alignedOutputH != outputH) || (alignedOutputW != outputW)); - if (isOutputUnaligned) { - auto alignedOutputType = RankedTensorType::get( - {outputN, alignedOutputH, alignedOutputW, outputF}, elementType); - output = insertToAlignedTensor(rewriter, loc, output, alignedOutputType); - outputType = alignedOutputType; - } - - Value transformedOutput = rewriter.create( - loc, outputType, matmulRet, output, m, r); - - // When output size is not aligned with output tile size, extract the - // value from the padded buffer. - if (isOutputUnaligned) { - transformedOutput = extractFromAlignedTensor( - rewriter, loc, transformedOutput, - RankedTensorType::get({outputN, outputH, outputW, outputF}, - elementType)); - } - - rewriter.replaceOp(convOp, transformedOutput); - - return transformedOutput.getDefiningOp(); -} - -class WinogradConv2DNhwcFhwc final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r) - : OpRewritePattern(context), m(m), r(r) {} - - LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp, - PatternRewriter &rewriter) const override { - if (failed(winogradConv2DHelper(rewriter, convOp, m, r))) - return failure(); - - return success(); - } - -private: - int64_t m; - int64_t r; -}; -} // end anonymous namespace - -//===----------------------------------------------------------------------===// -void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, - int64_t r) { - MLIRContext *context = patterns.getContext(); - patterns.insert(context, m, r); -} - -} // end namespace linalg -} // end namespace mlir diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir deleted file mode 100644 index 6cca3c602d4c0..0000000000000 --- a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir +++ /dev/null @@ -1,248 +0,0 @@ -// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s - -func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> { - %0 = tensor.empty() : tensor<2x4x4x2xf32> - %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<2x4x4x2xf32> - %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> - return %2 : tensor<2x4x4x2xf32> -} - -// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-LABEL: func.func @conv2d_4x4_3x3 -// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> { -// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32> -// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) { -// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// CHECK-NEXT: linalg.yield %[[IN]] : f32 -// CHECK-NEXT: } -> tensor<2x4x4x2xf32> -// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32> -// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32> -// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32> -// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32> -// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32> -// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32> -// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32> -// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> -// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32> -// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> -// CHECK-NEXT: return %[[S8]] : tensor<2x4x4x2xf32> -// CHECK-NEXT: } - -// ----- - -func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x2x2x2xf32> { - %0 = tensor.empty() : tensor<2x2x2x2xf32> - %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x2x2x2xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<2x2x2x2xf32> - %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> - return %2 : tensor<2x2x2x2xf32> -} - -// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-LABEL: func.func @conv2d_2x2_5x5 -// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x2x2x2xf32> { -// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x2x2x2xf32> -// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x2x2x2xf32>) { -// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// CHECK-NEXT: linalg.yield %[[IN]] : f32 -// CHECK-NEXT: } -> tensor<2x2x2x2xf32> -// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32> -// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32> -// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32> -// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32> -// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32> -// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32> -// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32> -// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> -// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32> -// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> -// CHECK-NEXT: return %[[S8]] : tensor<2x2x2x2xf32> -// CHECK-NEXT: } - -// ----- - -func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x1x4x2xf32> { - %0 = tensor.empty() : tensor<2x1x4x2xf32> - %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x1x4x2xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<2x1x4x2xf32> - %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%1 : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> - return %2 : tensor<2x1x4x2xf32> -} - -// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-LABEL: func.func @conv2d_1x4_1x3 -// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x1x4x2xf32> { -// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x1x4x2xf32> -// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x1x4x2xf32>) { -// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// CHECK-NEXT: linalg.yield %[[IN]] : f32 -// CHECK-NEXT: } -> tensor<2x1x4x2xf32> -// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x1x6x5x2xf32> -// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x1x1x6x5x2xf32>) -> tensor<1x1x1x6x5x2xf32> -// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x1x6x2x5xf32> -// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x1x1x6x2x5xf32>) -> tensor<1x1x1x6x2x5xf32> -// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x5x2xf32> into tensor<6x5x2xf32> -// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x2x5xf32> into tensor<6x2x5xf32> -// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32> -// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32> -// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 1, 6, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x1x6x2x2xf32> -// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x1x6x2x2xf32>) outs(%[[S1]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> -// CHECK-NEXT: return %[[S8]] : tensor<2x1x4x2xf32> -// CHECK-NEXT: } - -// ----- - -func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x1x2xf32> { - %0 = tensor.empty() : tensor<2x4x1x2xf32> - %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x1x2xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<2x4x1x2xf32> - %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%1 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> - return %2 : tensor<2x4x1x2xf32> -} - -// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-LABEL: func.func @conv2d_4x1_3x1 -// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x1x2xf32> { -// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x4x1x2xf32> -// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x1x2xf32>) { -// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// CHECK-NEXT: linalg.yield %[[IN]] : f32 -// CHECK-NEXT: } -> tensor<2x4x1x2xf32> -// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x6x1x5x2xf32> -// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<1x1x6x1x5x2xf32>) -> tensor<1x1x6x1x5x2xf32> -// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x6x1x2x5xf32> -// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<1x1x6x1x2x5xf32>) -> tensor<1x1x6x1x2x5xf32> -// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x5x2xf32> into tensor<6x5x2xf32> -// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x2x5xf32> into tensor<6x2x5xf32> -// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32> -// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32> -// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x6x1x2x2xf32> -// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x1x2x2xf32>) outs(%[[S1]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> -// CHECK-NEXT: return %[[S8]] : tensor<2x4x1x2xf32> -// CHECK-NEXT: } - -// ----- - -func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> { - %0 = tensor.empty() : tensor<2x8x8x2xf32> - %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<2x8x8x2xf32> - %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> - return %2 : tensor<2x8x8x2xf32> -} - -// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-LABEL: func.func @conv2d_aligned -// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> { -// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32> -// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) { -// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// CHECK-NEXT: linalg.yield %[[IN]] : f32 -// CHECK-NEXT: } -> tensor<2x8x8x2xf32> -// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32> -// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32> -// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32> -// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32> -// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32> -// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32> -// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32> -// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32> -// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32> -// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> -// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32> -// CHECK-NEXT: } - -// ----- - -func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> { - %0 = tensor.empty() : tensor<2x9x9x2xf32> - %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<2x9x9x2xf32> - %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> - return %2 : tensor<2x9x9x2xf32> -} - -// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-LABEL: func.func @conv2d_unaligned -// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> { -// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32> -// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) { -// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// CHECK-NEXT: linalg.yield %[[IN]] : f32 -// CHECK-NEXT: } -> tensor<2x9x9x2xf32> -// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32> -// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32> -// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32> -// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32> -// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32> -// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32> -// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32> -// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32> -// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32> -// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32> -// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32> -// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32> -// CHECK-NEXT: %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32> -// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> -// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32> -// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32> -// CHECK-NEXT: } - -// ----- - -func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> { - %0 = tensor.empty() : tensor<2x4x4x2xf32> - %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<2x4x4x2xf32> - %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> - return %2 : tensor<2x4x4x2xf32> -} - -// CHECK-LABEL: conv2d_unsupported_1 -// CHECK: linalg.conv_2d_nhwc_fhwc - -// ----- - -func.func @conv2d_unsupported_2(%arg0: tensor<2x7x7x5xf32>, %arg1: tensor<2x4x4x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> { - %0 = tensor.empty() : tensor<2x4x4x2xf32> - %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<2x4x4x2xf32> - %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x7x7x5xf32>, tensor<2x4x4x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> - return %2 : tensor<2x4x4x2xf32> -} - -// CHECK-LABEL: conv2d_unsupported_2 -// CHECK: linalg.conv_2d_nhwc_fhwc - -// ----- - -func.func @conv2d_unsupported_3(%arg0: tensor, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor) -> tensor { - %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor<2x3x3x5xf32>) outs(%arg2 : tensor) -> tensor - return %0 : tensor -} - -// CHECK-LABEL: conv2d_unsupported_3 -// CHECK: linalg.conv_2d_nhwc_fhwc diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index 12cb46a5968f1..4892fa2f99a7c 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -123,10 +123,6 @@ struct TestLinalgTransforms *this, "test-erase-unnecessary-inputs", llvm::cl::desc("Test patterns to erase unnecessary inputs"), llvm::cl::init(false)}; - Option testWinogradConv2D{ - *this, "test-winograd-conv2d", - llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"), - llvm::cl::init(false)}; }; } // namespace @@ -211,13 +207,6 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) { (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } -static void applyWinogradConv2D(func::FuncOp funcOp) { - RewritePatternSet patterns(funcOp.getContext()); - populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3); - populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); -} - /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { if (testPatterns) @@ -242,8 +231,6 @@ void TestLinalgTransforms::runOnOperation() { return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); if (testEraseUnnecessaryInputs) return applyEraseUnnecessaryInputs(getOperation()); - if (testWinogradConv2D) - return applyWinogradConv2D(getOperation()); } namespace mlir { From 690662771c806a2f7301bdc4dedc983047c41d35 Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Mon, 17 Jun 2024 11:24:07 +0100 Subject: [PATCH 5/9] [mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm Define high level winograd operators and convert conv_2d_nhwc_fhwc into winograd operators. According to Winograd Conv2D algorithm, we need three transform operators for input, filter, and output transformation. The formula of Winograd Conv2D algorithm is Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A filter transform: G x g x G^T input transform: B^T x d x B output transform: A^T x y x A The implementation is based on the paper, Fast Algorithm for Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308) --- .../mlir/Dialect/Linalg/IR/LinalgOps.td | 117 ++++++ .../Dialect/Linalg/Transforms/Transforms.h | 4 + mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 107 ++++++ .../Dialect/Linalg/Transforms/CMakeLists.txt | 1 + .../Linalg/Transforms/WinogradConv2D.cpp | 334 ++++++++++++++++++ mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 193 ++++++++++ .../Dialect/Linalg/TestLinalgTransforms.cpp | 13 + 7 files changed, 769 insertions(+) create mode 100644 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp create mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 64c538367267d..a9007c8db3078 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -154,4 +154,121 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax", let hasVerifier = 1; } +def Linalg_WinogradFilterTransformOp : + Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> { + let summary = "Winograd filter transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of filter + transformation (G x g x G^T) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter, + TensorRankOf<[AnyType], [4]>:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs TensorRankOf<[AnyType], [4]>:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $filter `:` type($filter) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + +def Linalg_WinogradInputTransformOp : + Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> { + let summary = "Winograd input transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of input + transformation (B^T x d x B) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins TensorRankOf<[AnyType], [4]>:$input, + TensorRankOf<[AnyType], [6]>:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs TensorRankOf<[AnyType], [6]>:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $input `:` type($input) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + +def Linalg_WinogradOutputTransformOp : + Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> { + let summary = "Winograd output transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of output + transformation (A^T x y x A) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins TensorRankOf<[AnyType], [6]>:$value, + TensorRankOf<[AnyType], [4]>:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs TensorRankOf<[AnyType], [4]>:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $value `:` type($value) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + #endif // LINALG_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 05e97befdec1f..835aeaf2ffed3 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1692,6 +1692,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns, void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn); +/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r). +void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, + int64_t r); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 57d126603ebd7..1283315f2eaef 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2734,6 +2734,113 @@ FailureOr> SoftmaxOp::decomposeOperation(OpBuilder &b) { return SmallVector{result}; } +//===----------------------------------------------------------------------===// +// WinogradFilterTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradFilterTransformOp::verify() { + auto filterType = cast(getFilter().getType()); + ArrayRef filterShape = filterType.getShape(); + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + int64_t r = getR(); + + if (filterH != r && filterH != 1) + return failure(); + if (filterW != r && filterW != 1) + return failure(); + if (filterH == 1 && filterW == 1) + return failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// WinogradInputTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradInputTransformOp::verify() { + auto inputType = cast(getInput().getType()); + ArrayRef inputShape = inputType.getShape(); + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + auto outputType = cast(getOutput().getType()); + ArrayRef outputShape = outputType.getShape(); + int64_t outputH = outputShape[0]; + int64_t outputW = outputShape[1]; + int64_t outputTileH = outputShape[2]; + int64_t outputTileW = outputShape[3]; + int m = getM(); + int r = getR(); + bool leftTransform = inputH != 1; + bool rightTransform = inputW != 1; + + if (!leftTransform && !rightTransform) + return failure(); + + if (leftTransform) { + int64_t tileH = (inputH - (r - 1)) / m; + if (inputH != tileH * m + (r - 1)) + return failure(); + if (tileH != outputTileH) + return failure(); + if (outputH != m + r - 1) + return failure(); + } + + if (rightTransform) { + int64_t tileW = (inputW - (r - 1)) / m; + if (inputW != tileW * m + (r - 1)) + return failure(); + if (tileW != outputTileW) + return failure(); + if (outputW != m + r - 1) + return failure(); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// WinogradOutputTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradOutputTransformOp::verify() { + auto valueType = cast(getValue().getType()); + ArrayRef valueShape = valueType.getShape(); + int64_t valueH = valueShape[0]; + int64_t valueW = valueShape[1]; + int64_t valueTileH = valueShape[2]; + int64_t valueTileW = valueShape[3]; + auto outputType = cast(getOutput().getType()); + ArrayRef outputShape = outputType.getShape(); + int64_t outputH = outputShape[1]; + int64_t outputW = outputShape[2]; + int m = getM(); + int r = getR(); + bool leftTransform = valueH != 1; + bool rightTransform = valueW != 1; + + if (!leftTransform && !rightTransform) + return failure(); + + if (leftTransform) { + if (valueH != m + r - 1) + return failure(); + if (outputH != m * valueTileH) + return failure(); + } + + if (rightTransform) { + if (valueW != m + r - 1) + return failure(); + if (outputW != m * valueTileW) + return failure(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // LinalgDialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 7e3dc56e0acdc..a7dcc29b5b9be 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms Transforms.cpp TransposeConv2D.cpp Vectorization.cpp + WinogradConv2D.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp new file mode 100644 index 0000000000000..6b46f9e07abf8 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -0,0 +1,334 @@ +//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implement Winograd Conv2D algorithm. The implementation is based on the +// paper: Fast Algorithms for Convolutional Neural Networks +// (https://arxiv.org/abs/1509.09308) +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir { +namespace linalg { + +namespace { + +using TransformMapKeyTy = std::pair; + +/// We use F(m, r) to define the size of minimal filtering algorithms. +/// m is the output dimension and r is the filter dimension. We can get +/// the input dimension, alpha, from the formula, alpha = m + r - 1. +/// +/// For example, when m = 2 and r = 3, we know its input size is 4. +/// The Conv2D will operate on 4x4 input data with 3x3 filter and get +/// 2x2 output result. +constexpr TransformMapKeyTy F_2_3{2, 3}; +constexpr TransformMapKeyTy F_4_3{4, 3}; +constexpr TransformMapKeyTy F_2_5{2, 5}; + +/// This function generates linalg.batch_matmul to multiply input with filter. +/// linalg.batch_matmul only supports 3-dimensional inputs. We can treat +/// tileH x tileW x H x W data as the 1-dimensional data array. That is to +/// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this +/// way, we can convert 6-dimensional inputs to 3-dimensional representation +/// that is suitable for linalg.batch_matmul. +/// +/// Batched matmul will do the matrix multiply with the reduction on channel. +/// +/// We get +/// +/// %collapsed_input = tensor.collapse_shape %input +/// %collapsed_filter = tensor.collapse_shape %filter +/// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter +/// %expanded_ret = tensor.expand_shape %ret +/// +/// After this function, we get return value with data layout +/// (tileH, tileW, H, W, N, F). +static Value matrixMultiply(RewriterBase &rewriter, Location loc, + Value transformedFilter, Value transformedInput, + Type outputElementType) { + // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter. + auto filterType = cast(transformedFilter.getType()); + assert(filterType.hasStaticShape() && "only support static shapes."); + ArrayRef filterShape = filterType.getShape(); + Type filterElementType = filterType.getElementType(); + auto filterReassocType = RankedTensorType::get( + {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]}, + filterElementType); + SmallVector filterReassoc = {{0, 1}, {2}, {3}}; + Value collapseFilter = rewriter.create( + loc, filterReassocType, transformedFilter, filterReassoc); + + // Convert (alphaH, alphaW, tileH, tileW, N, C) to + // (alphaH x alphaW, tileH x tileW x N, C) for input. + auto inputType = cast(transformedInput.getType()); + assert(inputType.hasStaticShape() && "only support static shapes."); + ArrayRef inputShape = inputType.getShape(); + Type inputElementType = inputType.getElementType(); + auto inputReassocType = RankedTensorType::get( + {inputShape[0] * inputShape[1], + inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]}, + inputElementType); + SmallVector inputReassoc = {{0, 1}, {2, 3, 4}, {5}}; + Value collapseInput = rewriter.create( + loc, inputReassocType, transformedInput, inputReassoc); + + // Batched matrix multiply. + auto matmulType = RankedTensorType::get( + {inputShape[0] * inputShape[1], + inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]}, + outputElementType); + Value init = rewriter.create(loc, matmulType.getShape(), + outputElementType); + + auto matmulOp = rewriter.create( + loc, matmulType, ValueRange({collapseInput, collapseFilter}), + ValueRange{init}); + + // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F) + // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F). + SmallVector outputReassoc = {{0, 1}, {2, 3, 4}, {5}}; + auto outputReassocType = + RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2], + inputShape[3], inputShape[4], filterShape[3]}, + outputElementType); + auto expandOutput = rewriter.create( + loc, outputReassocType, matmulOp.getResult(0), outputReassoc); + return expandOutput; +} + +/// Create an empty tensor with alignedType and insert the value into the +/// created empty tensor with aligned size. +static Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, + Value value, + ArrayRef alignedShape) { + OpFoldResult zeroIndex = rewriter.getIndexAttr(0); + auto valueType = cast(value.getType()); + Type elementType = valueType.getElementType(); + ArrayRef valueShape = valueType.getShape(); + SmallVector lowIndices(alignedShape.size(), zeroIndex); + SmallVector highIndices; + for (unsigned i = 0; i < alignedShape.size(); ++i) { + highIndices.emplace_back( + rewriter.getIndexAttr(alignedShape[i] - valueShape[i])); + } + auto alignedType = RankedTensorType::get(alignedShape, elementType); + Value pad_value = rewriter.create( + loc, elementType, rewriter.getZeroAttr(elementType)); + return rewriter.create(loc, alignedType, value, lowIndices, + highIndices, pad_value); +} + +/// Extract sub-tensor with extractedType from value. +static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc, + Value value, + RankedTensorType extractedType) { + OpFoldResult zeroIndex = rewriter.getIndexAttr(0); + OpFoldResult oneIndex = rewriter.getIndexAttr(1); + SmallVector offsets(4, zeroIndex); + SmallVector strides(4, oneIndex); + + ArrayRef extractedShape = extractedType.getShape(); + SmallVector sizes = + getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape)); + + return rewriter.create(loc, extractedType, value, + offsets, sizes, strides); +} + +/// Utility function to check all values in the attribute are 1. +static bool hasAllOneValues(DenseIntElementsAttr attr) { + return llvm::all_of( + attr, [](const APInt &element) { return element.getSExtValue() == 1; }); +} + +/// A helper function to convert linalg.conv_2d_nhwc_fhwc to +/// linalg.winograd_*_transform ops. +static FailureOr +winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, + int64_t m, int64_t r) { + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + auto inputType = cast(input.getType()); + auto filterType = cast(filter.getType()); + auto outputType = cast(output.getType()); + + if (!inputType.hasStaticShape()) + return rewriter.notifyMatchFailure(convOp, + "expected a static shape for the input"); + + if (!filterType.hasStaticShape()) + return rewriter.notifyMatchFailure( + convOp, "expected a static shape for the filter"); + + if (!hasAllOneValues(convOp.getDilations())) + return rewriter.notifyMatchFailure(convOp, + "expected all ones for dilations"); + + if (!hasAllOneValues(convOp.getStrides())) + return rewriter.notifyMatchFailure(convOp, "expected all ones for strides"); + + ArrayRef filterShape = filterType.getShape(); + int64_t filterF = filterShape[0]; + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + int64_t filterC = filterShape[3]; + ArrayRef inputShape = inputType.getShape(); + int64_t inputN = inputShape[0]; + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + int64_t inputC = inputShape[3]; + ArrayRef outputShape = outputType.getShape(); + int64_t outputN = outputShape[0]; + int64_t outputH = outputShape[1]; + int64_t outputW = outputShape[2]; + int64_t outputF = outputShape[3]; + + // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r). + bool isSupportedFilter = false; + if (filterH == filterW && filterH == r) + isSupportedFilter = true; + if (filterH == r && filterW == 1) + isSupportedFilter = true; + if (filterH == 1 && filterW == r) + isSupportedFilter = true; + + if (!isSupportedFilter) + return rewriter.notifyMatchFailure( + convOp, "only support filter (r x r), (r x 1) or (1 x r)"); + + // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5). + static const llvm::SmallVector validConfigs = { + F_2_3, F_4_3, F_2_5}; + + TransformMapKeyTy key = {m, r}; + auto it = std::find(validConfigs.begin(), validConfigs.end(), key); + // If we cannot find the constant transformation matrix, it means we do + // not support this configuration yet. + if (it == validConfigs.end()) + return failure(); + + // All the criterias are satisfied. We can do Winograd Conv2D. + Location loc = convOp.getLoc(); + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = filterH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = filterW != 1; + int64_t heightM = leftTransform ? m : 1; + int64_t widthM = rightTransform ? m : 1; + int64_t heightR = leftTransform ? r : 1; + int64_t widthR = rightTransform ? r : 1; + + // --- Create operation for filter transform --- + Type filterElementType = filterType.getElementType(); + int64_t alphaH = heightM + heightR - 1; + int64_t alphaW = widthM + widthR - 1; + int64_t tileH = llvm::divideCeilSigned(outputH, heightM); + int64_t tileW = llvm::divideCeilSigned(outputW, widthM); + auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF}, + filterElementType); + Value retValue = rewriter.create(loc, retType.getShape(), + filterElementType); + auto transformedFilter = rewriter.create( + loc, retType, filter, retValue, m, r); + + // --- Create operation for input transform --- + + // When input size - (r - 1) is not aligned with output tile size, we need to + // pad the input data to create the full tiles as tiling. + Type inputElementType = inputType.getElementType(); + int64_t alignedInputH = tileH * heightM + (heightR - 1); + int64_t alignedInputW = tileW * widthM + (widthR - 1); + if (alignedInputH != inputH || alignedInputW != inputW) { + input = insertToAlignedTensor( + rewriter, loc, input, {inputN, alignedInputH, alignedInputW, inputC}); + } + + retType = RankedTensorType::get( + {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType); + retValue = rewriter.create(loc, retType.getShape(), + inputElementType); + auto transformedInput = rewriter.create( + loc, retType, input, retValue, m, r); + + Type outputElementType = outputType.getElementType(); + Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter, + transformedInput, outputElementType); + + // --- Create operation for output transform --- + + // When output size is not aligned with output tile size, we need to pad the + // output buffer to insert the full tiles after tiling. + int64_t alignedOutputH = tileH * heightM; + int64_t alignedOutputW = tileW * widthM; + bool isOutputUnaligned = + ((alignedOutputH != outputH) || (alignedOutputW != outputW)); + if (isOutputUnaligned) { + auto alignedOutputType = RankedTensorType::get( + {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType); + output = insertToAlignedTensor(rewriter, loc, output, + alignedOutputType.getShape()); + outputType = alignedOutputType; + } + + Value transformedOutput = rewriter.create( + loc, outputType, matmulRet, output, m, r); + + // When output size is not aligned with output tile size, extract the + // value from the padded buffer. + if (isOutputUnaligned) { + transformedOutput = extractFromAlignedTensor( + rewriter, loc, transformedOutput, + RankedTensorType::get({outputN, outputH, outputW, outputF}, + outputElementType)); + } + + rewriter.replaceOp(convOp, transformedOutput); + + return transformedOutput.getDefiningOp(); +} + +/// A rewrite pattern for Winograd Conv2D algorithm. +class WinogradConv2DNhwcFhwc final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r) + : OpRewritePattern(context), m(m), r(r) {} + + LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp, + PatternRewriter &rewriter) const override { + if (failed(winogradConv2DHelper(rewriter, convOp, m, r))) + return failure(); + + return success(); + } + +private: + int64_t m; + int64_t r; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, + int64_t r) { + MLIRContext *context = patterns.getContext(); + patterns.insert(context, m, r); +} + +} // end namespace linalg +} // end namespace mlir diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir new file mode 100644 index 0000000000000..ec11a6ef8fbee --- /dev/null +++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir @@ -0,0 +1,193 @@ +// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s + +func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %0 : tensor<2x4x4x2xf32> +} + +// CHECK-LABEL: func.func @conv2d_4x4_3x3 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> { +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x4x4x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%out : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> + return %0 : tensor<2x2x2x2xf32> +} + +// CHECK-LABEL: func.func @conv2d_2x2_5x5 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x2x2x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%out : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> + return %0 : tensor<2x1x4x2xf32> +} + +// CHECK-LABEL: func.func @conv2d_1x4_1x3 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> { +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x1x4x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%out : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> + return %0 : tensor<2x4x1x2xf32> +} + +// CHECK-LABEL: func.func @conv2d_4x1_3x1 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> { +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x1x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x4x1x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%out : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %0 : tensor<2x8x8x2xf32> +} + +// CHECK-LABEL: func.func @conv2d_aligned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> { +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x8x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> + return %0 : tensor<2x9x9x2xf32> +} + +// CHECK-LABEL: func.func @conv2d_unaligned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> { +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] { +// CHECK-NEXT: ^bb0 +// CHECK-NEXT: tensor.yield %[[CST]] : f32 +// CHECK-NEXT: } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[PADDED]] : tensor<2x14x14x5xf32>) outs(%[[S2]] : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %3 {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32> +// CHECK-NEXT: %[[PADDED_1:.*]] = tensor.pad %arg3 low[0, 0, 0, 0] high[0, 3, 3, 0] { +// CHECK-NEXT: ^bb0 +// CHECK-NEXT: tensor.yield %[[CST]] : f32 +// CHECK-NEXT: } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[S6:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x3x3x2x2xf32>) outs(%[[PADDED_1]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32> +// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3x5xf16>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf16>, tensor<2x3x3x5xf16>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %0 : tensor<2x4x4x2xf32> +} + +// CHECK-LABEL: func.func @conv2d_type_promotion +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf16>, %[[ARG1:.*]]: tensor<2x3x3x5xf16>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf16> +// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf16>) outs(%[[S0]] : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf16>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32> +// CHECK-NEXT: %[[S6:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> +// CHECK-NEXT: return %[[S6]] : tensor<2x4x4x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %0 : tensor<2x4x4x2xf32> +} + +// CHECK-LABEL: conv2d_unsupported_1 +// CHECK: linalg.conv_2d_nhwc_fhwc + +// ----- + +func.func @conv2d_unsupported_2(%arg0: tensor<2x7x7x5xf32>, %arg1: tensor<2x4x4x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x7x7x5xf32>, tensor<2x4x4x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %0 : tensor<2x4x4x2xf32> +} + +// CHECK-LABEL: conv2d_unsupported_2 +// CHECK: linalg.conv_2d_nhwc_fhwc + +// ----- + +func.func @conv2d_unsupported_3(%arg0: tensor, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor<2x3x3x5xf32>) outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: conv2d_unsupported_3 +// CHECK: linalg.conv_2d_nhwc_fhwc diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index 4892fa2f99a7c..12cb46a5968f1 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -123,6 +123,10 @@ struct TestLinalgTransforms *this, "test-erase-unnecessary-inputs", llvm::cl::desc("Test patterns to erase unnecessary inputs"), llvm::cl::init(false)}; + Option testWinogradConv2D{ + *this, "test-winograd-conv2d", + llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"), + llvm::cl::init(false)}; }; } // namespace @@ -207,6 +211,13 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) { (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyWinogradConv2D(func::FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3); + populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { if (testPatterns) @@ -231,6 +242,8 @@ void TestLinalgTransforms::runOnOperation() { return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); if (testEraseUnnecessaryInputs) return applyEraseUnnecessaryInputs(getOperation()); + if (testWinogradConv2D) + return applyWinogradConv2D(getOperation()); } namespace mlir { From bb8087930cfd79a3d4ebf6a8e959f4c30bb70fcf Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Mon, 17 Jun 2024 11:49:08 +0100 Subject: [PATCH 6/9] [mlir][linalg] Add transform operator for Winograd Conv2D algorithm Add a transform operator structured.winograd_conv2d to convert linalg.conv_2d_nhwc_fhwc to Linalg winograd operators. --- .../Linalg/TransformOps/LinalgTransformOps.td | 51 +++++++++++ .../Dialect/Linalg/Transforms/Transforms.h | 7 ++ .../TransformOps/LinalgTransformOps.cpp | 25 ++++++ .../Linalg/Transforms/WinogradConv2D.cpp | 6 ++ .../Linalg/transform-winograd-conv2d.mlir | 88 +++++++++++++++++++ 5 files changed, 177 insertions(+) create mode 100644 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 93e2c2db729da..68d0f713caad4 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp : }]; } +//===----------------------------------------------------------------------===// +// Winograd Conv2D +//===----------------------------------------------------------------------===// + +def WinogradConv2DOp : Op { + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + #### Return modes: + + This operation fails if `target` is unsupported. Otherwise, the operation + succeeds and returns a handle of the sequence that replaces the original + convolution. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + I64Attr:$m, + I64Attr:$r); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = + "$target attr-dict `:` functional-type($target, results)"; + + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::linalg::LinalgOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 835aeaf2ffed3..da107b66257a5 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1312,6 +1312,13 @@ FailureOr transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS = true); +/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm +/// F(m x m, r x r). m is the dimension size of output and r is the dimension +/// size of filter. +FailureOr winogradConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp op, int64_t m, + int64_t r); + //===----------------------------------------------------------------------===// // Rewrite patterns wrapping transformations. // TODO: every single such pattern should be a close to noop wrapper around a diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index bc02788f9c441..d051b29e1f06f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne( return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// WinogradConv2DOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne( + transform::TransformRewriter &rewriter, linalg::LinalgOp target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + rewriter.setInsertionPoint(target); + auto maybeTransformed = + TypeSwitch>(target) + .Case([&](linalg::Conv2DNhwcFhwcOp op) { + return winogradConv2D(rewriter, op, getM(), getR()); + }) + .Default([&](Operation *op) { + return rewriter.notifyMatchFailure(op, "not supported"); + }); + + if (failed(maybeTransformed)) + return emitDefaultSilenceableFailure(target); + + results.push_back(*maybeTransformed); + return DiagnosedSilenceableFailure::success(); +} + #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp index 6b46f9e07abf8..843db0c069813 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -324,6 +324,12 @@ class WinogradConv2DNhwcFhwc final } // end anonymous namespace //===----------------------------------------------------------------------===// +FailureOr winogradConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp op, int64_t m, + int64_t r) { + return winogradConv2DHelper(rewriter, op, m, r); +} + void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, int64_t r) { MLIRContext *context = patterns.getContext(); diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir new file mode 100644 index 0000000000000..1e74fea5a1c31 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s + +func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> { + %0 = tensor.empty() : tensor<2x8x8x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x8x8x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %2 : tensor<2x8x8x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> { + %0 = tensor.empty() : tensor<2x9x9x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x9x9x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> + return %2 : tensor<2x9x9x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_unaligned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x9x9x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32> +// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32> +// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32> +// CHECK-NEXT: } From cc23f43cfab82f1c0b9ddbf6cacd29a20f99d825 Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Wed, 26 Jun 2024 12:26:15 +0100 Subject: [PATCH 7/9] Address ftynse's comments --- .../Linalg/TransformOps/LinalgTransformOps.td | 8 +- .../TransformOps/LinalgTransformOps.cpp | 26 ++-- .../Linalg/transform-winograd-conv2d.mlir | 112 ++++++++---------- 3 files changed, 69 insertions(+), 77 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 68d0f713caad4..5ef56bc97fef1 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2597,7 +2597,7 @@ def WinogradConv2DOp : Op { let description = [{ - Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + Winograd Conv2D algorithm will convert linalg Conv2D operation into batched matrix multiply. Before the matrix multiply, it will convert filter and input into a format suitable for batched matrix multiply. After the matrix multiply, it will convert output to the final result tensor. @@ -2612,9 +2612,9 @@ def WinogradConv2DOp : Op>(target) - .Case([&](linalg::Conv2DNhwcFhwcOp op) { - return winogradConv2D(rewriter, op, getM(), getR()); - }) - .Default([&](Operation *op) { - return rewriter.notifyMatchFailure(op, "not supported"); - }); - - if (failed(maybeTransformed)) - return emitDefaultSilenceableFailure(target); + FailureOr maybeTransformed = failure(); + bool supported = TypeSwitch(target) + .Case([&](linalg::Conv2DNhwcFhwcOp op) { + maybeTransformed = + winogradConv2D(rewriter, op, getM(), getR()); + return true; + }) + .Default([&](Operation *op) { + op->emitError("not supported"); + return false; + }); + + if (supported && failed(maybeTransformed)) { + return emitSilenceableError() << "apply Winograd Conv2D failed"; + } results.push_back(*maybeTransformed); return DiagnosedSilenceableFailure::success(); diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir index 1e74fea5a1c31..0a2dcc035ebd3 100644 --- a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir +++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir @@ -1,13 +1,8 @@ -// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s +// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file -verify-diagnostics| FileCheck %s -func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> { - %0 = tensor.empty() : tensor<2x8x8x2xf32> - %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<2x8x8x2xf32> - %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> - return %2 : tensor<2x8x8x2xf32> +func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %0 : tensor<2x8x8x2xf32> } module attributes {transform.with_named_sequence} { @@ -18,38 +13,17 @@ module attributes {transform.with_named_sequence} { } } -// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-LABEL: func.func @conv2d -// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> { -// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32> -// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) { -// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// CHECK-NEXT: linalg.yield %[[IN]] : f32 -// CHECK-NEXT: } -> tensor<2x8x8x2xf32> -// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32> -// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32> -// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32> -// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32> -// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32> -// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32> -// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32> -// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32> -// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32> -// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> -// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32> -// CHECK-NEXT: } +// CHECK: linalg.winograd_filter_transform m(4) r(3) +// CHECK: linalg.winograd_input_transform m(4) r(3) +// CHECK: linalg.batch_matmul +// CHECK: linalg.winograd_output_transform m(4) r(3) // ----- -func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> { - %0 = tensor.empty() : tensor<2x9x9x2xf32> - %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<2x9x9x2xf32> - %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> - return %2 : tensor<2x9x9x2xf32> +func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> + return %0 : tensor<2x9x9x2xf32> } module attributes {transform.with_named_sequence} { @@ -60,29 +34,43 @@ module attributes {transform.with_named_sequence} { } } -// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-LABEL: func.func @conv2d_unaligned -// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> { -// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32> -// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) { -// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// CHECK-NEXT: linalg.yield %[[IN]] : f32 -// CHECK-NEXT: } -> tensor<2x9x9x2xf32> -// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32> -// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32> -// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32> -// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32> -// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32> -// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32> -// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32> -// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32> -// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32> -// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32> -// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32> -// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32> -// CHECK-NEXT: %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32> -// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> -// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32> -// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32> -// CHECK-NEXT: } +// CHECK: linalg.winograd_filter_transform m(4) r(3) +// CHECK: tensor.pad +// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0] +// CHECK: linalg.winograd_input_transform m(4) r(3) +// CHECK: tensor.pad +// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0] +// CHECK: linalg.winograd_output_transform m(4) r(3) + +// ----- + +func.func @conv2d_unsupported(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<3x3x5x2xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> { + // expected-error @+1 {{not supported}} + %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<3x3x5x2xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %0 : tensor<2x8x8x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// ----- + +func.func @conv2d(%arg0: tensor<2x?x?x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x?x?x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32> + return %0 : tensor<2x?x?x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @+1 {{apply Winograd Conv2D failed}} + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} From a93529db926dfabbc9354888efc2917bd2330d62 Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Thu, 27 Jun 2024 10:04:50 +0100 Subject: [PATCH 8/9] fix failed test --- mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 6 +++++- mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index e0f2d00400d63..6f03d71fd0e1f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3497,10 +3497,14 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne( return true; }) .Default([&](Operation *op) { - op->emitError("not supported"); return false; }); + if (!supported) { + return emitSilenceableError() + << "this operation is not supported to convert to Winograd Conv2D"; + } + if (supported && failed(maybeTransformed)) { return emitSilenceableError() << "apply Winograd Conv2D failed"; } diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir index 0a2dcc035ebd3..c10e0ccebfd7c 100644 --- a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir +++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir @@ -46,7 +46,6 @@ module attributes {transform.with_named_sequence} { // ----- func.func @conv2d_unsupported(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<3x3x5x2xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> { - // expected-error @+1 {{not supported}} %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<3x3x5x2xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> return %0 : tensor<2x8x8x2xf32> } @@ -54,6 +53,7 @@ func.func @conv2d_unsupported(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<3x3x5x module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @+1 {{this operation is not supported to convert to Winograd Conv2D}} %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) transform.yield } From 0bb0f05959eefe7bc828edd568e2e8f2158f3c4c Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Thu, 11 Jul 2024 10:03:09 +0100 Subject: [PATCH 9/9] clang-format --- mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 4 +--- mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 03ce455a409bf..bffe7a4e7d62c 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3727,9 +3727,7 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne( winogradConv2D(rewriter, op, getM(), getR()); return true; }) - .Default([&](Operation *op) { - return false; - }); + .Default([&](Operation *op) { return false; }); if (!supported) { return emitSilenceableError() diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp index 53008b876a650..9b8fa7cf6bac1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -13,11 +13,11 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/MathExtras.h" namespace mlir {