Skip to content

Commit 593b5b8

Browse files
committed
[mlir][Linalg] Fix linalg.generic iteration domain collapse for dynamic dims
1 parent c40877d commit 593b5b8

File tree

3 files changed

+62
-31
lines changed

3 files changed

+62
-31
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,7 +1550,7 @@ static Value getCollapsedOpOperand(Location loc, LinalgOp op,
15501550
/// value in the collapsed operation.
15511551
void generateCollapsedIndexingRegion(Location loc, Block *block,
15521552
const CollapsingInfo &collapsingInfo,
1553-
ValueRange loopRange,
1553+
ArrayRef<OpFoldResult> loopRange,
15541554
RewriterBase &rewriter) {
15551555
OpBuilder::InsertionGuard g(rewriter);
15561556
rewriter.setInsertionPointToStart(block);
@@ -1572,10 +1572,12 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
15721572
Value newIndexVal =
15731573
rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
15741574
for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1575+
Value loopDim =
1576+
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
15751577
indexReplacementVals[dim] =
1576-
rewriter.create<arith::RemSIOp>(loc, newIndexVal, loopRange[dim]);
1578+
rewriter.createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
15771579
newIndexVal =
1578-
rewriter.create<arith::DivSIOp>(loc, newIndexVal, loopRange[dim]);
1580+
rewriter.createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
15791581
}
15801582
indexReplacementVals[foldedDims.value().front()] = newIndexVal;
15811583
}
@@ -1722,14 +1724,13 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
17221724
LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
17231725

17241726
Location loc = op->getLoc();
1727+
SmallVector<OpFoldResult> loopBound =
1728+
llvm::map_to_vector(loopRanges, [&](Range range) { return range.size; });
1729+
17251730
if (collapsedOp.hasIndexSemantics()) {
17261731
// Collect the loop range of the generic op.
17271732
OpBuilder::InsertionGuard g(rewriter);
17281733
rewriter.setInsertionPoint(collapsedOp);
1729-
SmallVector<Value> loopBound =
1730-
llvm::map_to_vector(loopRanges, [&](Range range) {
1731-
return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
1732-
});
17331734
generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
17341735
collapsingInfo, loopBound, rewriter);
17351736
}
@@ -1747,15 +1748,19 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
17471748
op.getIndexingMapMatchingResult(originalResult.value());
17481749
SmallVector<ReassociationIndices> reassociation =
17491750
getOperandReassociation(indexingMap, collapsingInfo);
1751+
SmallVector<OpFoldResult> resultShape =
1752+
applyPermutationMap(indexingMap, ArrayRef(loopBound));
17501753
Value result;
17511754
if (isa<MemRefType>(collapsedOpResult.getType())) {
17521755
MemRefType expandShapeResultType = MemRefType::get(
17531756
originalResultType.getShape(), originalResultType.getElementType());
17541757
result = rewriter.create<memref::ExpandShapeOp>(
1755-
loc, expandShapeResultType, collapsedOpResult, reassociation);
1758+
loc, expandShapeResultType, collapsedOpResult, reassociation,
1759+
resultShape);
17561760
} else {
17571761
result = rewriter.create<tensor::ExpandShapeOp>(
1758-
loc, originalResultType, collapsedOpResult, reassociation);
1762+
loc, originalResultType, collapsedOpResult, reassociation,
1763+
resultShape);
17591764
}
17601765
results.push_back(result);
17611766
} else {

mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,38 @@ func.func @fuse_by_collapsing_dynamic(%arg0 : tensor<?x?x?x?x?xi32>,
225225

226226
// -----
227227

228+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
229+
func.func @fuse_by_collapsing_dynamic_2(%arg0 : tensor<?xf32>, %sz0: index, %sz1: index) -> tensor<?x?xf32> {
230+
%0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [%sz0, %sz1] : tensor<?xf32> into tensor<?x?xf32>
231+
%init = tensor.empty(%sz1, %sz0) : tensor<?x?xf32>
232+
%1 = linalg.generic {
233+
indexing_maps = [#map0, #map0],
234+
iterator_types = ["parallel", "parallel"]}
235+
ins(%0 : tensor<?x?xf32>)
236+
outs(%init : tensor<?x?xf32>) {
237+
^bb0(%b0 : f32, %b1 : f32):
238+
%out = arith.negf %b0 : f32
239+
linalg.yield %out : f32
240+
} -> tensor<?x?xf32>
241+
return %1 : tensor<?x?xf32>
242+
}
243+
244+
// CHECK-LABEL: func @fuse_by_collapsing_dynamic_2
245+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
246+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
247+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
248+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
249+
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
250+
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[EXPANDED]], %[[C1]]
251+
// CHECK: %[[OUT:.+]] = linalg.generic
252+
// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
253+
// CHECK-SAME: outs(%{{.*}} : tensor<?xf32>)
254+
// CHECK: %[[EXPANDED_1:.+]] = tensor.expand_shape %[[OUT]]
255+
// CHECK-SAME: output_shape [%[[DIM0]], %[[DIM1]]]
256+
// CHECK: return %[[EXPANDED_1]]
257+
258+
// -----
259+
228260
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
229261
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
230262
func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>, %sz0: index) -> tensor<2x5xf32> {
@@ -425,10 +457,11 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
425457
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
426458
// CHECK: func @fuse_only_one_reassociation
427459
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<4x?x?x8xf32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
428-
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
429460
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
430-
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
461+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
431462
// CHECK-DAG: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[SZ0]], 4, %[[SZ1]], 8]
463+
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C0]] : tensor<?x4x?x8xf32>
464+
// CHECK-DAG: %[[DIM_2:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C2]] : tensor<?x4x?x8xf32>
432465
// CHECK-DAG: %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
433466
// CHECK-DAG: %[[COLLAPSE_ARG1_0:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
434467
// CHECK-DAG: %[[COLLAPSE_ARG1_1:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
@@ -437,10 +470,7 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
437470
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
438471
// CHECK-SAME: ins(%[[COLLAPSE_ARG0]], %[[COLLAPSE_ARG1_0]] :
439472
// CHECK-SAME: outs(%[[COLLAPSE_ARG1_1]] :
440-
// CHECK: %[[DIM:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<4x?x?xf32>
441-
// CHECK: %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C2]] : tensor<4x?x?xf32>
442-
// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
443-
// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[VAL_1]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
473+
// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[DIM_2]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
444474
// CHECK: return %[[EXPANDED_3]]
445475

446476
// -----
@@ -475,15 +505,16 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
475505
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)>
476506
// CHECK: func @fold_non_consecutive_dims(
477507
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
478-
// CHECK: %[[C1:.+]] = arith.constant 1 : index
479-
// CHECK: %[[C4:.+]] = arith.constant 4 : index
480-
// CHECK: %[[C8:.+]] = arith.constant 8 : index
481-
// CHECK: %[[C0:.+]] = arith.constant 0 : index
482-
// CHECK: %[[C2:.+]] = arith.constant 2 : index
508+
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
509+
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
510+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
511+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
483512
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 8] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
484-
// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
485-
// CHECK: %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
513+
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
514+
// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
486515
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM_0]], %[[DIM]])
516+
// CHECK-DAG: %[[DIM_1:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
517+
// CHECK-DAG: %[[DIM_2:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
487518
// CHECK: %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}}
488519
// CHECK: %[[GENERIC:.+]] = linalg.generic
489520
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
@@ -502,11 +533,7 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
502533
// CHECK-DAG: %[[T6:.+]] = arith.addi %[[T5]], %[[T3]]
503534
// CHECK-DAG: %[[T7:.+]] = arith.index_cast %[[T6]]
504535
// CHECK: linalg.yield %[[T7]]
505-
// CHECK: %[[DIM_1:.+]] = tensor.dim %[[GENERIC]], %[[C0]] : tensor<?x?xi32>
506-
// CHECK: %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<?x?xi32>
507-
// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C8]] : index
508-
// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C4]] : index
509-
// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
536+
// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM_2]], 8, %[[DIM_1]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
510537
// CHECK: return %[[EXPANDED_3]]
511538

512539
// -----

mlir/test/Dialect/Linalg/fusion-push-reshape.mlir

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55

66
// CHECK-LABEL: func @reshape
77
// CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>, %[[SZ0:.*]]: index)
8-
// CHECK: %[[C112:.*]] = arith.constant 112 : index
98
// CHECK: %[[C0:.*]] = arith.constant 0 : index
9+
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[A]]
10+
// CHECK: %[[DIM:.*]] = tensor.dim %[[EXPANDED]], %[[C0]]
1011
// CHECK: %[[RI:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor<?x112x16xf32> into tensor<?x16xf32>
1112
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
1213
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
1314
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x16xf32>)
14-
// CHECK: %[[DIM:.*]] = tensor.dim %[[R]], %[[C0]] : tensor<?x16xf32>
15-
// CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C112]] : index
16-
// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[VAL_1]], 112, 16] : tensor<?x16xf32> into tensor<?x112x16xf32>
15+
// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[DIM]], 112, 16] : tensor<?x16xf32> into tensor<?x112x16xf32>
1716
// CHECK: return %[[RR]] : tensor<?x112x16xf32>
1817
func.func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>, %sz0: index) -> tensor<?x112x16xf32> {
1918
%0 = tensor.expand_shape %A [[0, 1], [2]] output_shape [%sz0, 112, 16]

0 commit comments

Comments
 (0)