Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 112 additions & 31 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4698,6 +4698,111 @@ OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
// Common Canonicalizers and Folders.
//===----------------------------------------------------------------------===//
bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
// 1. InsertSliceOp has its own logic about folding tensor.cast ops.
// 2. Exclude DPS ops that are also LoopLike from this interface as they
// might need special handling of attached regions.
if (isa<InsertSliceOp>(op.getOperation()) ||
isa<LoopLikeOpInterface>(op.getOperation()))
return false;

// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
if (llvm::isa<BlockArgument>(opOperand.get()))
return false;
auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
});

return hasTensorCastOperand;
}

static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
SmallVector<Type> &newResTy) {
SmallVector<Value> newOperands;
newOperands.reserve(op->getNumOperands());

// Assumes that the result has dpsInits followed by nonDpsInits.
int64_t dpsInitIdx = 0;
for (OpOperand &opOperand : op->getOpOperands()) {
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
if (op.isDpsInit(&opOperand) &&
!llvm::isa<MemRefType>(newOperands.back().getType()))
newResTy[dpsInitIdx++] = newOperands.back().getType();
}
return newOperands;
}

/// Folds a tensor.cast op into a consuming tensor::PackOp op if the
/// `tensor.cast` has source that is more static than the consuming op.
///
/// Example:
/// ```mlir
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
/// ```
///
/// folds into:
///
/// ```mlir
/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
/// ```
struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;

LogicalResult matchAndRewrite(PackOp op,
PatternRewriter &rewriter) const override {
if (!foldTensorCastPrecondition(op))
return failure();

SmallVector<Type> newResultTypes(op->getResultTypes());
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);

// Get the updated mixed-tile-sizes attribute.
SmallVector<OpFoldResult> newMixedTileSizes;
for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
.getShape()
.take_back(op.getMixedTiles().size()),
op.getMixedTiles())) {
Comment on lines +4766 to +4769
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
.getShape()
.take_back(op.getMixedTiles().size()),
op.getMixedTiles())) {
for (auto [shape, innerTile] : llvm::zip_equal(cast<ShapedType>(newResultTypes[0])
.getShape()
.take_back(op.getMixedTiles().size()),
op.getMixedTiles())) {

int64_t shape = std::get<0>(it);
if (shape == ShapedType::kDynamic) {
newMixedTileSizes.push_back(std::get<1>(it));
continue;
}

if (Attribute attr =
llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
// Already a constant
newMixedTileSizes.push_back(std::get<1>(it));
} else {
int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
assert(tileSize == shape && "tile size and dim size don't match!");
newMixedTileSizes.push_back(
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
}
}

// Clone op.
PackOp newOp = rewriter.create<PackOp>(
op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());

// Replace op.
Value oldResult = op.getResult();
Value newResult = newOp.getResult();
Value replacement = (newResult.getType() != oldResult.getType())
? rewriter.create<tensor::CastOp>(
op->getLoc(), oldResult.getType(), newResult)
: newResult;

rewriter.replaceOp(op, {replacement});

return success();
}
};

/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
/// the `tensor.cast` has source that is more static than the consuming op.
Expand All @@ -4722,42 +4827,17 @@ struct FoldTensorCastProducerOp

LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
PatternRewriter &rewriter) const override {
// InsertSliceOp has its own logic about folding tensor.cast ops.
if (isa<InsertSliceOp>(op.getOperation()))
return failure();

// Exclude DPS ops that are also LoopLike from this interface as they
// might need special handling of attached regions.
if (isa<LoopLikeOpInterface>(op.getOperation()))
// Reject tensor::PackOp - there's dedicated pattern for that instead.
if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
return failure();

// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
if (llvm::isa<BlockArgument>(opOperand.get()))
return false;
auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
});
if (!hasTensorCastOperand)
return failure();
SmallVector<Type> newResultTypes(op->getResultTypes());
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);

SmallVector<Type, 4> newResultTypes(op->getResultTypes());
SmallVector<Value, 4> newOperands;
newOperands.reserve(op->getNumOperands());
// Assumes that the result has dpsInits followed by nonDpsInits.
int64_t dpsInitIdx = 0;
for (OpOperand &opOperand : op->getOpOperands()) {
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
if (op.isDpsInit(&opOperand) &&
!llvm::isa<MemRefType>(newOperands.back().getType()))
newResultTypes[dpsInitIdx++] = newOperands.back().getType();
}
// Clone op
auto newOp = clone(rewriter, op, newResultTypes, newOperands);

// Clone op.
Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
SmallVector<Value, 4> replacements;
replacements.reserve(newOp->getNumResults());
for (auto [oldResult, newResult] :
Expand All @@ -4781,6 +4861,7 @@ struct FoldTensorCastProducerOp

void TensorDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<FoldTensorCastPackOp>(getContext());
results.add<FoldTensorCastProducerOp>(getContext());
}

Expand Down
23 changes: 21 additions & 2 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2718,18 +2718,37 @@ func.func @dim_out_of_bounds() -> vector<7xi32> {

// -----

// CHECK-LABEL: func.func @test_destination_multiple_result(
// CHECK-LABEL: func.func @fold_cast_multiple_results(
// CHECK-SAME: %[[ARG1:.*]]: tensor<2x2xf32>,
// CHECK-SAME: %[[ARG2:.*]]: tensor<2x2xf32>) -> index {
// CHECK: %[[RES:.*]]:2 = test.destination_style_op ins(%[[ARG1]] : tensor<2x2xf32>)
// CHECK-SAME: outs(%[[ARG2]] : tensor<2x2xf32>) -> tensor<2x2xf32>, index
// CHECK: return %[[RES]]#1 : index
func.func @test_destination_multiple_result(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
%cast = tensor.cast %arg0 : tensor<2x2xf32> to tensor<?x2xf32>
%cast_0 = tensor.cast %arg1 : tensor<2x2xf32> to tensor<?x2xf32>
%0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
return %0#1 : index
}
// -----

// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size
// CHECK-SAME: %[[DEST:.*]]: tensor<1x1x8x1xi32>,
// CHECK-SAME: %[[SRC:.*]]: tensor<7x?xi32>,
// CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
// CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
// CHECK: return %[[PACK]] : tensor<1x1x8x1xi32>
func.func @fold_cast_pack_dynamic_tile_size(
%dest: tensor<1x1x8x1xi32>,
%src: tensor<7x?xi32>,
%pad: i32) -> tensor<1x1x8x1xi32> {

%cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
%c8 = arith.constant 8 : index
%pack = tensor.pack %src padding_value(%pad : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %cast : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
%res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
return %res : tensor<1x1x8x1xi32>
}

// -----

Expand Down
Loading