Skip to content

Conversation

@RoboTux
Copy link
Contributor

@RoboTux RoboTux commented Sep 11, 2024

Winograd lowering involves a number of matmul and batch_matmul which
are currently passed tensor.empty result as out parameter, thereby
are undefined behaviour. This commit adds the necessary linalg.fill.

Winograd lowering involves a number of matmul and batch_matmul which
are currently passed tensor.empty result as out parameter, thereby
are undefined behaviour. This commit adds the necessary linalg.fill.
@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Thomas Preud'homme (RoboTux)

Changes

Winograd lowering involves a number of matmul and batch_matmul which
are currently passed tensor.empty result as out parameter, thereby
are undefined behaviour. This commit adds the necessary linalg.fill.


Patch is 70.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108181.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+43-14)
  • (modified) mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir (+91-88)
  • (modified) mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir (+52-44)
  • (modified) mlir/test/Dialect/Linalg/winograd-conv2d.mlir (+80-67)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index b65b18699a15aa..80edf4a32c6df8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -390,6 +390,8 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
     TransformMapKeyTy key = {m, r};
     int64_t retRows = 1;
     Value matmulRetValue = extractFilter;
+    Value zero = builder.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(elementType));
     if (leftTransform) {
       // Get constant transform matrix G.
       auto it = GMatrices.find(key);
@@ -399,8 +401,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
 
       retRows = GMatrix.rows;
       auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
-      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                  elementType);
+      auto empty =
+          builder
+              .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+              .getResult();
+      auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
 
       Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
       // Multiply G x g.
@@ -418,8 +423,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
 
       auto matmulType =
           RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
-      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                  elementType);
+      auto empty =
+          builder
+              .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+              .getResult();
+      auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
 
       Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
       // Multiply u = (G x g) x GT.
@@ -523,6 +531,8 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
     int64_t retRows = 1;
     int64_t retCols = 1;
     Value matmulRetValue = extractInput;
+    Value zero = builder.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(elementType));
     if (leftTransform) {
       // Get constant transform matrix BT.
       auto it = BTMatrices.find(key);
@@ -532,8 +542,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
 
       retRows = BTMatrix.rows;
       auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
-      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                  elementType);
+      auto empty =
+          builder
+              .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+              .getResult();
+      auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
 
       Value BT =
           create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
@@ -552,8 +565,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
 
       retCols = BMatrix.cols;
       auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
-      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                  elementType);
+      auto empty =
+          builder
+              .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+              .getResult();
+      auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
       Value B =
           create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
       // Multiply v = (BT x d) x B.
@@ -636,8 +652,13 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
       {inputShape[0] * inputShape[1],
        inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
       outputElementType);
-  Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                outputElementType);
+  Value empty = rewriter
+                    .create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                             outputElementType)
+                    .getResult();
+  Value zero = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(outputElementType));
+  Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0);
 
   auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
       loc, matmulType, ValueRange({collapseInput, collapseFilter}),
@@ -725,6 +746,8 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
     int64_t leftScalarFactor = 1;
     int64_t rightScalarFactor = 1;
     Value matmulRetValue = extractValue;
+    Value zero = builder.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(elementType));
     if (leftTransform) {
       // Get constant transform matrix AT.
       auto it = ATMatrices.find(key);
@@ -735,8 +758,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
       leftScalarFactor = ATMatrix.scalarFactor;
       retRows = ATMatrix.rows;
       auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
-      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                  elementType);
+      auto empty =
+          builder
+              .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+              .getResult();
+      auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
 
       Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
       // Multiply AT x m.
@@ -756,8 +782,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
       auto matmulType =
           RankedTensorType::get({retRows, AMatrix.cols}, elementType);
       retCols = AMatrix.cols;
-      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                  elementType);
+      auto empty =
+          builder
+              .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+              .getResult();
+      auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
 
       Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
       // Multiply y = (AT x m) x A.
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 6bb3fb1423edc6..21dcea968615f6 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -44,9 +44,9 @@ module attributes {transform.with_named_sequence} {
 // CHECK:  %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
 // CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
 // CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1]
-// CHECK:      %[[S11:.*]] = linalg.matmul
-// CHECK:      %[[S13:.*]] = linalg.matmul
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = linalg.matmul
+// CHECK:      %[[S15:.*]] = linalg.matmul
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
 // CHECK:      scf.yield %[[INSERTED_SLICE]]
 // CHECK:    scf.yield %[[S9]]
 // CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
@@ -56,20 +56,20 @@ module attributes {transform.with_named_sequence} {
 // CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
 // CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
 // CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
-// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
+// CHECK:      %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_8]])
 // CHECK:        %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
-// CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
-// CHECK:          %[[S15:.*]] = linalg.matmul
-// CHECK:          %[[S17:.*]] = linalg.matmul
-// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S17]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK:          scf.yield %[[INSERTED_SLICE_9]]
+// CHECK:          %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK:          %[[S16:.*]] = linalg.matmul
+// CHECK:          %[[S19:.*]] = linalg.matmul
+// CHECK:          %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S19]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          scf.yield %[[INSERTED_SLICE_10]]
 // CHECK:        scf.yield %[[S13]]
 // CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
 // CHECK:      scf.yield %[[INSERTED_SLICE]]
 // CHECK:    scf.yield %[[S9]]
 // CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
-// CHECK:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:  %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
 // CHECK:  %[[S6:.*]] = linalg.batch_matmul
 // CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2]
 // CHECK:  %[[S7:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
@@ -78,20 +78,20 @@ module attributes {transform.with_named_sequence} {
 // CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
 // CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
 // CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
-// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
+// CHECK:      %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_8]])
 // CHECK:        %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
-// CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK:          %[[S17:.*]] = linalg.matmul
-// CHECK:          %[[S19:.*]] = linalg.matmul
-// CHECK:          %[[S20:.*]] = tensor.empty()
-// CHECK:          %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
+// CHECK:          %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          %[[S18:.*]] = linalg.matmul
+// CHECK:          %[[S21:.*]] = linalg.matmul
+// CHECK:          %[[S22:.*]] = tensor.empty()
+// CHECK:          %[[S23:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S22]] : tensor<4x4xf32>) {
 // CHECK:          ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
 // CHECK:            linalg.yield %[[IN]] : f32
 // CHECK:          } -> tensor<4x4xf32>
-// CHECK:          %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S22]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
-// CHECK:          scf.yield %[[INSERTED_SLICE_9]]
+// CHECK:          %[[S24:.*]] = linalg.mul ins(%[[S23]], %[[S21]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S22]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK:          %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S24]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK:          scf.yield %[[INSERTED_SLICE_10]]
 // CHECK:        scf.yield %[[S15]]
 // CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
 // CHECK:      %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
@@ -114,14 +114,15 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
   %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
   %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
   %4 = tensor.empty() : tensor<36x18x2xf32>
-  %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
-  %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+  %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%5 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+  %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
   %padded_1 = tensor.pad %arg2 low[0, 0, 0, 0] high[0, 3, 3, 0] {
   ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
     tensor.yield %cst : f32
   } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
-  %extracted_slice = tensor.extract_slice %6[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+  %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+  %extracted_slice = tensor.extract_slice %7[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
   return %extracted_slice : tensor<2x9x9x2xf32>
 }
 
@@ -153,71 +154,72 @@ module attributes {transform.with_named_sequence} {
 // CHECK:  %[[C2:.*]] = arith.constant 2 : index
 // CHECK:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK:  %[[S0:.*]] = tensor.empty()
-// CHECK:  %[[S1:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]])
-// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
-// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1]
-// CHECK:      %[[S11:.*]] = linalg.matmul
+// CHECK:  %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK:    %[[S10:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1]
 // CHECK:      %[[S13:.*]] = linalg.matmul
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK:      %[[S16:.*]] = linalg.matmul
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S16]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
 // CHECK:      scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
-// CHECK:    scf.yield %[[S9]] : tensor<6x6x5x2xf32>
+// CHECK:    scf.yield %[[S10]] : tensor<6x6x5x2xf32>
 // CHECK:  %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
 // CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
 // CHECK:  %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK:  %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]])
-// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
-// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
-// CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
-// CHECK:        %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
-// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
-// CHECK:          %[[S15:.*]] = linalg.matmul
+// CHECK:  %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]])
+// CHECK:    %[[S10:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:      %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S11]], %[[S12]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
+// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[S13:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK:        %[[S14:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
+// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
 // CHECK:          %[[S17:.*]] = linalg.matmul
-// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S17]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          %[[S20:.*]] = linalg.matmul
+// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S20]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
 // CHECK:          scf.yield %[[INSERTED_SLICE_12]] : tensor<6x6x1x1x2x5xf32>
-// CHECK:        scf.yield %[[S13]] : tensor<6x6x1x1x2x5xf32>
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:        scf.yield %[[S14]] : tensor<6x6x1x1x2x5xf32>
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
 // CHECK:      scf.yield %[[INSERTED_SLI...
[truncated]

Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this. It's a little hard to see what's changing in the lit tests, since all the names in the checks have changed. Could you keep the names there the same? Also, please add checks in the decomposition tests that fill ops are created and used in the matmuls.

Reduce diff noise in tests and add tensor.empty and linalg.fill check in
decomposition tests.
@RoboTux
Copy link
Contributor Author

RoboTux commented Sep 12, 2024

Thanks for catching this. It's a little hard to see what's changing in the lit tests, since all the names in the checks have changed. Could you keep the names there the same? Also, please add checks in the decomposition tests that fill ops are created and used in the matmuls.

I've reduced the noise somewhat but adding the fill inevitably resequence quite a bit of SSA values. It also leads to some FileCheck variable reuse because not everything was renamed. Alternatively if you prefer I could do a NFC patch that just update all FileCheck variable like in my first attempt but doesn't have the fix. Then in a second patch I can add just the fix and the new CHECK and land that as a rebase rather than squash if that's possible in github.

@RoboTux RoboTux requested a review from Max191 September 13, 2024 09:42
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks much better.

@Max191
Copy link
Contributor

Max191 commented Sep 13, 2024

I've reduced the noise somewhat but adding the fill inevitably resequence quite a bit of SSA values. It also leads to some FileCheck variable reuse because not everything was renamed. Alternatively if you prefer I could do a NFC patch that just update all FileCheck variable like in my first attempt but doesn't have the fix. Then in a second patch I can add just the fix and the new CHECK and land that as a rebase rather than squash if that's possible in github.

It looks better now, so no worries. It is usually good to have more meaningful FileCheck variable names, but I probably should have caught that on the original PRs. It makes lit tests more readable and maintainable when the variables aren't just numbered ssa values (you don't need to do anything about that in this PR, just noting).

@RoboTux RoboTux merged commit 326287f into llvm:main Sep 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants