Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops">,
let dependentDialects = ["linalg::LinalgDialect"];
}

def LinalgNamedOpConversionPass: Pass<"linalg-named-op-conversion"> {
let summary = "Convert from one named linalg op to another.";
// ------------------ End of "form" conversions

def SimplifyDepthwiseConvPass: Pass<"simplify-depthwise-conv"> {
let summary = "Simplify depthwise convolution.";
let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
}

// ------------------ End of "form" conversions

def ConvertElementwiseToLinalgPass : Pass<"convert-elementwise-to-linalg", ""> {
let summary = "Convert ElementwiseMappable ops to linalg";
let description = [{
Expand Down
5 changes: 2 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1962,9 +1962,8 @@ void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
void populateFuseTensorPadWithProducerLinalgOpPatterns(
RewritePatternSet &patterns);

/// Patterns to convert from one named op to another. These can be seen as
/// canonicalizations of named ops into another named op.
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
/// Patterns to simplify depthwise convolutions.
void populateSimplifyDepthwiseConvPatterns(RewritePatternSet &patterns);

/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
/// tensors via reassociative reshape ops.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MorphOps.cpp
TransposeMatmul.cpp
ShardingInterfaceImpl.cpp
NamedOpConversions.cpp
SimplifyDepthwiseConv.cpp
NamedToElementwise.cpp
BlockPackMatmul.cpp
PackAndUnpackPatterns.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "llvm/ADT/TypeSwitch.h"

namespace mlir {
#define GEN_PASS_DEF_LINALGNAMEDOPCONVERSIONPASS
#define GEN_PASS_DEF_SIMPLIFYDEPTHWISECONVPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

Expand Down Expand Up @@ -143,23 +143,22 @@ struct SimplifyDepthwiseConvQOp
}
};

struct LinalgNamedOpConversionPass
: public impl::LinalgNamedOpConversionPassBase<
LinalgNamedOpConversionPass> {
using impl::LinalgNamedOpConversionPassBase<
LinalgNamedOpConversionPass>::LinalgNamedOpConversionPassBase;
struct SimplifyDepthwiseConvPass
: public impl::SimplifyDepthwiseConvPassBase<SimplifyDepthwiseConvPass> {
using impl::SimplifyDepthwiseConvPassBase<
SimplifyDepthwiseConvPass>::SimplifyDepthwiseConvPassBase;

void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
populateLinalgNamedOpConversionPatterns(patterns);
populateSimplifyDepthwiseConvPatterns(patterns);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace

void mlir::linalg::populateLinalgNamedOpConversionPatterns(
void mlir::linalg::populateSimplifyDepthwiseConvPatterns(
RewritePatternSet &patterns) {
patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
patterns.getContext());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -linalg-named-op-conversion -split-input-file | FileCheck %s
// RUN: mlir-opt %s --simplify-depthwise-conv -split-input-file | FileCheck %s

// CHECK-LABEL: @depthwise_conv
func.func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %arg2: tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32> {
Expand Down