Skip to content
81 changes: 81 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2254,6 +2254,87 @@ def CIR_StackRestoreOp : CIR_Op<"stackrestore"> {
let assemblyFormat = "$ptr attr-dict `:` qualified(type($ptr))";
}

//===----------------------------------------------------------------------===//
// InlineAsmOp
//===----------------------------------------------------------------------===//

def CIR_AsmFlavor : CIR_I32EnumAttr<"AsmFlavor", "ATT or Intel",
[I32EnumAttrCase<"x86_att", 0>,
I32EnumAttrCase<"x86_intel", 1>]>;

def CIR_InlineAsmOp : CIR_Op<"asm", [RecursiveMemoryEffects]> {
let description = [{
The `cir.asm` operation represents C/C++ asm inline.

CIR constraints strings follow the same rules that are established for
the C level assembler constraints with several differences caused by
clang::AsmStmt processing.

Thus, numbers that appears in the constraint string may also refer to:
- the output variable index referenced by the input operands.
- the index of early-clobber operand

Operand attributes are a storage, where each element corresponds to the
operand with the same index. The first index relates to the operation
result (if any).
The operands themselves are stored as VariadicOfVariadic in the following
order: output, input and then in/out operands. When several output operands
are present, the result type may be represented as an anonymous record type.

Example:
```C++
__asm__("foo" : : : );
__asm__("bar $42 %[val]" : [val] "=r" (x), "+&r"(x));
__asm__("baz $42 %[val]" : [val] "=r" (x), "+&r"(x) : "[val]"(y));
```

```mlir
!rec_22anon2E022 = !cir.record<struct "anon.0" {!cir.int<s, 32>, !cir.int<s, 32>}>
!rec_22anon2E122 = !cir.record<struct "anon.1" {!cir.int<s, 32>, !cir.int<s, 32>}>
...
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init]
%1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init]
...
%2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
%3 = cir.load %1 : !cir.ptr<!s32i>, !s32i

cir.asm(x86_att,
out = [],
in = [],
in_out = [],
{"foo" "~{dirflag},~{fpsr},~{flags}"}) side_effects

cir.asm(x86_att,
out = [],
in = [],
in_out = [%2 : !s32i],
{"bar $$42 $0" "=r,=&r,1,~{dirflag},~{fpsr},~{flags}"}) -> !rec_22anon2E022

cir.asm(x86_att,
out = [],
in = [%3 : !s32i],
in_out = [%2 : !s32i],
{"baz $$42 $0" "=r,=&r,0,1,~{dirflag},~{fpsr},~{flags}"}) -> !rec_22anon2E122
```
}];

let results = (outs Optional<CIR_AnyType>:$res);

let arguments =
(ins VariadicOfVariadic<AnyType, "operands_segments">:$asm_operands,
StrAttr:$asm_string, StrAttr:$constraints, UnitAttr:$side_effects,
CIR_AsmFlavor:$asm_flavor, ArrayAttr:$operand_attrs,
DenseI32ArrayAttr:$operands_segments);

let builders = [OpBuilder<(ins
"llvm::ArrayRef<mlir::ValueRange>":$asmOperands,
"llvm::StringRef":$asmString, "llvm::StringRef":$constraints,
"bool":$sideEffects, "AsmFlavor":$asmFlavor,
"llvm::ArrayRef<mlir::Attribute>":$operandAttrs)>];

let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// UnreachableOp
//===----------------------------------------------------------------------===//
Expand Down
203 changes: 203 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2419,6 +2419,209 @@ OpFoldResult RotateOp::fold(FoldAdaptor adaptor) {
return IntAttr::get(input.getContext(), input.getType(), resultValue);
}

//===----------------------------------------------------------------------===//
// InlineAsmOp
//===----------------------------------------------------------------------===//

void cir::InlineAsmOp::print(OpAsmPrinter &p) {
p << '(' << getAsmFlavor() << ", ";
p.increaseIndent();
p.printNewline();

llvm::SmallVector<std::string, 3> names{"out", "in", "in_out"};
auto *nameIt = names.begin();
auto *attrIt = getOperandAttrs().begin();

for (mlir::OperandRange ops : getAsmOperands()) {
p << *nameIt << " = ";

p << '[';
llvm::interleaveComma(llvm::make_range(ops.begin(), ops.end()), p,
[&](Value value) {
p.printOperand(value);
p << " : " << value.getType();
if (*attrIt)
p << " (maybe_memory)";
attrIt++;
});
p << "],";
p.printNewline();
++nameIt;
}

p << "{";
p.printString(getAsmString());
p << " ";
p.printString(getConstraints());
p << "}";
p.decreaseIndent();
p << ')';
if (getSideEffects())
p << " side_effects";

std::array elidedAttrs{
llvm::StringRef("asm_flavor"), llvm::StringRef("asm_string"),
llvm::StringRef("constraints"), llvm::StringRef("operand_attrs"),
llvm::StringRef("operands_segments"), llvm::StringRef("side_effects")};
p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs);

if (auto v = getRes())
p << " -> " << v.getType();
}

void cir::InlineAsmOp::build(OpBuilder &odsBuilder, OperationState &odsState,
ArrayRef<ValueRange> asmOperands,
StringRef asmString, StringRef constraints,
bool sideEffects, cir::AsmFlavor asmFlavor,
ArrayRef<Attribute> operandAttrs) {
// Set up the operands_segments for VariadicOfVariadic
SmallVector<int32_t> segments;
for (auto operandRange : asmOperands) {
segments.push_back(operandRange.size());
odsState.addOperands(operandRange);
}

odsState.addAttribute(
"operands_segments",
DenseI32ArrayAttr::get(odsBuilder.getContext(), segments));
odsState.addAttribute("asm_string", odsBuilder.getStringAttr(asmString));
odsState.addAttribute("constraints", odsBuilder.getStringAttr(constraints));
odsState.addAttribute("asm_flavor",
AsmFlavorAttr::get(odsBuilder.getContext(), asmFlavor));

if (sideEffects)
odsState.addAttribute("side_effects", odsBuilder.getUnitAttr());

odsState.addAttribute("operand_attrs", odsBuilder.getArrayAttr(operandAttrs));
}

ParseResult cir::InlineAsmOp::parse(OpAsmParser &parser,
OperationState &result) {
llvm::SmallVector<mlir::Attribute> operandAttrs;
llvm::SmallVector<int32_t> operandsGroupSizes;
std::string asmString, constraints;
Type resType;
MLIRContext *ctxt = parser.getBuilder().getContext();

auto error = [&](const Twine &msg) -> LogicalResult {
return parser.emitError(parser.getCurrentLocation(), msg);
};

auto expected = [&](const std::string &c) {
return error("expected '" + c + "'");
};

if (parser.parseLParen().failed())
return expected("(");

auto flavor = FieldParser<AsmFlavor, AsmFlavor>::parse(parser);
if (failed(flavor))
return error("Unknown AsmFlavor");

if (parser.parseComma().failed())
return expected(",");

auto parseValue = [&](Value &v) {
OpAsmParser::UnresolvedOperand op;

if (parser.parseOperand(op) || parser.parseColon())
return error("can't parse operand");

Type typ;
if (parser.parseType(typ).failed())
return error("can't parse operand type");
llvm::SmallVector<mlir::Value> tmp;
if (parser.resolveOperand(op, typ, tmp))
return error("can't resolve operand");
v = tmp[0];
return mlir::success();
};

auto parseOperands = [&](llvm::StringRef name) {
if (parser.parseKeyword(name).failed())
return error("expected " + name + " operands here");
if (parser.parseEqual().failed())
return expected("=");
if (parser.parseLSquare().failed())
return expected("[");

int size = 0;
if (parser.parseOptionalRSquare().succeeded()) {
operandsGroupSizes.push_back(size);
if (parser.parseComma())
return expected(",");
return mlir::success();
}

auto parseOperand = [&]() {
Value val;
if (parseValue(val).succeeded()) {
result.operands.push_back(val);
size++;

if (parser.parseOptionalLParen().failed()) {
operandAttrs.push_back(mlir::Attribute());
return mlir::success();
}

if (parser.parseKeyword("maybe_memory").succeeded()) {
operandAttrs.push_back(mlir::UnitAttr::get(ctxt));
if (parser.parseRParen())
return expected(")");
return mlir::success();
} else {
return expected("maybe_memory");
}
}
return mlir::failure();
};

if (parser.parseCommaSeparatedList(parseOperand).failed())
return mlir::failure();

if (parser.parseRSquare().failed() || parser.parseComma().failed())
return expected("]");
operandsGroupSizes.push_back(size);
return mlir::success();
};

if (parseOperands("out").failed() || parseOperands("in").failed() ||
parseOperands("in_out").failed())
return error("failed to parse operands");

if (parser.parseLBrace())
return expected("{");
if (parser.parseString(&asmString))
return error("asm string parsing failed");
if (parser.parseString(&constraints))
return error("constraints string parsing failed");
if (parser.parseRBrace())
return expected("}");
if (parser.parseRParen())
return expected(")");

if (parser.parseOptionalKeyword("side_effects").succeeded())
result.attributes.set("side_effects", UnitAttr::get(ctxt));

if (parser.parseOptionalArrow().succeeded() &&
parser.parseType(resType).failed())
return mlir::failure();

if (parser.parseOptionalAttrDict(result.attributes).failed())
return mlir::failure();

result.attributes.set("asm_flavor", AsmFlavorAttr::get(ctxt, *flavor));
result.attributes.set("asm_string", StringAttr::get(ctxt, asmString));
result.attributes.set("constraints", StringAttr::get(ctxt, constraints));
result.attributes.set("operand_attrs", ArrayAttr::get(ctxt, operandAttrs));
result.getOrAddProperties<InlineAsmOp::Properties>().operands_segments =
parser.getBuilder().getDenseI32ArrayAttr(operandsGroupSizes);
if (resType)
result.addTypes(TypeRange{resType});

return mlir::success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
64 changes: 64 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2272,6 +2272,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
patterns.add<CIRToLLVMCastOpLowering>(converter, patterns.getContext(), dl);
patterns.add<CIRToLLVMPtrStrideOpLowering>(converter, patterns.getContext(),
dl);
patterns.add<CIRToLLVMInlineAsmOpLowering>(converter, patterns.getContext(),
dl);
patterns.add<
// clang-format off
CIRToLLVMAssumeOpLowering,
Expand Down Expand Up @@ -2905,6 +2907,68 @@ mlir::LogicalResult CIRToLLVMGetBitfieldOpLowering::matchAndRewrite(
return mlir::success();
}

mlir::LogicalResult CIRToLLVMInlineAsmOpLowering::matchAndRewrite(
cir::InlineAsmOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Type llResTy;
if (op.getNumResults())
llResTy = getTypeConverter()->convertType(op.getType(0));

cir::AsmFlavor dialect = op.getAsmFlavor();
mlir::LLVM::AsmDialect llDialect = dialect == cir::AsmFlavor::x86_att
? mlir::LLVM::AsmDialect::AD_ATT
: mlir::LLVM::AsmDialect::AD_Intel;

SmallVector<mlir::Attribute> opAttrs;
StringRef llvmAttrName = mlir::LLVM::InlineAsmOp::getElementTypeAttrName();

// this is for the lowering to LLVM from LLVM dialect. Otherwise, if we
// don't have the result (i.e. void type as a result of operation), the
// element type attribute will be attached to the whole instruction, but not
// to the operand
if (!op.getNumResults())
opAttrs.push_back(mlir::Attribute());

SmallVector<mlir::Value> llvmOperands;
SmallVector<mlir::Value> cirOperands;
for (auto const&[llvmOp, cirOp] :
zip(adaptor.getAsmOperands(), op.getAsmOperands())) {
append_range(llvmOperands, llvmOp);
append_range(cirOperands, cirOp);
}

// so far we infer the llvm dialect element type attr from
// CIR operand type.
for (auto const&[cirOpAttr, cirOp] : zip(op.getOperandAttrs(), cirOperands)) {
if (!cirOpAttr) {
opAttrs.push_back(mlir::Attribute());
continue;
}

llvm::SmallVector<mlir::NamedAttribute, 1> attrs;
cir::PointerType typ =
mlir::cast<cir::PointerType>(cirOp.getType());
mlir::TypeAttr typAttr = mlir::TypeAttr::get(convertTypeForMemory(
*getTypeConverter(), dataLayout, typ.getPointee()));

attrs.push_back(rewriter.getNamedAttr(llvmAttrName, typAttr));
mlir::DictionaryAttr newDict = rewriter.getDictionaryAttr(attrs);
opAttrs.push_back(newDict);
}

rewriter.replaceOpWithNewOp<mlir::LLVM::InlineAsmOp>(
op, llResTy, llvmOperands, op.getAsmStringAttr(), op.getConstraintsAttr(),
op.getSideEffectsAttr(),
/*is_align_stack*/ mlir::UnitAttr(),
/*tail_call_kind*/
mlir::LLVM::TailCallKindAttr::get(
getContext(), mlir::LLVM::tailcallkind::TailCallKind::None),
mlir::LLVM::AsmDialectAttr::get(getContext(), llDialect),
rewriter.getArrayAttr(opAttrs));

return mlir::success();
}

std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
return std::make_unique<ConvertCIRToLLVMPass>();
}
Expand Down
17 changes: 17 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,23 @@ class CIRToLLVMFAbsOpLowering : public mlir::OpConversionPattern<cir::FAbsOp> {
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMInlineAsmOpLowering
: public mlir::OpConversionPattern<cir::InlineAsmOp> {
mlir::DataLayout const &dataLayout;

public:
CIRToLLVMInlineAsmOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
mlir::DataLayout const &dataLayout)
: OpConversionPattern(typeConverter, context), dataLayout(dataLayout) {}

using mlir::OpConversionPattern<cir::InlineAsmOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::InlineAsmOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

} // namespace direct
} // namespace cir

Expand Down
Loading