Skip to content

[flang][cuda] Get device address in fir.declare #118591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/Transforms/CUFOpConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@ class SymbolTable;

namespace cuf {

/// Patterns that convert CUF operations to runtime calls.
void populateCUFToFIRConversionPatterns(const fir::LLVMTypeConverter &converter,
mlir::DataLayout &dl,
const mlir::SymbolTable &symtab,
mlir::RewritePatternSet &patterns);

/// Patterns that updates fir operations in presence of CUF.
void populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
mlir::RewritePatternSet &patterns);

} // namespace cuf

#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFOPCONVERSION_H_
150 changes: 90 additions & 60 deletions flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ static bool hasDoubleDescriptors(OpTy op) {
return false;
}

bool isDeviceGlobal(fir::GlobalOp op) {
auto attr = op.getDataAttr();
if (attr && (*attr == cuf::DataAttribute::Device ||
*attr == cuf::DataAttribute::Managed ||
*attr == cuf::DataAttribute::Constant))
return true;
return false;
}

static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Type toTy,
mlir::Value val) {
Expand All @@ -89,62 +98,6 @@ static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
return val;
}

mlir::Value getDeviceAddress(mlir::PatternRewriter &rewriter,
mlir::OpOperand &operand,
const mlir::SymbolTable &symtab) {
mlir::Value v = operand.get();
auto declareOp = v.getDefiningOp<fir::DeclareOp>();
if (!declareOp)
return v;

auto addrOfOp = declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
if (!addrOfOp)
return v;

auto globalOp = symtab.lookup<fir::GlobalOp>(
addrOfOp.getSymbol().getRootReference().getValue());

if (!globalOp)
return v;

bool isDevGlobal{false};
auto attr = globalOp.getDataAttrAttr();
if (attr) {
switch (attr.getValue()) {
case cuf::DataAttribute::Device:
case cuf::DataAttribute::Managed:
case cuf::DataAttribute::Constant:
isDevGlobal = true;
break;
default:
break;
}
}
if (!isDevGlobal)
return v;
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(operand.getOwner());
auto loc = declareOp.getLoc();
auto mod = declareOp->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);

mlir::func::FuncOp callee =
fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
auto fTy = callee.getFunctionType();
auto toTy = fTy.getInput(0);
mlir::Value inputArg =
createConvertOp(rewriter, loc, toTy, declareOp.getResult());
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, inputArg, sourceFile, sourceLine)};
auto call = rewriter.create<fir::CallOp>(loc, callee, args);
mlir::Value cast = createConvertOp(
rewriter, loc, declareOp.getMemref().getType(), call->getResult(0));
return cast;
}

template <typename OpTy>
static mlir::LogicalResult convertOpToCall(OpTy op,
mlir::PatternRewriter &rewriter,
Expand Down Expand Up @@ -422,6 +375,54 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
const fir::LLVMTypeConverter *typeConverter;
};

struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
using OpRewritePattern::OpRewritePattern;

DeclareOpConversion(mlir::MLIRContext *context,
const mlir::SymbolTable &symtab)
: OpRewritePattern(context), symTab{symtab} {}

mlir::LogicalResult
matchAndRewrite(fir::DeclareOp op,
mlir::PatternRewriter &rewriter) const override {
if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
if (auto global = symTab.lookup<fir::GlobalOp>(
addrOfOp.getSymbol().getRootReference().getValue())) {
if (isDeviceGlobal(global)) {
rewriter.setInsertionPointAfter(addrOfOp);
auto mod = op->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Location loc = op.getLoc();
mlir::func::FuncOp callee =
fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(
loc, builder);
auto fTy = callee.getFunctionType();
mlir::Type toTy = fTy.getInput(0);
mlir::Value inputArg =
createConvertOp(rewriter, loc, toTy, addrOfOp.getResult());
mlir::Value sourceFile =
fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, inputArg, sourceFile, sourceLine)};
auto call = rewriter.create<fir::CallOp>(loc, callee, args);
mlir::Value cast = createConvertOp(
rewriter, loc, op.getMemref().getType(), call->getResult(0));
rewriter.startOpModification(op);
op.getMemrefMutable().assign(cast);
rewriter.finalizeOpModification(op);
return success();
}
}
}
return failure();
}

private:
const mlir::SymbolTable &symTab;
};

struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -511,7 +512,7 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
builder.create<fir::StoreOp>(loc, src, alloc);
addr = alloc;
} else {
addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
addr = op.getSrc();
}
llvm::SmallVector<mlir::Value> lenParams;
mlir::Type boxTy = fir::BoxType::get(srcTy);
Expand All @@ -531,7 +532,7 @@ static mlir::Value emboxDst(mlir::PatternRewriter &rewriter,
mlir::Location loc = op.getLoc();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
mlir::Value dstAddr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
mlir::Value dstAddr = op.getDst();
mlir::Type dstBoxTy = fir::BoxType::get(dstTy);
llvm::SmallVector<mlir::Value> lenParams;
mlir::Value dstBox =
Expand Down Expand Up @@ -652,8 +653,8 @@ struct CUFDataTransferOpConversion
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));

mlir::Value dst = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
mlir::Value src = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
mlir::Value dst = op.getDst();
mlir::Value src = op.getSrc();
// Materialize the src if constant.
if (matchPattern(src.getDefiningOp(), mlir::m_Constant())) {
mlir::Value temp = builder.createTemporary(loc, srcTy);
Expand Down Expand Up @@ -823,6 +824,30 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
"error in CUF op conversion\n");
signalPassFailure();
}

target.addDynamicallyLegalOp<fir::DeclareOp>([&](fir::DeclareOp op) {
if (inDeviceContext(op))
return true;
if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
if (auto global = symtab.lookup<fir::GlobalOp>(
addrOfOp.getSymbol().getRootReference().getValue())) {
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(global.getType())))
return true;
if (isDeviceGlobal(global))
return false;
}
}
return true;
});

patterns.clear();
cuf::populateFIRCUFConversionPatterns(symtab, patterns);
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
mlir::emitError(mlir::UnknownLoc::get(ctx),
"error in CUF op conversion\n");
signalPassFailure();
}
}
};
} // namespace
Expand All @@ -837,3 +862,8 @@ void cuf::populateCUFToFIRConversionPatterns(
&dl, &converter);
patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
}

void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
mlir::RewritePatternSet &patterns) {
patterns.insert<DeclareOpConversion>(patterns.getContext(), symtab);
}
29 changes: 16 additions & 13 deletions flang/test/Fir/CUDA/cuda-data-transfer.fir
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,12 @@ func.func @_QPsub8() attributes {fir.bindc_name = "t"} {
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[SRC:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[SRC_CONV:.*]] = fir.convert %[[SRC]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR_CONV:.*]] = fir.convert %[[ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR_CONV]]
// CHECK: %[[DST:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[SRC:.*]] = fir.convert %[[SRC_CONV]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[SRC:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none


Expand All @@ -223,11 +223,11 @@ func.func @_QPsub9() {
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[DST_CONV:.*]] = fir.convert %[[DST]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
// CHECK: %[[DST:.*]] = fir.convert %[[DST_CONV]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR_CONV:.*]] = fir.convert %[[ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR_CONV]]
// CHECK: %[[DST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none

Expand Down Expand Up @@ -380,9 +380,12 @@ func.func @_QPdevice_addr_conv() {
}

// CHECK-LABEL: func.func @_QPdevice_addr_conv()
// CHECK: %[[DEV_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[DEV_ADDR_CONV:.*]] = fir.convert %[[DEV_ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<4xf32>>
// CHECK: fir.embox %[[DEV_ADDR_CONV]](%{{.*}}) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<4xf32>>
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmod1Ea_dev) : !fir.ref<!fir.array<4xf32>>
// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<4xf32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR_CONV:.*]] = fir.convert %[[ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<4xf32>>
// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR_CONV]](%{{.*}}) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Ea_dev"} : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.ref<!fir.array<4xf32>>
// CHECK: fir.embox %[[DECL]](%{{.*}}) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<4xf32>>
// CHECK: fir.call @_FortranACUFDataTransferCstDesc

func.func @_QQchar_transfer() attributes {fir.bindc_name = "char_transfer"} {
Expand Down
36 changes: 36 additions & 0 deletions flang/test/Fir/CUDA/cuda-global-addr.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: fir-opt --cuf-convert %s | FileCheck %s

module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
fir.global @_QMmod1Eadev {data_attr = #cuf.cuda<device>} : !fir.array<10xi32> {
%0 = fir.zero_bits !fir.array<10xi32>
fir.has_value %0 : !fir.array<10xi32>
}
func.func @_QQmain() attributes {fir.bindc_name = "test"} {
%c14_i32 = arith.constant 14 : i32
%c6_i32 = arith.constant 6 : i32
%c4 = arith.constant 4 : index
%c1_i32 = arith.constant 1 : i32
%c0_i32 = arith.constant 0 : i32
%c10 = arith.constant 10 : index
%1 = fir.shape %c10 : (index) -> !fir.shape<1>
%3 = fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
%4 = fir.declare %3(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Eadev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
%5 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"}
%6 = fir.declare %5 {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
fir.store %c0_i32 to %6 : !fir.ref<i32>
%7 = fir.array_coor %4(%1) %c4 : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
cuf.data_transfer %c1_i32 to %7 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<i32>
return
}

}

// CHECK-LABEL: func.func @_QQmain()
// CHECK: %[[ADDR:.*]] = fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
// CHECK: %[[ADDRPTR:.*]] = fir.convert %[[ADDR]] : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[DEVICE_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[ADDRPTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[DEVICE_ADDR_CONV:.*]] = fir.convert %[[DEVICE_ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<10xi32>>
// CHECK: %[[DECL:.*]] = fir.declare %[[DEVICE_ADDR_CONV]](%{{.*}}) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Eadev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
// CHECK: %[[ARRAY_COOR:.*]] = fir.array_coor %[[DECL]](%{{.*}}) %c4{{.*}} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
// CHECK: %[[ARRAY_COOR_PTR:.*]] = fir.convert %[[ARRAY_COOR]] : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[ARRAY_COOR_PTR]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
Loading