From 668dfb2d335f3446bb0ae2e9529d1e84f5eac8ac Mon Sep 17 00:00:00 2001 From: Sirui Mu Date: Mon, 23 Sep 2024 23:19:47 +0800 Subject: [PATCH 1/8] [mlir][LLVM] Add operand bundle support --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 44 ++- mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 4 + .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 10 +- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 251 ++++++++++++++++-- .../LLVMIR/LLVMToLLVMIRTranslation.cpp | 68 ++++- mlir/test/Dialect/LLVMIR/invalid.mlir | 18 +- mlir/test/Target/LLVMIR/llvmir.mlir | 49 ++++ 7 files changed, 401 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 8584a25f8b3d6..030160821bd82 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -551,7 +551,15 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [ Variadic:$normalDestOperands, Variadic:$unwindDestOperands, OptionalAttr:$branch_weights, - DefaultValuedAttr:$CConv); + DefaultValuedAttr:$CConv, + VariadicOfVariadic:$op_bundle_operands, + DenseI32ArrayAttr:$op_bundle_sizes, + DefaultValuedProperty< + ArrayProperty, + "ArrayRef{}", + "SmallVector{}" + >:$op_bundle_tags); let results = (outs Optional:$result); let successors = (successor AnySuccessor:$normalDest, AnySuccessor:$unwindDest); @@ -607,7 +615,8 @@ def LLVM_VaArgOp : LLVM_Op<"va_arg"> { //===----------------------------------------------------------------------===// def LLVM_CallOp : LLVM_MemAccessOpBase<"call", - [DeclareOpInterfaceMethods, + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { @@ -661,8 +670,15 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", OptionalAttr:$memory_effects, OptionalAttr:$convergent, OptionalAttr:$no_unwind, - OptionalAttr:$will_return - ); + OptionalAttr:$will_return, + VariadicOfVariadic:$op_bundle_operands, + DenseI32ArrayAttr:$op_bundle_sizes, + DefaultValuedProperty< + ArrayProperty, + "ArrayRef{}", + "SmallVector{}" + >:$op_bundle_tags); // Append the aliasing related attributes defined in LLVM_MemAccessOpBase. let arguments = !con(args, aliasAttrs); let results = (outs Optional:$result); @@ -682,6 +698,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee, CArg<"ValueRange", "{}">:$args)> ]; + let hasVerifier = 1; let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ /// Returns the callee function type. @@ -1895,7 +1912,8 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods]> { + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods]> { let summary = "Call to an LLVM intrinsic function."; let description = [{ Call the specified llvm intrinsic. If the intrinsic is overloaded, use @@ -1903,13 +1921,25 @@ def LLVM_CallIntrinsicOp }]; let arguments = (ins StrAttr:$intrin, Variadic:$args, DefaultValuedAttr:$fastmathFlags); + "{}">:$fastmathFlags, + VariadicOfVariadic:$op_bundle_operands, + DenseI32ArrayAttr:$op_bundle_sizes, + DefaultValuedProperty< + ArrayProperty, + "ArrayRef{}", + "SmallVector{}" + >:$op_bundle_tags); let results = (outs Optional:$results); let llvmBuilder = [{ return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation); }]; let assemblyFormat = [{ - $intrin `(` $args `)` `:` functional-type($args, $results) attr-dict + $intrin `(` $args `)` + ( custom($op_bundle_operands, type($op_bundle_operands), + $op_bundle_tags)^ )? + `:` functional-type($args, $results) + attr-dict }]; let hasVerifier = 1; diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 4c2e8682285c5..2cc77e8fd41b9 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -544,6 +544,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), promoted, callOp->getAttrs()); + newOp.getProperties().operandSegmentSizes = { + static_cast(promoted.size()), 0}; + newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({}); + SmallVector results; if (numResults < 2) { // If < 2 results, packing did not do anything and we can just return. diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index ca78631632419..6ae607f75adbd 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -837,8 +837,11 @@ class FunctionCallPattern matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (callOp.getNumResults() == 0) { - rewriter.replaceOpWithNewOp( + auto newOp = rewriter.replaceOpWithNewOp( callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs()); + newOp.getProperties().operandSegmentSizes = { + static_cast(adaptor.getOperands().size()), 0}; + newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({}); return success(); } @@ -846,8 +849,11 @@ class FunctionCallPattern auto dstType = typeConverter.convertType(callOp.getType(0)); if (!dstType) return rewriter.notifyMatchFailure(callOp, "type conversion failed"); - rewriter.replaceOpWithNewOp( + auto newOp = rewriter.replaceOpWithNewOp( callOp, dstType, adaptor.getOperands(), callOp->getAttrs()); + newOp.getProperties().operandSegmentSizes = { + static_cast(adaptor.getOperands().size()), 0}; + newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({}); return success(); } }; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 205d7494d4378..837e0e41800d8 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -220,6 +220,88 @@ static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, return static_cast(index); } +//===----------------------------------------------------------------------===// +// Operand bundle helpers. +//===----------------------------------------------------------------------===// + +static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands, + TypeRange operandTypes, StringRef tag) { + p.printString(tag); + p << "("; + p.printOperands(operands); + p << " : "; + llvm::interleaveComma(operandTypes, p); + p << ")"; +} + +static void printOpBundles(OpAsmPrinter &p, Operation *op, + OperandRangeRange opBundleOperands, + TypeRangeRange opBundleOperandTypes, + ArrayRef opBundleTags) { + p << "["; + llvm::interleaveComma( + llvm::zip(opBundleOperands, opBundleOperandTypes, opBundleTags), p, + [&p](auto bundle) { + printOneOpBundle(p, std::get<0>(bundle), std::get<1>(bundle), + std::get<2>(bundle)); + }); + p << "]"; +} + +static ParseResult parseOneOpBundle( + OpAsmParser &p, + SmallVector> &opBundleOperands, + SmallVector> &opBundleOperandTypes, + SmallVector &opBundleTags) { + auto currentParserLoc = p.getCurrentLocation(); + SmallVector operands; + SmallVector types; + std::string tag; + + if (p.parseString(&tag)) + return p.emitError(currentParserLoc, "expect operand bundle tag"); + + if (p.parseLParen()) + return failure(); + + if (p.parseOperandList(operands)) + return failure(); + if (p.parseColon()) + return failure(); + if (p.parseTypeList(types)) + return failure(); + + if (p.parseRParen()) + return failure(); + + opBundleOperands.push_back(std::move(operands)); + opBundleOperandTypes.push_back(std::move(types)); + opBundleTags.push_back(std::move(tag)); + + return success(); +} + +static std::optional parseOpBundles( + OpAsmParser &p, + SmallVector> &opBundleOperands, + SmallVector> &opBundleOperandTypes, + SmallVector &opBundleTags) { + if (p.parseOptionalLSquare()) + return std::nullopt; + + auto bundleParser = [&] { + return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes, + opBundleTags); + }; + if (p.parseCommaSeparatedList(bundleParser)) + return failure(); + + if (p.parseRSquare()) + return failure(); + + return success(); +} + //===----------------------------------------------------------------------===// // Printing, parsing, folding and builder for LLVM::CmpOp. //===----------------------------------------------------------------------===// @@ -954,6 +1036,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, + /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -980,6 +1063,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, + /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -992,6 +1076,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, + /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1004,6 +1089,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, + /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1027,7 +1113,7 @@ void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) { } Operation::operand_range CallOp::getArgOperands() { - return getOperands().drop_front(getCallee().has_value() ? 0 : 1); + return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1); } MutableOperandRange CallOp::getArgOperandsMutable() { @@ -1100,6 +1186,27 @@ LogicalResult verifyCallOpVarCalleeType(OpTy callOp) { return success(); } +template +static LogicalResult verifyOperandBundles(OpType &op) { + OperandRangeRange opBundleOperands = op.getOpBundleOperands(); + std::optional> opBundleTags = op.getOpBundleTags(); + + if (!opBundleTags.has_value()) { + if (!opBundleOperands.empty()) + return op.emitError("expected operand bundle tags"); + return success(); + } + + if (opBundleTags->size() != opBundleOperands.size()) + return op.emitError("expected ") + << opBundleOperands.size() + << " operand bundle tags, but actually got " << opBundleTags->size(); + + return success(); +} + +LogicalResult CallOp::verify() { return verifyOperandBundles(*this); } + LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { if (failed(verifyCallOpVarCalleeType(*this))) return failure(); @@ -1150,15 +1257,15 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // Verify that the operand and result types match the callee. if (!funcType.isVarArg() && - funcType.getNumParams() != (getNumOperands() - isIndirect)) + funcType.getNumParams() != (getCalleeOperands().size() - isIndirect)) return emitOpError() << "incorrect number of operands (" - << (getNumOperands() - isIndirect) + << (getCalleeOperands().size() - isIndirect) << ") for callee (expecting: " << funcType.getNumParams() << ")"; - if (funcType.getNumParams() > (getNumOperands() - isIndirect)) + if (funcType.getNumParams() > (getCalleeOperands().size() - isIndirect)) return emitOpError() << "incorrect number of operands (" - << (getNumOperands() - isIndirect) + << (getCalleeOperands().size() - isIndirect) << ") for varargs callee (expecting at least: " << funcType.getNumParams() << ")"; @@ -1208,16 +1315,24 @@ void CallOp::print(OpAsmPrinter &p) { else p << getOperand(0); - auto args = getOperands().drop_front(isDirect ? 0 : 1); + auto args = getCalleeOperands().drop_front(isDirect ? 0 : 1); p << '(' << args << ')'; // Print the variadic callee type if the call is variadic. if (std::optional varCalleeType = getVarCalleeType()) p << " vararg(" << *varCalleeType << ")"; + if (!getOpBundleOperands().empty()) { + p << " "; + printOpBundles(p, *this, getOpBundleOperands(), + getOpBundleOperands().getTypes(), getOpBundleTags()); + } + p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {getCalleeAttrName(), getTailCallKindAttrName(), - getVarCalleeTypeAttrName(), getCConvAttrName()}); + getVarCalleeTypeAttrName(), getCConvAttrName(), + getOperandSegmentSizesAttrName(), + getOpBundleSizesAttrName()}); p << " : "; if (!isDirect) @@ -1285,14 +1400,53 @@ static ParseResult parseOptionalCallFuncPtr( return success(); } +static ParseResult resolveOpBundleOperands( + OpAsmParser &parser, SMLoc loc, OperationState &state, + ArrayRef> opBundleOperands, + ArrayRef> opBundleOperandTypes, + StringAttr opBundleSizesAttrName) { + assert(opBundleOperands.size() == opBundleOperandTypes.size() && + "operand bundle operand groups and type groups should match"); + + unsigned opBundleIndex = 0; + for (const auto &[operands, types] : + llvm::zip(opBundleOperands, opBundleOperandTypes)) { + if (operands.size() != types.size()) + return parser.emitError(loc, "expected ") + << operands.size() + << " types for operand bundle operands for operand bundle #" + << opBundleIndex << ", but actually got " << types.size(); + if (parser.resolveOperands(operands, types, loc, state.operands)) + return failure(); + } + + SmallVector opBundleSizes; + opBundleSizes.reserve(opBundleOperands.size()); + for (const auto &operands : opBundleOperands) { + opBundleSizes.push_back(operands.size()); + } + + state.addAttribute( + opBundleSizesAttrName, + DenseI32ArrayAttr::get(parser.getContext(), opBundleSizes)); + + return success(); +} + // ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use) // `(` ssa-use-list `)` // ( `vararg(` var-callee-type `)` )? +// ( `bundlearg(` ssa-use-list-list `)` )? +// ( `bundletags(` str-elements-attr `) ) // attribute-dict? `:` (type `,`)? function-type +// (`,` `bundletype(` type-list-list `)`)? ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { SymbolRefAttr funcAttr; TypeAttr varCalleeType; SmallVector operands; + SmallVector> opBundleOperands; + SmallVector> opBundleOperandTypes; + SmallVector opBundleTags; // Default to C Calling Convention if no keyword is provided. result.addAttribute( @@ -1333,11 +1487,35 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); } + auto opBundlesLoc = parser.getCurrentLocation(); + if (auto result = parseOpBundles(parser, opBundleOperands, + opBundleOperandTypes, opBundleTags); + result.has_value() && failed(*result)) + return failure(); + if (!opBundleTags.empty()) + result.getOrAddProperties().op_bundle_tags = + std::move(opBundleTags); + if (parser.parseOptionalAttrDict(result.attributes)) return failure(); // Parse the trailing type list and resolve the operands. - return parseCallTypeAndResolveOperands(parser, result, isDirect, operands); + if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands)) + return failure(); + if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands, + opBundleOperandTypes, + getOpBundleSizesAttrName(result.name))) + return failure(); + + int32_t numOpBundleOperands = 0; + for (const auto &operands : opBundleOperands) + numOpBundleOperands += operands.size(); + + result.addAttribute( + CallOp::getOperandSegmentSizeAttr(), + parser.getBuilder().getDenseI32ArrayAttr( + {static_cast(operands.size()), numOpBundleOperands})); + return success(); } LLVMFunctionType CallOp::getCalleeFunctionType() { @@ -1356,7 +1534,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, auto calleeType = func.getFunctionType(); build(builder, state, getCallOpResultTypes(calleeType), getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops, - normalOps, unwindOps, nullptr, nullptr, normal, unwind); + normalOps, unwindOps, nullptr, nullptr, {}, std::nullopt, normal, + unwind); } void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys, @@ -1365,7 +1544,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys, ValueRange unwindOps) { build(builder, state, tys, /*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr, - nullptr, normal, unwind); + nullptr, {}, std::nullopt, normal, unwind); } void InvokeOp::build(OpBuilder &builder, OperationState &state, @@ -1374,7 +1553,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, Block *unwind, ValueRange unwindOps) { build(builder, state, getCallOpResultTypes(calleeType), getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps, - nullptr, nullptr, normal, unwind); + nullptr, nullptr, {}, std::nullopt, normal, unwind); } SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) { @@ -1402,7 +1581,7 @@ void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) { } Operation::operand_range InvokeOp::getArgOperands() { - return getOperands().drop_front(getCallee().has_value() ? 0 : 1); + return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1); } MutableOperandRange InvokeOp::getArgOperandsMutable() { @@ -1423,6 +1602,9 @@ LogicalResult InvokeOp::verify() { return emitError("first operation in unwind destination should be a " "llvm.landingpad operation"); + if (failed(verifyOperandBundles(*this))) + return failure(); + return success(); } @@ -1452,9 +1634,16 @@ void InvokeOp::print(OpAsmPrinter &p) { if (std::optional varCalleeType = getVarCalleeType()) p << " vararg(" << *varCalleeType << ")"; + if (!getOpBundleOperands().empty()) { + p << " "; + printOpBundles(p, *this, getOpBundleOperands(), + getOpBundleOperands().getTypes(), getOpBundleTags()); + } + p.printOptionalAttrDict((*this)->getAttrs(), {getCalleeAttrName(), getOperandSegmentSizeAttr(), - getCConvAttrName(), getVarCalleeTypeAttrName()}); + getCConvAttrName(), getVarCalleeTypeAttrName(), + getOpBundleSizesAttrName()}); p << " : "; if (!isDirect) @@ -1468,11 +1657,17 @@ void InvokeOp::print(OpAsmPrinter &p) { // `to` bb-id (`[` ssa-use-and-type-list `]`)? // `unwind` bb-id (`[` ssa-use-and-type-list `]`)? // ( `vararg(` var-callee-type `)` )? +// ( `bundlearg(` ssa-use-list-list `)` )? +// ( `bundletags(` str-elements-attr `) ) // attribute-dict? `:` (type `,`)? function-type +// (`,` `bundletype(` type-list-list `)`)? ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector operands; SymbolRefAttr funcAttr; TypeAttr varCalleeType; + SmallVector> opBundleOperands; + SmallVector> opBundleOperandTypes; + SmallVector opBundleTags; Block *normalDest, *unwindDest; SmallVector normalOperands, unwindOperands; Builder &builder = parser.getBuilder(); @@ -1513,22 +1708,40 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); } + auto opBundlesLoc = parser.getCurrentLocation(); + if (auto result = parseOpBundles(parser, opBundleOperands, + opBundleOperandTypes, opBundleTags); + result.has_value() && failed(*result)) + return failure(); + if (!opBundleTags.empty()) + result.getOrAddProperties().op_bundle_tags = + std::move(opBundleTags); + if (parser.parseOptionalAttrDict(result.attributes)) return failure(); // Parse the trailing type list and resolve the function operands. if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands)) return failure(); + if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands, + opBundleOperandTypes, + getOpBundleSizesAttrName(result.name))) + return failure(); result.addSuccessors({normalDest, unwindDest}); result.addOperands(normalOperands); result.addOperands(unwindOperands); - result.addAttribute(InvokeOp::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr( - {static_cast(operands.size()), - static_cast(normalOperands.size()), - static_cast(unwindOperands.size())})); + int32_t numOpBundleOperands = 0; + for (const auto &operands : opBundleOperands) + numOpBundleOperands += operands.size(); + + result.addAttribute( + InvokeOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({static_cast(operands.size()), + static_cast(normalOperands.size()), + static_cast(unwindOperands.size()), + numOpBundleOperands})); return success(); } @@ -3108,6 +3321,8 @@ OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) { LogicalResult CallIntrinsicOp::verify() { if (!getIntrin().starts_with("llvm.")) return emitOpError() << "intrinsic name must start with 'llvm.'"; + if (failed(verifyOperandBundles(*this))) + return failure(); return success(); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index d948ff5eaf176..53ca302518a90 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -102,6 +102,40 @@ getOverloadedDeclaration(CallIntrinsicOp op, llvm::Intrinsic::ID id, return llvm::Intrinsic::getDeclaration(module, id, overloadedArgTysRef); } +static llvm::OperandBundleDef +convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag, + LLVM::ModuleTranslation &moduleTranslation) { + std::vector operands; + operands.reserve(bundleOperands.size()); + for (auto bundleArg : bundleOperands) + operands.push_back(moduleTranslation.lookupValue(bundleArg)); + return llvm::OperandBundleDef(bundleTag.str(), std::move(operands)); +} + +static SmallVector +convertOperandBundles(OperandRangeRange bundleOperands, + ArrayRef bundleTags, + LLVM::ModuleTranslation &moduleTranslation) { + assert(bundleOperands.size() == bundleTags.size() && + "operand bundles and tags do not match"); + + SmallVector bundles; + bundles.reserve(bundleOperands.size()); + + for (auto [operands, tag] : llvm::zip(bundleOperands, bundleTags)) + bundles.push_back(convertOperandBundle(operands, tag, moduleTranslation)); + return bundles; +} + +static SmallVector +convertOperandBundles(OperandRangeRange bundleOperands, + std::optional> bundleTags, + LLVM::ModuleTranslation &moduleTranslation) { + if (!bundleTags.has_value()) + bundleTags.emplace(); + return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation); +} + /// Builder for LLVM_CallIntrinsicOp static LogicalResult convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, @@ -138,15 +172,15 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, // Check the argument types of the call. If the function is variadic, check // the subrange of required arguments. if (!fn->getFunctionType()->isVarArg() && - op.getNumOperands() != fn->arg_size()) { + op.getArgs().size() != fn->arg_size()) { return mlir::emitError(op.getLoc(), "intrinsic call has ") - << op.getNumOperands() << " operands but " << op.getIntrinAttr() + << op.getArgs().size() << " operands but " << op.getIntrinAttr() << " expects " << fn->arg_size(); } if (fn->getFunctionType()->isVarArg() && - op.getNumOperands() < fn->arg_size()) { + op.getArgs().size() < fn->arg_size()) { return mlir::emitError(op.getLoc(), "intrinsic call has ") - << op.getNumOperands() << " operands but variadic " + << op.getArgs().size() << " operands but variadic " << op.getIntrinAttr() << " expects at least " << fn->arg_size(); } // Check the arguments up to the number the function requires. @@ -164,8 +198,10 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, FastmathFlagsInterface itf = op; builder.setFastMathFlags(getFastmathFlags(itf)); - auto *inst = - builder.CreateCall(fn, moduleTranslation.lookupValues(op.getOperands())); + auto *inst = builder.CreateCall( + fn, moduleTranslation.lookupValues(op.getArgs()), + convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(), + moduleTranslation)); if (op.getNumResults() == 1) moduleTranslation.mapValue(op->getResults().front()) = inst; return success(); @@ -205,17 +241,21 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, // itself. Otherwise, this is an indirect call and the callee is the first // operand, look it up as a normal value. if (auto callOp = dyn_cast(opInst)) { - auto operands = moduleTranslation.lookupValues(callOp.getOperands()); + auto operands = moduleTranslation.lookupValues(callOp.getCalleeOperands()); + SmallVector opBundles = + convertOperandBundles(callOp.getOpBundleOperands(), + callOp.getOpBundleTags(), moduleTranslation); ArrayRef operandsRef(operands); llvm::CallInst *call; if (auto attr = callOp.getCalleeAttr()) { - call = builder.CreateCall( - moduleTranslation.lookupFunction(attr.getValue()), operandsRef); + call = + builder.CreateCall(moduleTranslation.lookupFunction(attr.getValue()), + operandsRef, opBundles); } else { llvm::FunctionType *calleeType = llvm::cast( moduleTranslation.convertType(callOp.getCalleeFunctionType())); call = builder.CreateCall(calleeType, operandsRef.front(), - operandsRef.drop_front()); + operandsRef.drop_front(), opBundles); } call->setCallingConv(convertCConvToLLVM(callOp.getCConv())); call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind())); @@ -312,13 +352,17 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, if (auto invOp = dyn_cast(opInst)) { auto operands = moduleTranslation.lookupValues(invOp.getCalleeOperands()); + SmallVector opBundles = + convertOperandBundles(invOp.getOpBundleOperands(), + invOp.getOpBundleTags(), moduleTranslation); ArrayRef operandsRef(operands); llvm::InvokeInst *result; if (auto attr = opInst.getAttrOfType("callee")) { result = builder.CreateInvoke( moduleTranslation.lookupFunction(attr.getValue()), moduleTranslation.lookupBlock(invOp.getSuccessor(0)), - moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef); + moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef, + opBundles); } else { llvm::FunctionType *calleeType = llvm::cast( moduleTranslation.convertType(invOp.getCalleeFunctionType())); @@ -326,7 +370,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, calleeType, operandsRef.front(), moduleTranslation.lookupBlock(invOp.getSuccessor(0)), moduleTranslation.lookupBlock(invOp.getSuccessor(1)), - operandsRef.drop_front()); + operandsRef.drop_front(), opBundles); } result->setCallingConv(convertCConvToLLVM(invOp.getCConv())); moduleTranslation.mapBranch(invOp, result); diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 6670e4b186c39..1121691133108 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -218,7 +218,7 @@ func.func @store_unaligned_atomic(%val : f32, %ptr : !llvm.ptr) { func.func @invalid_call() { // expected-error@+1 {{'llvm.call' op must have either a `callee` attribute or at least an operand}} - "llvm.call"() : () -> () + "llvm.call"() {op_bundle_sizes = array} : () -> () llvm.return } @@ -286,7 +286,7 @@ func.func @call_non_llvm() { func.func @call_non_llvm_arg(%arg0 : tensor<*xi32>) { // expected-error@+1 {{'llvm.call' op operand #0 must be variadic of LLVM dialect-compatible type}} - "llvm.call"(%arg0) : (tensor<*xi32>) -> () + "llvm.call"(%arg0) {operandSegmentSizes = array, op_bundle_sizes = array} : (tensor<*xi32>) -> () llvm.return } @@ -1588,7 +1588,7 @@ llvm.func @variadic(...) llvm.func @invalid_variadic_call(%arg: i32) { // expected-error@+1 {{missing var_callee_type attribute for vararg call}} - "llvm.call"(%arg) <{callee = @variadic}> : (i32) -> () + "llvm.call"(%arg) <{callee = @variadic}> {operandSegmentSizes = array, op_bundle_sizes = array} : (i32) -> () llvm.return } @@ -1598,7 +1598,7 @@ llvm.func @variadic(...) llvm.func @invalid_variadic_call(%arg: i32) { // expected-error@+1 {{missing var_callee_type attribute for vararg call}} - "llvm.call"(%arg) <{callee = @variadic}> : (i32) -> () + "llvm.call"(%arg) <{callee = @variadic}> {operandSegmentSizes = array, op_bundle_sizes = array} : (i32) -> () llvm.return } @@ -1655,3 +1655,13 @@ llvm.func @alwaysinline_noinline() attributes { always_inline, no_inline } { llvm.func @optnone_requires_noinline() attributes { optimize_none } { llvm.return } + +// ----- + +llvm.func @foo() +llvm.func @wrong_number_of_bundle_types() { + %0 = llvm.mlir.constant(0 : i32) : i32 + // expected-error@+1 {{expected 1 types for operand bundle operands for operand bundle #0, but actually got 2}} + llvm.call @foo() ["tag"(%0 : i32, i32)] : () -> () bundletype((i32, i32)) + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 966a00f9e3c67..025ff4a35a552 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -2626,3 +2626,52 @@ llvm.func @reqd_work_group_size() attributes {reqd_work_group_size = array () + llvm.return +} + +// CHECK: define void @call_with_opbundle() { +// CHECK-NEXT: call void @foo() [ "tag1"(i32 1, i32 2), "tag2"(i32 3) ] +// CHECK-NEXT: ret void +// CHECK-NEXT: } + +llvm.func @__gxx_personality_v0(...) -> i32 +llvm.func @invoke_with_opbundle() attributes { personality = @__gxx_personality_v0 } { + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.mlir.constant(2 : i32) : i32 + %2 = llvm.mlir.constant(3 : i32) : i32 + llvm.invoke @foo() to ^bb2 unwind ^bb1 ["tag1"(%0, %1 : i32, i32), "tag2"(%2 : i32)] : () -> () + +^bb1: + %3 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)> + llvm.return + +^bb2: + llvm.return +} + +// CHECK: define void @invoke_with_opbundle() personality ptr @__gxx_personality_v0 { +// CHECK-NEXT: invoke void @foo() [ "tag1"(i32 1, i32 2), "tag2"(i32 3) ] +// CHECK-NEXT: to label %{{.+}} unwind label %{{.+}} +// CHECK: } + +llvm.func @call_intrin_with_opbundle(%arg0 : !llvm.ptr) { + %0 = llvm.mlir.constant(1 : i1) : i1 + %1 = llvm.mlir.constant(16 : i32) : i32 + llvm.call_intrinsic "llvm.assume"(%0) ["align"(%arg0, %1 : !llvm.ptr, i32)] : (i1) -> () + llvm.return +} + +// CHECK: define void @call_intrin_with_opbundle(ptr %0) { +// CHECK-NEXT: call void @llvm.assume(i1 true) [ "align"(ptr %0, i32 16) ] +// CHECK-NEXT: ret void +// CHECK-NEXT: } From 3a8fc34d90d4e5071bfc7e47bda55e78c3c0a8e2 Mon Sep 17 00:00:00 2001 From: Sirui Mu Date: Tue, 24 Sep 2024 20:28:32 +0800 Subject: [PATCH 2/8] resolve nits before adding more tests --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 46 +++++++------------ .../LLVMIR/LLVMToLLVMIRTranslation.cpp | 4 +- mlir/test/Dialect/LLVMIR/invalid.mlir | 2 +- 3 files changed, 19 insertions(+), 33 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 837e0e41800d8..e7550b653d27b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -253,7 +253,7 @@ static ParseResult parseOneOpBundle( SmallVector> &opBundleOperands, SmallVector> &opBundleOperandTypes, SmallVector &opBundleTags) { - auto currentParserLoc = p.getCurrentLocation(); + SMLoc currentParserLoc = p.getCurrentLocation(); SmallVector operands; SmallVector types; std::string tag; @@ -1189,18 +1189,12 @@ LogicalResult verifyCallOpVarCalleeType(OpTy callOp) { template static LogicalResult verifyOperandBundles(OpType &op) { OperandRangeRange opBundleOperands = op.getOpBundleOperands(); - std::optional> opBundleTags = op.getOpBundleTags(); + ArrayRef opBundleTags = op.getOpBundleTags(); - if (!opBundleTags.has_value()) { - if (!opBundleOperands.empty()) - return op.emitError("expected operand bundle tags"); - return success(); - } - - if (opBundleTags->size() != opBundleOperands.size()) + if (opBundleTags.size() != opBundleOperands.size()) return op.emitError("expected ") << opBundleOperands.size() - << " operand bundle tags, but actually got " << opBundleTags->size(); + << " operand bundle tags, but actually got " << opBundleTags.size(); return success(); } @@ -1405,12 +1399,9 @@ static ParseResult resolveOpBundleOperands( ArrayRef> opBundleOperands, ArrayRef> opBundleOperandTypes, StringAttr opBundleSizesAttrName) { - assert(opBundleOperands.size() == opBundleOperandTypes.size() && - "operand bundle operand groups and type groups should match"); - unsigned opBundleIndex = 0; for (const auto &[operands, types] : - llvm::zip(opBundleOperands, opBundleOperandTypes)) { + llvm::zip_equal(opBundleOperands, opBundleOperandTypes)) { if (operands.size() != types.size()) return parser.emitError(loc, "expected ") << operands.size() @@ -1422,9 +1413,8 @@ static ParseResult resolveOpBundleOperands( SmallVector opBundleSizes; opBundleSizes.reserve(opBundleOperands.size()); - for (const auto &operands : opBundleOperands) { + for (const auto &operands : opBundleOperands) opBundleSizes.push_back(operands.size()); - } state.addAttribute( opBundleSizesAttrName, @@ -1436,10 +1426,8 @@ static ParseResult resolveOpBundleOperands( // ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use) // `(` ssa-use-list `)` // ( `vararg(` var-callee-type `)` )? -// ( `bundlearg(` ssa-use-list-list `)` )? -// ( `bundletags(` str-elements-attr `) ) +// ( `[` op-bundles-list `]` )? // attribute-dict? `:` (type `,`)? function-type -// (`,` `bundletype(` type-list-list `)`)? ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { SymbolRefAttr funcAttr; TypeAttr varCalleeType; @@ -1487,10 +1475,10 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); } - auto opBundlesLoc = parser.getCurrentLocation(); - if (auto result = parseOpBundles(parser, opBundleOperands, - opBundleOperandTypes, opBundleTags); - result.has_value() && failed(*result)) + SMLoc opBundlesLoc = parser.getCurrentLocation(); + if (std::optional result = parseOpBundles( + parser, opBundleOperands, opBundleOperandTypes, opBundleTags); + result && failed(*result)) return failure(); if (!opBundleTags.empty()) result.getOrAddProperties().op_bundle_tags = @@ -1657,10 +1645,8 @@ void InvokeOp::print(OpAsmPrinter &p) { // `to` bb-id (`[` ssa-use-and-type-list `]`)? // `unwind` bb-id (`[` ssa-use-and-type-list `]`)? // ( `vararg(` var-callee-type `)` )? -// ( `bundlearg(` ssa-use-list-list `)` )? -// ( `bundletags(` str-elements-attr `) ) +// ( `[` op-bundles-list `]` )? // attribute-dict? `:` (type `,`)? function-type -// (`,` `bundletype(` type-list-list `)`)? ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector operands; SymbolRefAttr funcAttr; @@ -1708,10 +1694,10 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); } - auto opBundlesLoc = parser.getCurrentLocation(); - if (auto result = parseOpBundles(parser, opBundleOperands, - opBundleOperandTypes, opBundleTags); - result.has_value() && failed(*result)) + SMLoc opBundlesLoc = parser.getCurrentLocation(); + if (std::optional result = parseOpBundles( + parser, opBundleOperands, opBundleOperandTypes, opBundleTags); + result && failed(*result)) return failure(); if (!opBundleTags.empty()) result.getOrAddProperties().op_bundle_tags = diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 53ca302518a90..cd4e760c3b4bc 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -107,7 +107,7 @@ convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag, LLVM::ModuleTranslation &moduleTranslation) { std::vector operands; operands.reserve(bundleOperands.size()); - for (auto bundleArg : bundleOperands) + for (Value bundleArg : bundleOperands) operands.push_back(moduleTranslation.lookupValue(bundleArg)); return llvm::OperandBundleDef(bundleTag.str(), std::move(operands)); } @@ -131,7 +131,7 @@ static SmallVector convertOperandBundles(OperandRangeRange bundleOperands, std::optional> bundleTags, LLVM::ModuleTranslation &moduleTranslation) { - if (!bundleTags.has_value()) + if (!bundleTags) bundleTags.emplace(); return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation); } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 1121691133108..afe01d3ff89d6 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1662,6 +1662,6 @@ llvm.func @foo() llvm.func @wrong_number_of_bundle_types() { %0 = llvm.mlir.constant(0 : i32) : i32 // expected-error@+1 {{expected 1 types for operand bundle operands for operand bundle #0, but actually got 2}} - llvm.call @foo() ["tag"(%0 : i32, i32)] : () -> () bundletype((i32, i32)) + llvm.call @foo() ["tag"(%0 : i32, i32)] : () -> () llvm.return } From ac42d461064eea3a428dd22410baeb4cb79e5ae0 Mon Sep 17 00:00:00 2001 From: Sirui Mu Date: Tue, 24 Sep 2024 20:51:38 +0800 Subject: [PATCH 3/8] parse empty operand bundles list --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 3 +++ mlir/test/Target/LLVMIR/llvmir.mlir | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index e7550b653d27b..80f6ae7a224c8 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -289,6 +289,9 @@ static std::optional parseOpBundles( if (p.parseOptionalLSquare()) return std::nullopt; + if (succeeded(p.parseOptionalRSquare())) + return success(); + auto bundleParser = [&] { return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes, opBundleTags); diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 025ff4a35a552..189e541e5fc33 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -2631,6 +2631,16 @@ llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 3 llvm.func @foo() +llvm.func @call_with_empty_opbundle() { + llvm.call @foo() [] : () -> () + llvm.return +} + +// CHECK: define void @call_with_empty_opbundle() { +// CHECK-NEXT: call void @foo() +// CHECK-NEXT: ret void +// CHECK-NEXT: } + llvm.func @call_with_opbundle() { %0 = llvm.mlir.constant(1 : i32) : i32 %1 = llvm.mlir.constant(2 : i32) : i32 From 9acfb951dc47bffc9057b74b6ce4669e6e30a712 Mon Sep 17 00:00:00 2001 From: Sirui Mu Date: Tue, 24 Sep 2024 21:02:35 +0800 Subject: [PATCH 4/8] add test for operand bundle verifier --- mlir/test/Dialect/LLVMIR/invalid.mlir | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index afe01d3ff89d6..9388d7ef24936 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1665,3 +1665,18 @@ llvm.func @wrong_number_of_bundle_types() { llvm.call @foo() ["tag"(%0 : i32, i32)] : () -> () llvm.return } + +// ----- + +llvm.func @foo() +llvm.func @wrong_number_of_bundle_tags() { + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.mlir.constant(1 : i32) : i32 + // expected-error@+1 {{expected 2 operand bundle tags, but actually got 1}} + "llvm.call"(%0, %1) <{ op_bundle_tags = ["tag"] }> { + callee = @foo, + operandSegmentSizes = array, + op_bundle_sizes = array + } : (i32, i32) -> () + llvm.return +} From cc0e132be3b4face9a1f3d9fd9062f878bc548e6 Mon Sep 17 00:00:00 2001 From: Sirui Mu Date: Tue, 24 Sep 2024 21:29:34 +0800 Subject: [PATCH 5/8] parse empty operands within a bundle --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 14 +++++--------- mlir/test/Target/LLVMIR/llvmir.mlir | 10 ++++++++++ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 80f6ae7a224c8..4b95e5486e74a 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -264,15 +264,11 @@ static ParseResult parseOneOpBundle( if (p.parseLParen()) return failure(); - if (p.parseOperandList(operands)) - return failure(); - if (p.parseColon()) - return failure(); - if (p.parseTypeList(types)) - return failure(); - - if (p.parseRParen()) - return failure(); + if (p.parseOptionalRParen()) { + if (p.parseOperandList(operands) || p.parseColon() || + p.parseTypeList(types) || p.parseRParen()) + return failure(); + } opBundleOperands.push_back(std::move(operands)); opBundleOperandTypes.push_back(std::move(types)); diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 189e541e5fc33..007284d0ca443 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -2641,6 +2641,16 @@ llvm.func @call_with_empty_opbundle() { // CHECK-NEXT: ret void // CHECK-NEXT: } +llvm.func @call_with_empty_opbundle_operands() { + llvm.call @foo() ["tag"()] : () -> () + llvm.return +} + +// CHECK: define void @call_with_empty_opbundle_operands() { +// CHECK-NEXT: call void @foo() [ "tag"() ] +// CHECK-NEXT: ret void +// CHECK-NEXT: } + llvm.func @call_with_opbundle() { %0 = llvm.mlir.constant(1 : i32) : i32 %1 = llvm.mlir.constant(2 : i32) : i32 From 2290580a6edf301b92c418509ce5aaa380d7d7d4 Mon Sep 17 00:00:00 2001 From: Sirui Mu Date: Tue, 24 Sep 2024 21:32:43 +0800 Subject: [PATCH 6/8] nit: replace llvm::zip with llvm::zip_equal --- .../Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index cd4e760c3b4bc..78a3f1809aec3 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -116,13 +116,10 @@ static SmallVector convertOperandBundles(OperandRangeRange bundleOperands, ArrayRef bundleTags, LLVM::ModuleTranslation &moduleTranslation) { - assert(bundleOperands.size() == bundleTags.size() && - "operand bundles and tags do not match"); - SmallVector bundles; bundles.reserve(bundleOperands.size()); - for (auto [operands, tag] : llvm::zip(bundleOperands, bundleTags)) + for (auto [operands, tag] : llvm::zip_equal(bundleOperands, bundleTags)) bundles.push_back(convertOperandBundle(operands, tag, moduleTranslation)); return bundles; } From 43513981c203ef56eddfa6b10e82cb7835086ee3 Mon Sep 17 00:00:00 2001 From: Sirui Mu Date: Tue, 24 Sep 2024 21:58:14 +0800 Subject: [PATCH 7/8] add roundtrip test for operand bundle syntax --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 17 +++-- mlir/test/Dialect/LLVMIR/roundtrip.mlir | 83 ++++++++++++++++++++++ 2 files changed, 94 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 4b95e5486e74a..0561c364c7d59 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -228,9 +228,13 @@ static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands, TypeRange operandTypes, StringRef tag) { p.printString(tag); p << "("; - p.printOperands(operands); - p << " : "; - llvm::interleaveComma(operandTypes, p); + + if (!operands.empty()) { + p.printOperands(operands); + p << " : "; + llvm::interleaveComma(operandTypes, p); + } + p << ")"; } @@ -1611,7 +1615,7 @@ void InvokeOp::print(OpAsmPrinter &p) { else p << getOperand(0); - p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')'; + p << '(' << getCalleeOperands().drop_front(isDirect ? 0 : 1) << ')'; p << " to "; p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands()); p << " unwind "; @@ -1635,8 +1639,9 @@ void InvokeOp::print(OpAsmPrinter &p) { p << " : "; if (!isDirect) p << getOperand(0).getType() << ", "; - p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1), - getResultTypes()); + p.printFunctionalType( + llvm::drop_begin(getCalleeOperands().getTypes(), isDirect ? 0 : 1), + getResultTypes()); } // ::= `llvm.invoke` (cconv)? (function-id | ssa-use) diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 89d303fcac8ff..62f1de2b7fe7d 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -751,3 +751,86 @@ llvm.func @vector_predication_intrinsics(%A: vector<8xi32>, %B: vector<8xi32>, (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32> llvm.return } + +llvm.func @op_bundle_target() + +// CHECK-LABEL: @test_call_with_empty_opbundle +llvm.func @test_call_with_empty_opbundle() { + // CHECK: llvm.call @op_bundle_target() : () -> () + llvm.call @op_bundle_target() [] : () -> () + llvm.return +} + +// CHECK-LABEL: @test_call_with_empty_opbundle_operands +llvm.func @test_call_with_empty_opbundle_operands() { + // CHECK: llvm.call @op_bundle_target() ["tag"()] : () -> () + llvm.call @op_bundle_target() ["tag"()] : () -> () + llvm.return +} + +// CHECK-LABEL: @test_call_with_opbundle +llvm.func @test_call_with_opbundle() { + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.mlir.constant(1 : i32) : i32 + %2 = llvm.mlir.constant(2 : i32) : i32 + // CHECK: llvm.call @op_bundle_target() ["tag1"(%{{.+}}, %{{.+}} : i32, i32), "tag2"(%{{.+}} : i32)] : () -> () + llvm.call @op_bundle_target() ["tag1"(%0, %1 : i32, i32), "tag2"(%2 : i32)] : () -> () + llvm.return +} + +// CHECK-LABEL: @test_invoke_with_empty_opbundle +llvm.func @test_invoke_with_empty_opbundle() attributes { personality = @__gxx_personality_v0 } { + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.mlir.constant(2 : i32) : i32 + %2 = llvm.mlir.constant(3 : i32) : i32 + // CHECK: llvm.invoke @op_bundle_target() to ^{{.+}} unwind ^{{.+}} : () -> () + llvm.invoke @op_bundle_target() to ^bb2 unwind ^bb1 [] : () -> () + +^bb1: + %3 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)> + llvm.return + +^bb2: + llvm.return +} + +// CHECK-LABEL: @test_invoke_with_empty_opbundle_operands +llvm.func @test_invoke_with_empty_opbundle_operands() attributes { personality = @__gxx_personality_v0 } { + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.mlir.constant(2 : i32) : i32 + %2 = llvm.mlir.constant(3 : i32) : i32 + // CHECK: llvm.invoke @op_bundle_target() to ^{{.+}} unwind ^{{.+}} ["tag"()] : () -> () + llvm.invoke @op_bundle_target() to ^bb2 unwind ^bb1 ["tag"()] : () -> () + +^bb1: + %3 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)> + llvm.return + +^bb2: + llvm.return +} + +// CHECK-LABEL: @test_invoke_with_opbundle +llvm.func @test_invoke_with_opbundle() attributes { personality = @__gxx_personality_v0 } { + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.mlir.constant(2 : i32) : i32 + %2 = llvm.mlir.constant(3 : i32) : i32 + // CHECK: llvm.invoke @op_bundle_target() to ^{{.+}} unwind ^{{.+}} ["tag1"(%{{.+}}, %{{.+}} : i32, i32), "tag2"(%{{.+}} : i32)] : () -> () + llvm.invoke @op_bundle_target() to ^bb2 unwind ^bb1 ["tag1"(%0, %1 : i32, i32), "tag2"(%2 : i32)] : () -> () + +^bb1: + %3 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)> + llvm.return + +^bb2: + llvm.return +} + +// CHECK-LABEL: @test_call_intrin_with_opbundle +llvm.func @test_call_intrin_with_opbundle(%arg0 : !llvm.ptr) { + %0 = llvm.mlir.constant(1 : i1) : i1 + %1 = llvm.mlir.constant(16 : i32) : i32 + // CHECK: llvm.call_intrinsic "llvm.assume"(%{{.+}}) ["align"(%{{.+}}, %{{.+}} : !llvm.ptr, i32)] : (i1) -> () + llvm.call_intrinsic "llvm.assume"(%0) ["align"(%arg0, %1 : !llvm.ptr, i32)] : (i1) -> () + llvm.return +} From bf4da2a7c95882011489cb39a0fe31c095ec5cc6 Mon Sep 17 00:00:00 2001 From: Sirui Mu Date: Thu, 26 Sep 2024 00:38:29 +0800 Subject: [PATCH 8/8] fix flang regressions --- flang/lib/Optimizer/CodeGen/CodeGen.cpp | 39 ++++++++++++++++++++----- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 88293bcf36a78..efc28e9708e19 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -110,6 +110,26 @@ static unsigned getLenParamFieldId(mlir::Type ty) { return getTypeDescFieldId(ty) + 1; } +static llvm::SmallVector +addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter, + llvm::ArrayRef attrs, + int32_t numCallOperands) { + llvm::SmallVector newAttrs; + newAttrs.reserve(attrs.size() + 2); + + for (mlir::NamedAttribute attr : attrs) { + if (attr.getName() != "operandSegmentSizes") + newAttrs.push_back(attr); + } + + newAttrs.push_back(rewriter.getNamedAttr( + "operandSegmentSizes", + rewriter.getDenseI32ArrayAttr({numCallOperands, 0}))); + newAttrs.push_back(rewriter.getNamedAttr("op_bundle_sizes", + rewriter.getDenseI32ArrayAttr({}))); + return newAttrs; +} + namespace { /// Lower `fir.address_of` operation to `llvm.address_of` operation. struct AddrOfOpConversion : public fir::FIROpConversion { @@ -229,7 +249,8 @@ struct AllocaOpConversion : public fir::FIROpConversion { mlir::NamedAttribute attr = rewriter.getNamedAttr( "callee", mlir::SymbolRefAttr::get(memSizeFn)); auto call = rewriter.create( - loc, ity, lenParams, llvm::ArrayRef{attr}); + loc, ity, lenParams, + addLLVMOpBundleAttrs(rewriter, {attr}, lenParams.size())); size = call.getResult(); llvmObjectType = ::getI8Type(alloc.getContext()); } else { @@ -559,7 +580,9 @@ struct CallOpConversion : public fir::FIROpConversion { mlir::arith::AttrConvertFastMathToLLVM attrConvert(call); rewriter.replaceOpWithNewOp( - call, resultTys, adaptor.getOperands(), attrConvert.getAttrs()); + call, resultTys, adaptor.getOperands(), + addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(), + adaptor.getOperands().size())); return mlir::success(); } }; @@ -980,7 +1003,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion { loc, ity, size, integerCast(loc, rewriter, ity, opnd)); heap->setAttr("callee", getMalloc(heap, rewriter)); rewriter.replaceOpWithNewOp( - heap, ::getLlvmPtrType(heap.getContext()), size, heap->getAttrs()); + heap, ::getLlvmPtrType(heap.getContext()), size, + addLLVMOpBundleAttrs(rewriter, heap->getAttrs(), 1)); return mlir::success(); } @@ -1037,9 +1061,9 @@ struct FreeMemOpConversion : public fir::FIROpConversion { mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = freemem.getLoc(); freemem->setAttr("callee", getFree(freemem, rewriter)); - rewriter.create(loc, mlir::TypeRange{}, - mlir::ValueRange{adaptor.getHeapref()}, - freemem->getAttrs()); + rewriter.create( + loc, mlir::TypeRange{}, mlir::ValueRange{adaptor.getHeapref()}, + addLLVMOpBundleAttrs(rewriter, freemem->getAttrs(), 1)); rewriter.eraseOp(freemem); return mlir::success(); } @@ -2671,7 +2695,8 @@ struct FieldIndexOpConversion : public fir::FIROpConversion { "field", mlir::IntegerAttr::get(lowerTy().indexType(), index)); rewriter.replaceOpWithNewOp( field, lowerTy().offsetType(), adaptor.getOperands(), - llvm::ArrayRef{callAttr, fieldAttr}); + addLLVMOpBundleAttrs(rewriter, {callAttr, fieldAttr}, + adaptor.getOperands().size())); return mlir::success(); }