diff --git a/flang/include/flang/Optimizer/Transforms/CUFOpConversion.h b/flang/include/flang/Optimizer/Transforms/CUFOpConversion.h index f061323db1704..336cf46d82bab 100644 --- a/flang/include/flang/Optimizer/Transforms/CUFOpConversion.h +++ b/flang/include/flang/Optimizer/Transforms/CUFOpConversion.h @@ -23,11 +23,16 @@ class SymbolTable; namespace cuf { +/// Patterns that convert CUF operations to runtime calls. void populateCUFToFIRConversionPatterns(const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl, const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns); +/// Patterns that updates fir operations in presence of CUF. +void populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab, + mlir::RewritePatternSet &patterns); + } // namespace cuf #endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFOPCONVERSION_H_ diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index 337ea04755d1a..7f6843d66d39f 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -81,6 +81,15 @@ static bool hasDoubleDescriptors(OpTy op) { return false; } +bool isDeviceGlobal(fir::GlobalOp op) { + auto attr = op.getDataAttr(); + if (attr && (*attr == cuf::DataAttribute::Device || + *attr == cuf::DataAttribute::Managed || + *attr == cuf::DataAttribute::Constant)) + return true; + return false; +} + static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type toTy, mlir::Value val) { @@ -89,62 +98,6 @@ static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter, return val; } -mlir::Value getDeviceAddress(mlir::PatternRewriter &rewriter, - mlir::OpOperand &operand, - const mlir::SymbolTable &symtab) { - mlir::Value v = operand.get(); - auto declareOp = v.getDefiningOp(); - if (!declareOp) - return v; - - auto addrOfOp = declareOp.getMemref().getDefiningOp(); - if (!addrOfOp) - return v; - - auto globalOp = symtab.lookup( - addrOfOp.getSymbol().getRootReference().getValue()); - - if (!globalOp) - return v; - - bool isDevGlobal{false}; - auto attr = globalOp.getDataAttrAttr(); - if (attr) { - switch (attr.getValue()) { - case cuf::DataAttribute::Device: - case cuf::DataAttribute::Managed: - case cuf::DataAttribute::Constant: - isDevGlobal = true; - break; - default: - break; - } - } - if (!isDevGlobal) - return v; - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(operand.getOwner()); - auto loc = declareOp.getLoc(); - auto mod = declareOp->getParentOfType(); - fir::FirOpBuilder builder(rewriter, mod); - - mlir::func::FuncOp callee = - fir::runtime::getRuntimeFunc(loc, builder); - auto fTy = callee.getFunctionType(); - auto toTy = fTy.getInput(0); - mlir::Value inputArg = - createConvertOp(rewriter, loc, toTy, declareOp.getResult()); - mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); - mlir::Value sourceLine = - fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); - llvm::SmallVector args{fir::runtime::createArguments( - builder, loc, fTy, inputArg, sourceFile, sourceLine)}; - auto call = rewriter.create(loc, callee, args); - mlir::Value cast = createConvertOp( - rewriter, loc, declareOp.getMemref().getType(), call->getResult(0)); - return cast; -} - template static mlir::LogicalResult convertOpToCall(OpTy op, mlir::PatternRewriter &rewriter, @@ -422,6 +375,54 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern { const fir::LLVMTypeConverter *typeConverter; }; +struct DeclareOpConversion : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + DeclareOpConversion(mlir::MLIRContext *context, + const mlir::SymbolTable &symtab) + : OpRewritePattern(context), symTab{symtab} {} + + mlir::LogicalResult + matchAndRewrite(fir::DeclareOp op, + mlir::PatternRewriter &rewriter) const override { + if (auto addrOfOp = op.getMemref().getDefiningOp()) { + if (auto global = symTab.lookup( + addrOfOp.getSymbol().getRootReference().getValue())) { + if (isDeviceGlobal(global)) { + rewriter.setInsertionPointAfter(addrOfOp); + auto mod = op->getParentOfType(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + mlir::func::FuncOp callee = + fir::runtime::getRuntimeFunc( + loc, builder); + auto fTy = callee.getFunctionType(); + mlir::Type toTy = fTy.getInput(0); + mlir::Value inputArg = + createConvertOp(rewriter, loc, toTy, addrOfOp.getResult()); + mlir::Value sourceFile = + fir::factory::locationToFilename(builder, loc); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); + llvm::SmallVector args{fir::runtime::createArguments( + builder, loc, fTy, inputArg, sourceFile, sourceLine)}; + auto call = rewriter.create(loc, callee, args); + mlir::Value cast = createConvertOp( + rewriter, loc, op.getMemref().getType(), call->getResult(0)); + rewriter.startOpModification(op); + op.getMemrefMutable().assign(cast); + rewriter.finalizeOpModification(op); + return success(); + } + } + } + return failure(); + } + +private: + const mlir::SymbolTable &symTab; +}; + struct CUFFreeOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -511,7 +512,7 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter, builder.create(loc, src, alloc); addr = alloc; } else { - addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab); + addr = op.getSrc(); } llvm::SmallVector lenParams; mlir::Type boxTy = fir::BoxType::get(srcTy); @@ -531,7 +532,7 @@ static mlir::Value emboxDst(mlir::PatternRewriter &rewriter, mlir::Location loc = op.getLoc(); fir::FirOpBuilder builder(rewriter, mod); mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType()); - mlir::Value dstAddr = getDeviceAddress(rewriter, op.getDstMutable(), symtab); + mlir::Value dstAddr = op.getDst(); mlir::Type dstBoxTy = fir::BoxType::get(dstTy); llvm::SmallVector lenParams; mlir::Value dstBox = @@ -652,8 +653,8 @@ struct CUFDataTransferOpConversion mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(5)); - mlir::Value dst = getDeviceAddress(rewriter, op.getDstMutable(), symtab); - mlir::Value src = getDeviceAddress(rewriter, op.getSrcMutable(), symtab); + mlir::Value dst = op.getDst(); + mlir::Value src = op.getSrc(); // Materialize the src if constant. if (matchPattern(src.getDefiningOp(), mlir::m_Constant())) { mlir::Value temp = builder.createTemporary(loc, srcTy); @@ -823,6 +824,30 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase { "error in CUF op conversion\n"); signalPassFailure(); } + + target.addDynamicallyLegalOp([&](fir::DeclareOp op) { + if (inDeviceContext(op)) + return true; + if (auto addrOfOp = op.getMemref().getDefiningOp()) { + if (auto global = symtab.lookup( + addrOfOp.getSymbol().getRootReference().getValue())) { + if (mlir::isa(fir::unwrapRefType(global.getType()))) + return true; + if (isDeviceGlobal(global)) + return false; + } + } + return true; + }); + + patterns.clear(); + cuf::populateFIRCUFConversionPatterns(symtab, patterns); + if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + mlir::emitError(mlir::UnknownLoc::get(ctx), + "error in CUF op conversion\n"); + signalPassFailure(); + } } }; } // namespace @@ -837,3 +862,8 @@ void cuf::populateCUFToFIRConversionPatterns( &dl, &converter); patterns.insert(patterns.getContext(), symtab); } + +void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab, + mlir::RewritePatternSet &patterns) { + patterns.insert(patterns.getContext(), symtab); +} diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir index b371d39777728..7203c33e7eb11 100644 --- a/flang/test/Fir/CUDA/cuda-data-transfer.fir +++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir @@ -199,12 +199,12 @@ func.func @_QPsub8() attributes {fir.bindc_name = "t"} { // CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32> // CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]] // CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref> -// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]] -// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref>) -> !fir.llvm_ptr -// CHECK: %[[SRC:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.ref, i32) -> !fir.llvm_ptr -// CHECK: %[[SRC_CONV:.*]] = fir.convert %[[SRC]] : (!fir.llvm_ptr) -> !fir.ref> +// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref>) -> !fir.llvm_ptr +// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.ref, i32) -> !fir.llvm_ptr +// CHECK: %[[ADDR_CONV:.*]] = fir.convert %[[ADDR]] : (!fir.llvm_ptr) -> !fir.ref> +// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR_CONV]] // CHECK: %[[DST:.*]] = fir.convert %[[LOCAL]] : (!fir.ref>) -> !fir.llvm_ptr -// CHECK: %[[SRC:.*]] = fir.convert %[[SRC_CONV]] : (!fir.ref>) -> !fir.llvm_ptr +// CHECK: %[[SRC:.*]] = fir.convert %[[DECL]] : (!fir.ref>) -> !fir.llvm_ptr // CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none @@ -223,11 +223,11 @@ func.func @_QPsub9() { // CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32> // CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]] // CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref> -// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]] -// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref>) -> !fir.llvm_ptr -// CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.ref, i32) -> !fir.llvm_ptr -// CHECK: %[[DST_CONV:.*]] = fir.convert %[[DST]] : (!fir.llvm_ptr) -> !fir.ref> -// CHECK: %[[DST:.*]] = fir.convert %[[DST_CONV]] : (!fir.ref>) -> !fir.llvm_ptr +// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref>) -> !fir.llvm_ptr +// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.ref, i32) -> !fir.llvm_ptr +// CHECK: %[[ADDR_CONV:.*]] = fir.convert %[[ADDR]] : (!fir.llvm_ptr) -> !fir.ref> +// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR_CONV]] +// CHECK: %[[DST:.*]] = fir.convert %[[DECL]] : (!fir.ref>) -> !fir.llvm_ptr // CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref>) -> !fir.llvm_ptr // CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none @@ -380,9 +380,12 @@ func.func @_QPdevice_addr_conv() { } // CHECK-LABEL: func.func @_QPdevice_addr_conv() -// CHECK: %[[DEV_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.ref, i32) -> !fir.llvm_ptr -// CHECK: %[[DEV_ADDR_CONV:.*]] = fir.convert %[[DEV_ADDR]] : (!fir.llvm_ptr) -> !fir.ref> -// CHECK: fir.embox %[[DEV_ADDR_CONV]](%{{.*}}) : (!fir.ref>, !fir.shape<1>) -> !fir.box> +// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmod1Ea_dev) : !fir.ref> +// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref>) -> !fir.llvm_ptr +// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.ref, i32) -> !fir.llvm_ptr +// CHECK: %[[ADDR_CONV:.*]] = fir.convert %[[ADDR]] : (!fir.llvm_ptr) -> !fir.ref> +// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR_CONV]](%{{.*}}) {data_attr = #cuf.cuda, uniq_name = "_QMmod1Ea_dev"} : (!fir.ref>, !fir.shape<1>) -> !fir.ref> +// CHECK: fir.embox %[[DECL]](%{{.*}}) : (!fir.ref>, !fir.shape<1>) -> !fir.box> // CHECK: fir.call @_FortranACUFDataTransferCstDesc func.func @_QQchar_transfer() attributes {fir.bindc_name = "char_transfer"} { diff --git a/flang/test/Fir/CUDA/cuda-global-addr.mlir b/flang/test/Fir/CUDA/cuda-global-addr.mlir new file mode 100644 index 0000000000000..2baead4010f5c --- /dev/null +++ b/flang/test/Fir/CUDA/cuda-global-addr.mlir @@ -0,0 +1,36 @@ +// RUN: fir-opt --cuf-convert %s | FileCheck %s + +module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry, dense<64> : vector<4xi64>>, #dlti.dl_entry, dense<32> : vector<4xi64>>, #dlti.dl_entry, dense<32> : vector<4xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} { +fir.global @_QMmod1Eadev {data_attr = #cuf.cuda} : !fir.array<10xi32> { + %0 = fir.zero_bits !fir.array<10xi32> + fir.has_value %0 : !fir.array<10xi32> +} +func.func @_QQmain() attributes {fir.bindc_name = "test"} { + %c14_i32 = arith.constant 14 : i32 + %c6_i32 = arith.constant 6 : i32 + %c4 = arith.constant 4 : index + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c10 = arith.constant 10 : index + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %3 = fir.address_of(@_QMmod1Eadev) : !fir.ref> + %4 = fir.declare %3(%1) {data_attr = #cuf.cuda, uniq_name = "_QMmod1Eadev"} : (!fir.ref>, !fir.shape<1>) -> !fir.ref> + %5 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"} + %6 = fir.declare %5 {uniq_name = "_QFEi"} : (!fir.ref) -> !fir.ref + fir.store %c0_i32 to %6 : !fir.ref + %7 = fir.array_coor %4(%1) %c4 : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref + cuf.data_transfer %c1_i32 to %7 {transfer_kind = #cuf.cuda_transfer} : i32, !fir.ref + return +} + +} + +// CHECK-LABEL: func.func @_QQmain() +// CHECK: %[[ADDR:.*]] = fir.address_of(@_QMmod1Eadev) : !fir.ref> +// CHECK: %[[ADDRPTR:.*]] = fir.convert %[[ADDR]] : (!fir.ref>) -> !fir.llvm_ptr +// CHECK: %[[DEVICE_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[ADDRPTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.ref, i32) -> !fir.llvm_ptr +// CHECK: %[[DEVICE_ADDR_CONV:.*]] = fir.convert %[[DEVICE_ADDR]] : (!fir.llvm_ptr) -> !fir.ref> +// CHECK: %[[DECL:.*]] = fir.declare %[[DEVICE_ADDR_CONV]](%{{.*}}) {data_attr = #cuf.cuda, uniq_name = "_QMmod1Eadev"} : (!fir.ref>, !fir.shape<1>) -> !fir.ref> +// CHECK: %[[ARRAY_COOR:.*]] = fir.array_coor %[[DECL]](%{{.*}}) %c4{{.*}} : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref +// CHECK: %[[ARRAY_COOR_PTR:.*]] = fir.convert %[[ARRAY_COOR]] : (!fir.ref) -> !fir.llvm_ptr +// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[ARRAY_COOR_PTR]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none