Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 6 additions & 66 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4696,72 +4696,6 @@ def FrameAddrOp : FuncAddrBuiltinOp<"frame_address"> {
}];
}

//===----------------------------------------------------------------------===//
// StdFindOp
//===----------------------------------------------------------------------===//

def StdFindOp : CIR_Op<"std.find", [SameFirstSecondOperandAndResultType]> {
let arguments = (ins FlatSymbolRefAttr:$original_fn,
CIR_AnyType:$first,
CIR_AnyType:$last,
CIR_AnyType:$pattern);
let summary = "std:find()";
let results = (outs CIR_AnyType:$result);

let description = [{
Search for `pattern` in data range from `first` to `last`. This currently
maps to only one form of `std::find`. The `original_fn` operand tracks the
mangled named that can be used when lowering to a `cir.call`.

Example:

```mlir
...
%result = cir.std.find(@original_fn,
%first : !T, %last : !T, %pattern : !P) -> !T
```
}];

let assemblyFormat = [{
`(`
$original_fn
`,` $first `:` type($first)
`,` $last `:` type($last)
`,` $pattern `:` type($pattern)
`)` `->` type($result) attr-dict
}];
let hasVerifier = 0;
}

//===----------------------------------------------------------------------===//
// IterBegin/End
//===----------------------------------------------------------------------===//

def IterBeginOp : CIR_Op<"iterator_begin"> {
let arguments = (ins FlatSymbolRefAttr:$original_fn, CIR_AnyType:$container);
let summary = "Returns an iterator to the first element of a container";
let results = (outs CIR_AnyType:$result);
let assemblyFormat = [{
`(`
$original_fn `,` $container `:` type($container)
`)` `->` type($result) attr-dict
}];
let hasVerifier = 0;
}

def IterEndOp : CIR_Op<"iterator_end"> {
let arguments = (ins FlatSymbolRefAttr:$original_fn, CIR_AnyType:$container);
let summary = "Returns an iterator to the element following the last element"
" of a container";
let results = (outs CIR_AnyType:$result);
let assemblyFormat = [{
`(`
$original_fn `,` $container `:` type($container)
`)` `->` type($result) attr-dict
}];
let hasVerifier = 0;
}

//===----------------------------------------------------------------------===//
// Floating Point Ops
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -5751,4 +5685,10 @@ def SignBitOp : CIR_Op<"signbit", [Pure]> {
}];
}

//===----------------------------------------------------------------------===//
// Standard library function calls
//===----------------------------------------------------------------------===//

include "clang/CIR/Dialect/IR/CIRStdOps.td"

#endif // LLVM_CLANG_CIR_DIALECT_IR_CIROPS
66 changes: 66 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRStdOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===-- CIRStdOps.td - CIR standard library ops ------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// Defines ops representing standard library calls
///
//===----------------------------------------------------------------------===//

#ifndef LLVM_CLANG_CIR_DIALECT_IR_CIRSTDOPS
#define LLVM_CLANG_CIR_DIALECT_IR_CIRSTDOPS

class CIRStdOp<string functionName, dag args, dag res, list<Trait> traits = []>:
CIR_Op<"std." # functionName, traits> {
string funcName = functionName;

let arguments = !con((ins FlatSymbolRefAttr:$original_fn), args);

let summary = "std::" # functionName # "()";
let results = res;

let extraClassDeclaration = [{
static constexpr unsigned getNumArgs() {
return }] # !size(args) # [{;
}
static llvm::StringRef getFunctionName() {
return "}] # functionName # [{";
}
}];

string argsAssemblyFormat = !interleave(
!foreach(
name,
!foreach(i, !range(!size(args)), !getdagname(args, i)),
!strconcat("$", name, " `:` type($", name, ")")
), " `,` "
);

string resultAssemblyFormat = !if(
!empty(res),
"",
" `->` type($" # !getdagname(res, 0) # ")"
);

let assemblyFormat = !strconcat("`(` ", argsAssemblyFormat,
" `,` $original_fn `)`", resultAssemblyFormat,
" attr-dict");

let hasVerifier = 0;
}

def StdFindOp : CIRStdOp<"find",
(ins CIR_AnyType:$first, CIR_AnyType:$last, CIR_AnyType:$pattern),
(outs CIR_AnyType:$result),
[SameFirstSecondOperandAndResultType]>;
def IterBeginOp: CIRStdOp<"begin",
(ins CIR_AnyType:$container),
(outs CIR_AnyType:$result)>;
def IterEndOp: CIRStdOp<"end",
(ins CIR_AnyType:$container),
(outs CIR_AnyType:$result)>;

#endif
79 changes: 52 additions & 27 deletions clang/lib/CIR/Dialect/Transforms/IdiomRecognizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,52 @@ using namespace cir;

namespace {

// Recognizes a cir.call that calls a standard library function represented
// by `TargetOp`, and raise it to that operation.
template <typename TargetOp> class StdRecognizer {
private:
// Reserved for template specialization.
static bool checkArguments(mlir::ValueRange) { return true; }

template <size_t... Indices>
static TargetOp buildCall(CIRBaseBuilderTy &builder, CallOp call,
std::index_sequence<Indices...>) {
return builder.create<TargetOp>(call.getLoc(), call.getResult().getType(),
call.getCalleeAttr(),
call.getOperand(Indices)...);
}

public:
static bool raise(CallOp call, mlir::MLIRContext &context, bool remark) {
constexpr int numArgs = TargetOp::getNumArgs();
if (call.getNumOperands() != numArgs)
return false;

auto callExprAttr = call.getAstAttr();
llvm::StringRef stdFuncName = TargetOp::getFunctionName();
if (!callExprAttr || !callExprAttr.isStdFunctionCall(stdFuncName))
return false;

if (!checkArguments(call.getArgOperands()))
return false;

if (remark)
mlir::emitRemark(call.getLoc())
<< "found call to std::" << stdFuncName << "()";

CIRBaseBuilderTy builder(context);
builder.setInsertionPointAfter(call.getOperation());
TargetOp op = buildCall(builder, call, std::make_index_sequence<numArgs>());
call.replaceAllUsesWith(op);
call.erase();
return true;
}
};

struct IdiomRecognizerPass : public IdiomRecognizerBase<IdiomRecognizerPass> {
IdiomRecognizerPass() = default;
void runOnOperation() override;
void recognizeCall(CallOp call);
bool raiseStdFind(CallOp call);
bool raiseIteratorBeginEnd(CallOp call);

// Handle pass options
Expand Down Expand Up @@ -88,30 +129,6 @@ struct IdiomRecognizerPass : public IdiomRecognizerBase<IdiomRecognizerPass> {
};
} // namespace

bool IdiomRecognizerPass::raiseStdFind(CallOp call) {
// FIXME: tablegen all of this function.
if (call.getNumOperands() != 3)
return false;

auto callExprAttr = call.getAstAttr();
if (!callExprAttr || !callExprAttr.isStdFunctionCall("find")) {
return false;
}

if (opts.emitRemarkFoundCalls())
emitRemark(call.getLoc()) << "found call to std::find()";

CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(call.getOperation());
auto findOp = builder.create<cir::StdFindOp>(
call.getLoc(), call.getResult().getType(), call.getCalleeAttr(),
call.getOperand(0), call.getOperand(1), call.getOperand(2));

call.replaceAllUsesWith(findOp);
call.erase();
return true;
}

static bool isIteratorLikeType(mlir::Type t) {
// TODO: some iterators are going to be represented with structs,
// in which case we could look at ASTRecordDeclInterface for more
Expand Down Expand Up @@ -175,8 +192,16 @@ void IdiomRecognizerPass::recognizeCall(CallOp call) {
if (raiseIteratorBeginEnd(call))
return;

if (raiseStdFind(call))
return;
bool remark = opts.emitRemarkFoundCalls();

using StdFunctionsRecognizer = std::tuple<StdRecognizer<StdFindOp>>;

// MSVC requires explicitly capturing these variables.
std::apply(
[&, call, remark, this](auto... recognizers) {
(decltype(recognizers)::raise(call, this->getContext(), remark) || ...);
},
StdFunctionsRecognizer());
}

void IdiomRecognizerPass::runOnOperation() {
Expand Down
6 changes: 2 additions & 4 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1481,8 +1481,7 @@ void LoweringPreparePass::lowerIterBeginOp(IterBeginOp op) {
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op.getOperation());
auto call = builder.createCallOp(op.getLoc(), op.getOriginalFnAttr(),
op.getResult().getType(),
mlir::ValueRange{op.getOperand()});
op.getResult().getType(), op.getOperand());

op.replaceAllUsesWith(call);
op.erase();
Expand All @@ -1492,8 +1491,7 @@ void LoweringPreparePass::lowerIterEndOp(IterEndOp op) {
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op.getOperation());
auto call = builder.createCallOp(op.getLoc(), op.getOriginalFnAttr(),
op.getResult().getType(),
mlir::ValueRange{op.getOperand()});
op.getResult().getType(), op.getOperand());

op.replaceAllUsesWith(call);
op.erase();
Expand Down
11 changes: 5 additions & 6 deletions clang/test/CIR/Transforms/idiom-recognizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ int test_find(unsigned char n = 3)
// expected-remark@-2 {{found call to end() iterator}}

// BEFORE-IDIOM: {{.*}} cir.call @_ZNSt5arrayIhLj9EE5beginEv(
// AFTER-IDIOM: {{.*}} cir.iterator_begin(@_ZNSt5arrayIhLj9EE5beginEv,
// AFTER-IDIOM: {{.*}} cir.std.begin({{.*}}, @_ZNSt5arrayIhLj9EE5beginEv
// AFTER-LOWERING-PREPARE: {{.*}} cir.call @_ZNSt5arrayIhLj9EE5beginEv(

// BEFORE-IDIOM: {{.*}} cir.call @_ZNSt5arrayIhLj9EE3endEv(
// AFTER-IDIOM: {{.*}} cir.iterator_end(@_ZNSt5arrayIhLj9EE3endEv,
// AFTER-IDIOM: {{.*}} cir.std.end({{.*}}, @_ZNSt5arrayIhLj9EE3endEv
// AFTER-LOWERING-PREPARE: {{.*}} cir.call @_ZNSt5arrayIhLj9EE3endEv(

// BEFORE-IDIOM: {{.*}} cir.call @_ZSt4findIPhhET_S1_S1_RKT0_(
// AFTER-IDIOM: {{.*}} cir.std.find(@_ZSt4findIPhhET_S1_S1_RKT0_,
// AFTER-IDIOM: {{.*}} cir.std.find({{.*}}, @_ZSt4findIPhhET_S1_S1_RKT0_
// AFTER-LOWERING-PREPARE: {{.*}} cir.call @_ZSt4findIPhhET_S1_S1_RKT0_(

if (f != v.end()) // expected-remark {{found call to end() iterator}}
Expand All @@ -43,8 +43,7 @@ template<typename T, unsigned N> struct array {
};
}

int iter_test()
{
void iter_test() {
yolo::array<unsigned char, 3> v = {1, 2, 3};
(void)v.begin(); // no remark should be produced.
}
}
Loading