Skip to content

Commit 855e738

Browse files
committed
[VectorOps] Implement a simple folder for identity vector.transpose operations.
Differential Revision: https://reviews.llvm.org/D77088
1 parent 43e5765 commit 855e738

File tree

3 files changed

+92
-0
lines changed

3 files changed

+92
-0
lines changed

mlir/include/mlir/Dialect/Vector/VectorOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,10 +1315,12 @@ def Vector_TransposeOp :
13151315
VectorType getResultType() {
13161316
return result().getType().cast<VectorType>();
13171317
}
1318+
void getTransp(SmallVectorImpl<int64_t> &results);
13181319
}];
13191320
let assemblyFormat = [{
13201321
$vector `,` $transp attr-dict `:` type($vector) `to` type($result)
13211322
}];
1323+
let hasFolder = 1;
13221324
}
13231325

13241326
def Vector_TupleGetOp :

mlir/lib/Dialect/Vector/VectorOps.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,23 @@ static LogicalResult verify(TupleOp op) { return success(); }
15241524
// TransposeOp
15251525
//===----------------------------------------------------------------------===//
15261526

1527+
// Eliminates transpose operations, which produce values identical to their
1528+
// input values. This happens when the dimensions of the input vector remain in
1529+
// their original order after the transpose operation.
1530+
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
1531+
SmallVector<int64_t, 4> transp;
1532+
getTransp(transp);
1533+
1534+
// Check if the permutation of the dimensions contains sequential values:
1535+
// {0, 1, 2, ...}.
1536+
for (int64_t i = 0, e = transp.size(); i < e; i++) {
1537+
if (transp[i] != i)
1538+
return {};
1539+
}
1540+
1541+
return vector();
1542+
}
1543+
15271544
static LogicalResult verify(TransposeOp op) {
15281545
VectorType vectorType = op.getVectorType();
15291546
VectorType resultType = op.getResultType();
@@ -1549,6 +1566,10 @@ static LogicalResult verify(TransposeOp op) {
15491566
return success();
15501567
}
15511568

1569+
void TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
1570+
populateFromInt64AttrArray(transp(), results);
1571+
}
1572+
15521573
//===----------------------------------------------------------------------===//
15531574
// TupleGetOp
15541575
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,72 @@ func @strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
8787
// CHECK: vector.constant_mask [1, 1] : vector<2x1xi1>
8888
return %1 : vector<2x1xi1>
8989
}
90+
91+
// -----
92+
93+
// CHECK-LABEL: transpose_1D_identity
94+
// CHECK-SAME: ([[ARG:%.*]]: vector<4xf32>)
95+
func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
96+
// CHECK-NOT: transpose
97+
%0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32>
98+
// CHECK-NEXT: return [[ARG]]
99+
return %0 : vector<4xf32>
100+
}
101+
102+
// -----
103+
104+
// CHECK-LABEL: transpose_2D_identity
105+
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
106+
func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
107+
// CHECK-NOT: transpose
108+
%0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
109+
// CHECK-NEXT: return [[ARG]]
110+
return %0 : vector<4x3xf32>
111+
}
112+
113+
// -----
114+
115+
// CHECK-LABEL: transpose_3D_identity
116+
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
117+
func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
118+
// CHECK-NOT: transpose
119+
%0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32>
120+
// CHECK-NEXT: return [[ARG]]
121+
return %0 : vector<4x3x2xf32>
122+
}
123+
124+
// -----
125+
126+
// CHECK-LABEL: transpose_2D_sequence
127+
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
128+
func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<3x4xf32> {
129+
// CHECK-NOT: transpose
130+
%0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
131+
// CHECK: [[T1:%.*]] = vector.transpose [[ARG]], [1, 0]
132+
%1 = vector.transpose %0, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
133+
// CHECK-NOT: transpose
134+
%2 = vector.transpose %1, [0, 1] : vector<3x4xf32> to vector<3x4xf32>
135+
// CHECK: [[ADD:%.*]] = addf [[T1]], [[T1]]
136+
%4 = addf %1, %2 : vector<3x4xf32>
137+
// CHECK-NEXT: return [[ADD]]
138+
return %4 : vector<3x4xf32>
139+
}
140+
141+
// -----
142+
143+
// CHECK-LABEL: transpose_3D_sequence
144+
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
145+
func @transpose_3D_sequence(%arg : vector<4x3x2xf32>) -> vector<2x3x4xf32> {
146+
// CHECK: [[T0:%.*]] = vector.transpose [[ARG]], [1, 2, 0]
147+
%0 = vector.transpose %arg, [1, 2, 0] : vector<4x3x2xf32> to vector<3x2x4xf32>
148+
// CHECK-NOT: transpose
149+
%1 = vector.transpose %0, [0, 1, 2] : vector<3x2x4xf32> to vector<3x2x4xf32>
150+
// CHECK: [[T2:%.*]] = vector.transpose [[T0]], [1, 0, 2]
151+
%2 = vector.transpose %1, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32>
152+
// CHECK: [[ADD:%.*]] = addf [[T2]], [[T2]]
153+
%3 = addf %2, %2 : vector<2x3x4xf32>
154+
// CHECK-NOT: transpose
155+
%4 = vector.transpose %3, [0, 1, 2] : vector<2x3x4xf32> to vector<2x3x4xf32>
156+
// CHECK-NEXT: return [[ADD]]
157+
return %4 : vector<2x3x4xf32>
158+
}

0 commit comments

Comments
 (0)