Skip to content

Commit bbb2bc2

Browse files
committed
[mlir][LLVM] Add operand bundle support
1 parent 30d7dcc commit bbb2bc2

File tree

8 files changed

+363
-54
lines changed

8 files changed

+363
-54
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,4 +1210,33 @@ def WorkgroupAttributionAttr
12101210
let assemblyFormat = "`<` $num_elements `,` $element_type `>`";
12111211
}
12121212

1213+
//===----------------------------------------------------------------------===//
1214+
// OperandBundleAttr
1215+
//===----------------------------------------------------------------------===//
1216+
1217+
def LLVM_OperandBundleAttr : LLVM_Attr<"OperandBundle", "opbundle"> {
1218+
let summary = "Operand bundle information";
1219+
let description = [{
1220+
Provide information about a single operand bundle. Each operand bundle has a
1221+
string tag together with various number of SSA value uses. The SSA values
1222+
are specified through indices into the operation's operand bundle operands.
1223+
}];
1224+
1225+
let parameters = (ins "StringAttr":$tag,
1226+
OptionalArrayRefParameter<"uint32_t">:$argIndices);
1227+
let assemblyFormat = [{
1228+
`<` $tag (`,` $argIndices^)? `>`
1229+
}];
1230+
}
1231+
1232+
def LLVM_OperandBundlesAttr : LLVM_Attr<"OperandBundles", "opbundles"> {
1233+
let summary = "A list of operand bundle attributes";
1234+
let description = "A list of operand bundle attributes";
1235+
1236+
let parameters = (ins ArrayRefParameter<"OperandBundleAttr">:$bundles);
1237+
let assemblyFormat = [{
1238+
`<` $bundles `>`
1239+
}];
1240+
}
1241+
12131242
#endif // LLVMIR_ATTRDEFS

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -550,8 +550,10 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
550550
Variadic<LLVM_Type>:$callee_operands,
551551
Variadic<LLVM_Type>:$normalDestOperands,
552552
Variadic<LLVM_Type>:$unwindDestOperands,
553+
Variadic<LLVM_Type>:$bundle_operands,
553554
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
554-
DefaultValuedAttr<CConv, "CConv::C">:$CConv);
555+
DefaultValuedAttr<CConv, "CConv::C">:$CConv,
556+
OptionalAttr<LLVM_OperandBundlesAttr>:$op_bundles);
555557
let results = (outs Optional<LLVM_Type>:$result);
556558
let successors = (successor AnySuccessor:$normalDest,
557559
AnySuccessor:$unwindDest);
@@ -587,7 +589,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
587589
//===----------------------------------------------------------------------===//
588590

589591
def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
590-
[DeclareOpInterfaceMethods<FastmathFlagsInterface>,
592+
[AttrSizedOperandSegments,
593+
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
591594
DeclareOpInterfaceMethods<CallOpInterface>,
592595
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
593596
DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
@@ -633,6 +636,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
633636
dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
634637
OptionalAttr<FlatSymbolRefAttr>:$callee,
635638
Variadic<LLVM_Type>:$callee_operands,
639+
Variadic<LLVM_Type>:$bundle_operands,
636640
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
637641
"{}">:$fastmathFlags,
638642
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
@@ -641,7 +645,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
641645
OptionalAttr<LLVM_MemoryEffectsAttr>:$memory_effects,
642646
OptionalAttr<UnitAttr>:$convergent,
643647
OptionalAttr<UnitAttr>:$no_unwind,
644-
OptionalAttr<UnitAttr>:$will_return
648+
OptionalAttr<UnitAttr>:$will_return,
649+
OptionalAttr<LLVM_OperandBundlesAttr>:$op_bundles
645650
);
646651
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
647652
let arguments = !con(args, aliasAttrs);
@@ -662,6 +667,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
662667
OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
663668
CArg<"ValueRange", "{}">:$args)>
664669
];
670+
let hasVerifier = 1;
665671
let hasCustomAssemblyFormat = 1;
666672
let extraClassDeclaration = [{
667673
/// Returns the callee function type.
@@ -1875,21 +1881,28 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
18751881

18761882
def LLVM_CallIntrinsicOp
18771883
: LLVM_Op<"call_intrinsic",
1878-
[DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
1884+
[AttrSizedOperandSegments,
1885+
DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
18791886
let summary = "Call to an LLVM intrinsic function.";
18801887
let description = [{
18811888
Call the specified llvm intrinsic. If the intrinsic is overloaded, use
18821889
the MLIR function type of this op to determine which intrinsic to call.
18831890
}];
18841891
let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args,
1892+
Variadic<LLVM_Type>:$bundle_operands,
18851893
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
1886-
"{}">:$fastmathFlags);
1894+
"{}">:$fastmathFlags,
1895+
OptionalAttr<LLVM_OperandBundlesAttr>:$op_bundles);
18871896
let results = (outs Optional<LLVM_Type>:$results);
18881897
let llvmBuilder = [{
18891898
return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
18901899
}];
18911900
let assemblyFormat = [{
1892-
$intrin `(` $args `)` `:` functional-type($args, $results) attr-dict
1901+
$intrin `(` $args `)`
1902+
( `bundlearg` `(` $bundle_operands^ `)` )?
1903+
`:` functional-type($args, $results)
1904+
( `,` `tuple` `<` type($bundle_operands)^ `>` )?
1905+
attr-dict
18931906
}];
18941907

18951908
let hasVerifier = 1;

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
544544
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
545545
promoted, callOp->getAttrs());
546546

547+
newOp->setAttr(newOp.getOperandSegmentSizesAttrName(),
548+
rewriter.getDenseI32ArrayAttr(
549+
{static_cast<int32_t>(promoted.size()), 0}));
550+
547551
SmallVector<Value, 4> results;
548552
if (numResults < 2) {
549553
// If < 2 results, packing did not do anything and we can just return.

mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -837,17 +837,25 @@ class FunctionCallPattern
837837
matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
838838
ConversionPatternRewriter &rewriter) const override {
839839
if (callOp.getNumResults() == 0) {
840-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
840+
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
841841
callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
842+
newOp->setAttr(
843+
newOp.getOperandSegmentSizesAttrName(),
844+
rewriter.getDenseI32ArrayAttr(
845+
{static_cast<int32_t>(adaptor.getOperands().size()), 0}));
842846
return success();
843847
}
844848

845849
// Function returns a single result.
846850
auto dstType = typeConverter.convertType(callOp.getType(0));
847851
if (!dstType)
848852
return rewriter.notifyMatchFailure(callOp, "type conversion failed");
849-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
853+
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
850854
callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
855+
newOp->setAttr(
856+
newOp.getOperandSegmentSizesAttrName(),
857+
rewriter.getDenseI32ArrayAttr(
858+
{static_cast<int32_t>(adaptor.getOperands().size()), 0}));
851859
return success();
852860
}
853861
};

0 commit comments

Comments
 (0)