Skip to content

Commit d10b1a3

Browse files
Tobias Gysiftynse
authored andcommitted
[mlir] make the bitwidth of device side index computations configurable
The patch makes the index type lowering of the GPU to NVVM/ROCDL conversion configurable. It introduces a pass option that controls the bitwidth used when lowering index computations. Differential Revision: https://reviews.llvm.org/D80285
1 parent e935a54 commit d10b1a3

File tree

10 files changed

+206
-146
lines changed

10 files changed

+206
-146
lines changed

mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
99
#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
1010

11+
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
1112
#include <memory>
1213

1314
namespace mlir {
@@ -24,9 +25,11 @@ class GPUModuleOp;
2425
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
2526
OwningRewritePatternList &patterns);
2627

27-
/// Creates a pass that lowers GPU dialect operations to NVVM counterparts.
28-
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
29-
createLowerGpuOpsToNVVMOpsPass();
28+
/// Creates a pass that lowers GPU dialect operations to NVVM counterparts. The
29+
/// index bitwidth used for the lowering of the device side index computations
30+
/// is configurable.
31+
std::unique_ptr<OperationPass<gpu::GPUModuleOp>> createLowerGpuOpsToNVVMOpsPass(
32+
unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout);
3033

3134
} // namespace mlir
3235

mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#ifndef MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
99
#define MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
1010

11+
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
1112
#include <memory>
1213

1314
namespace mlir {
@@ -25,9 +26,12 @@ class GPUModuleOp;
2526
void populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter,
2627
OwningRewritePatternList &patterns);
2728

28-
/// Creates a pass that lowers GPU dialect operations to ROCDL counterparts.
29+
/// Creates a pass that lowers GPU dialect operations to ROCDL counterparts. The
30+
/// index bitwidth used for the lowering of the device side index computations
31+
/// is configurable.
2932
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
30-
createLowerGpuOpsToROCDLOpsPass();
33+
createLowerGpuOpsToROCDLOpsPass(
34+
unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout);
3135

3236
} // namespace mlir
3337

mlir/include/mlir/Conversion/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ def ConvertGpuLaunchFuncToGpuRuntimeCalls : Pass<"launch-func-to-gpu-runtime",
100100
def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
101101
let summary = "Generate NVVM operations for gpu operations";
102102
let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()";
103+
let options = [
104+
Option<"indexBitwidth", "index-bitwidth", "unsigned",
105+
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
106+
"Bitwidth of the index type, 0 to use size of machine word">
107+
];
103108
}
104109

105110
//===----------------------------------------------------------------------===//
@@ -109,6 +114,11 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
109114
def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
110115
let summary = "Generate ROCDL operations for gpu operations";
111116
let constructor = "mlir::createLowerGpuOpsToROCDLOpsPass()";
117+
let options = [
118+
Option<"indexBitwidth", "index-bitwidth", "unsigned",
119+
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
120+
"Bitwidth of the index type, 0 to use size of machine word">
121+
];
112122
}
113123

114124
//===----------------------------------------------------------------------===//

mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
1616
#define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
1717

18+
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
1819
#include "mlir/Transforms/DialectConversion.h"
1920

2021
namespace llvm {
@@ -35,22 +36,6 @@ class LLVMDialect;
3536
class LLVMType;
3637
} // namespace LLVM
3738

38-
/// Set of callbacks that allows the customization of LLVMTypeConverter.
39-
struct LLVMTypeConverterCustomization {
40-
using CustomCallback = std::function<LogicalResult(LLVMTypeConverter &, Type,
41-
SmallVectorImpl<Type> &)>;
42-
43-
/// Customize the type conversion of function arguments.
44-
CustomCallback funcArgConverter;
45-
46-
/// Used to determine the bitwidth of the LLVM integer type that the index
47-
/// type gets lowered to. Defaults to deriving the size from the data layout.
48-
unsigned indexBitwidth;
49-
50-
/// Initialize customization to default callbacks.
51-
LLVMTypeConverterCustomization();
52-
};
53-
5439
/// Callback to convert function argument types. It converts a MemRef function
5540
/// argument to a list of non-aggregate types containing descriptor
5641
/// information, and an UnrankedmemRef function argument to a list containing
@@ -75,13 +60,11 @@ class LLVMTypeConverter : public TypeConverter {
7560
public:
7661
using TypeConverter::convertType;
7762

78-
/// Create an LLVMTypeConverter using the default
79-
/// LLVMTypeConverterCustomization.
63+
/// Create an LLVMTypeConverter using the default LowerToLLVMOptions.
8064
LLVMTypeConverter(MLIRContext *ctx);
8165

82-
/// Create an LLVMTypeConverter using 'custom' customizations.
83-
LLVMTypeConverter(MLIRContext *ctx,
84-
const LLVMTypeConverterCustomization &custom);
66+
/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
67+
LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
8568

8669
/// Convert a function type. The arguments and results are converted one by
8770
/// one and results are packed into a wrapped LLVM IR structure type. `result`
@@ -127,7 +110,7 @@ class LLVMTypeConverter : public TypeConverter {
127110
LLVM::LLVMType getIndexType();
128111

129112
/// Gets the bitwidth of the index type when converted to LLVM.
130-
unsigned getIndexTypeBitwidth() { return customizations.indexBitwidth; }
113+
unsigned getIndexTypeBitwidth() { return options.indexBitwidth; }
131114

132115
protected:
133116
/// LLVM IR module used to parse/create types.
@@ -193,8 +176,8 @@ class LLVMTypeConverter : public TypeConverter {
193176
// Convert a 1D vector type into an LLVM vector type.
194177
Type convertVectorType(VectorType type);
195178

196-
/// Callbacks for customizing the type conversion.
197-
LLVMTypeConverterCustomization customizations;
179+
/// Options for customizing the llvm lowering.
180+
LowerToLLVMOptions options;
198181
};
199182

200183
/// Helper class to produce LLVM dialect operations extracting or inserting
@@ -389,11 +372,17 @@ class UnrankedMemRefDescriptor : public StructBuilder {
389372
};
390373

391374
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
392-
/// conversion patterns with access to an LLVMTypeConverter.
375+
/// conversion patterns with access to an LLVMTypeConverter and the
376+
/// LowerToLLVMOptions.
393377
class ConvertToLLVMPattern : public ConversionPattern {
394378
public:
395379
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
396380
LLVMTypeConverter &typeConverter,
381+
const LowerToLLVMOptions &options = {
382+
/*useBarePtrCallConv=*/false,
383+
/*emitCWrappers=*/false,
384+
/*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout,
385+
/*useAlignedAlloc=*/false},
397386
PatternBenefit benefit = 1);
398387

399388
/// Returns the LLVM dialect.
@@ -445,6 +434,9 @@ class ConvertToLLVMPattern : public ConversionPattern {
445434
protected:
446435
/// Reference to the type converter, with potential extensions.
447436
LLVMTypeConverter &typeConverter;
437+
438+
/// Reference to the llvm lowering options.
439+
const LowerToLLVMOptions &options;
448440
};
449441

450442
/// Utility class for operation conversions targeting the LLVM dialect that
@@ -453,10 +445,11 @@ template <typename OpTy>
453445
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
454446
public:
455447
ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
448+
const LowerToLLVMOptions &options,
456449
PatternBenefit benefit = 1)
457450
: ConvertToLLVMPattern(OpTy::getOperationName(),
458451
&typeConverter.getContext(), typeConverter,
459-
benefit) {}
452+
options, benefit) {}
460453
};
461454

462455
namespace LLVM {

mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,54 +14,50 @@
1414
namespace mlir {
1515
class LLVMTypeConverter;
1616
class ModuleOp;
17-
template <typename T> class OperationPass;
17+
template <typename T>
18+
class OperationPass;
1819
class OwningRewritePatternList;
1920

21+
/// Value to pass as bitwidth for the index type when the converter is expected
22+
/// to derive the bitwidth from the LLVM data layout.
23+
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0;
24+
25+
struct LowerToLLVMOptions {
26+
bool useBarePtrCallConv = false;
27+
bool emitCWrappers = false;
28+
unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout;
29+
/// Use aligned_alloc for heap allocations.
30+
bool useAlignedAlloc = false;
31+
};
32+
2033
/// Collect a set of patterns to convert memory-related operations from the
2134
/// Standard dialect to the LLVM dialect, excluding non-memory-related
2235
/// operations and FuncOp.
2336
void populateStdToLLVMMemoryConversionPatterns(
2437
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
25-
bool useAlignedAlloc);
38+
const LowerToLLVMOptions &options);
2639

2740
/// Collect a set of patterns to convert from the Standard dialect to the LLVM
2841
/// dialect, excluding the memory-related operations.
2942
void populateStdToLLVMNonMemoryConversionPatterns(
30-
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
43+
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
44+
const LowerToLLVMOptions &options);
3145

3246
/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
3347
/// `emitCWrappers` is set, the pattern will also produce functions
3448
/// that pass memref descriptors by pointer-to-structure in addition to the
3549
/// default unpacked form.
36-
void populateStdToLLVMDefaultFuncOpConversionPattern(
50+
void populateStdToLLVMFuncOpConversionPattern(
3751
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
38-
bool emitCWrappers = false);
52+
const LowerToLLVMOptions &options);
3953

40-
/// Collect a set of default patterns to convert from the Standard dialect to
41-
/// LLVM.
42-
void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
43-
OwningRewritePatternList &patterns,
44-
bool emitCWrappers = false,
45-
bool useAlignedAlloc = false);
46-
47-
/// Collect a set of patterns to convert from the Standard dialect to
48-
/// LLVM using the bare pointer calling convention for MemRef function
49-
/// arguments.
50-
void populateStdToLLVMBarePtrConversionPatterns(
54+
/// Collect the patterns to convert from the Standard dialect to LLVM.
55+
void populateStdToLLVMConversionPatterns(
5156
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
52-
bool useAlignedAlloc);
53-
54-
/// Value to pass as bitwidth for the index type when the converter is expected
55-
/// to derive the bitwidth from the LLVM data layout.
56-
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0;
57-
58-
struct LowerToLLVMOptions {
59-
bool useBarePtrCallConv = false;
60-
bool emitCWrappers = false;
61-
unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout;
62-
/// Use aligned_alloc for heap allocations.
63-
bool useAlignedAlloc = false;
64-
};
57+
const LowerToLLVMOptions &options = {
58+
/*useBarePtrCallConv=*/false, /*emitCWrappers=*/false,
59+
/*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout,
60+
/*useAlignedAlloc=*/false});
6561

6662
/// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
6763
/// stdlib malloc/free is used by default for allocating memrefs allocated with

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ using namespace mlir;
3030

3131
namespace {
3232

33-
3433
struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
3534
explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_)
3635
: ConvertToLLVMPattern(gpu::ShuffleOp::getOperationName(),
@@ -97,17 +96,27 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
9796
///
9897
/// This pass only handles device code and is not meant to be run on GPU host
9998
/// code.
100-
class LowerGpuOpsToNVVMOpsPass
99+
struct LowerGpuOpsToNVVMOpsPass
101100
: public ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
102-
public:
101+
LowerGpuOpsToNVVMOpsPass() = default;
102+
LowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth) {
103+
this->indexBitwidth = indexBitwidth;
104+
}
105+
103106
void runOnOperation() override {
104107
gpu::GPUModuleOp m = getOperation();
105108

109+
/// Customize the bitwidth used for the device side index computations.
110+
LowerToLLVMOptions options = {/*useBarePtrCallConv =*/false,
111+
/*emitCWrappers = */ true,
112+
/*indexBitwidth =*/indexBitwidth,
113+
/*useAlignedAlloc =*/false};
114+
106115
/// MemRef conversion for GPU to NVVM lowering. The GPU dialect uses memory
107116
/// space 5 for private memory attributions, but NVVM represents private
108117
/// memory allocations as local `alloca`s in the default address space. This
109118
/// converter drops the private memory space to support the use case above.
110-
LLVMTypeConverter converter(m.getContext());
119+
LLVMTypeConverter converter(m.getContext(), options);
111120
converter.addConversion([&](MemRefType type) -> Optional<Type> {
112121
if (type.getMemorySpace() != gpu::GPUDialect::getPrivateAddressSpace())
113122
return llvm::None;
@@ -176,6 +185,6 @@ void mlir::populateGpuToNVVMConversionPatterns(
176185
}
177186

178187
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
179-
mlir::createLowerGpuOpsToNVVMOpsPass() {
180-
return std::make_unique<LowerGpuOpsToNVVMOpsPass>();
188+
mlir::createLowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth) {
189+
return std::make_unique<LowerGpuOpsToNVVMOpsPass>(indexBitwidth);
181190
}

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,22 @@ namespace {
4141
//
4242
// This pass only handles device code and is not meant to be run on GPU host
4343
// code.
44-
class LowerGpuOpsToROCDLOpsPass
44+
struct LowerGpuOpsToROCDLOpsPass
4545
: public ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
46-
public:
46+
LowerGpuOpsToROCDLOpsPass() = default;
47+
LowerGpuOpsToROCDLOpsPass(unsigned indexBitwidth) {
48+
this->indexBitwidth = indexBitwidth;
49+
}
50+
4751
void runOnOperation() override {
4852
gpu::GPUModuleOp m = getOperation();
4953

50-
LLVMTypeConverter converter(m.getContext());
54+
/// Customize the bitwidth used for the device side index computations.
55+
LowerToLLVMOptions options = {/*useBarePtrCallConv =*/false,
56+
/*emitCWrappers = */ true,
57+
/*indexBitwidth =*/indexBitwidth,
58+
/*useAlignedAlloc =*/false};
59+
LLVMTypeConverter converter(m.getContext(), options);
5160

5261
OwningRewritePatternList patterns;
5362

@@ -106,6 +115,6 @@ void mlir::populateGpuToROCDLConversionPatterns(
106115
}
107116

108117
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
109-
mlir::createLowerGpuOpsToROCDLOpsPass() {
110-
return std::make_unique<LowerGpuOpsToROCDLOpsPass>();
118+
mlir::createLowerGpuOpsToROCDLOpsPass(unsigned indexBitwidth) {
119+
return std::make_unique<LowerGpuOpsToROCDLOpsPass>(indexBitwidth);
111120
}

0 commit comments

Comments
 (0)