diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 4f6200d29a70a..7ffdfe15ac0d9 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -244,6 +244,92 @@ static llvm::BasicBlock *convertOmpOpRegions( return continuationBlock; } +/// Finds the set of \c llvm.alloca instructions associated to \c LLVM::AllocaOp +/// MLIR operations for primitive types that are defined outside of the given +/// \p region but only used inside of it. +static void +gatherSinkableAllocas(const LLVM::ModuleTranslation &moduleTranslation, + Region ®ion, + SetVector &allocasToSink) { + Operation *op = region.getParentOp(); + + auto processLoadStore = [&](auto loadStoreOp) { + Value addr = loadStoreOp.getAddr(); + Operation *addrOp = addr.getDefiningOp(); + + // The destination address is already defined in this region or it is not an + // llvm.alloca operation, so skip it. + if (!isa_and_present(addrOp) || op->isAncestor(addrOp)) + return; + + // Get LLVM value to which the address is mapped. It has to be mapped to the + // allocation instruction of a scalar type to be marked as sinkable by this + // function. + llvm::Value *llvmAddr = moduleTranslation.lookupValue(addr); + if (!isa_and_present(llvmAddr)) + return; + + auto *llvmAlloca = cast(llvmAddr); + if (llvmAlloca->getAllocatedType()->getPrimitiveSizeInBits() == 0) + return; + + // Check that the address is only used inside of the region. + bool addressUsedOnlyInternally = true; + for (auto &addrUse : addr.getUses()) { + if (!op->isAncestor(addrUse.getOwner())) { + addressUsedOnlyInternally = false; + break; + } + } + + if (!addressUsedOnlyInternally) + return; + + allocasToSink.insert(llvmAlloca); + }; + + region.walk([&processLoadStore](Operation *op) { + if (auto loadOp = dyn_cast(op)) + processLoadStore(loadOp); + else if (auto storeOp = dyn_cast(op)) + processLoadStore(storeOp); + }); +} + +/// Converts the given region that appears within an OpenMP dialect operation to +/// LLVM IR, according to the process described in \c convertOmpOpRegions(), and +/// marks the lifetime of allocas read/written exclusively inside of the region +/// but defined outside of it. +/// +/// This information enables later compilation stages to sink these allocations +/// inside of the region, such as when outlining it into a separate function. +static llvm::BasicBlock *convertOmpOpRegionsWithAllocaLifetimes( + Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, LogicalResult &bodyGenStatus) { + SetVector allocasToSink; + gatherSinkableAllocas(moduleTranslation, region, allocasToSink); + + for (auto *alloca : allocasToSink) { + unsigned size = alloca->getAllocatedType()->getPrimitiveSizeInBits() / 8; + builder.CreateLifetimeStart(alloca, builder.getInt64(size)); + } + + llvm::BasicBlock *continuationBlock = convertOmpOpRegions( + region, blockName, builder, moduleTranslation, bodyGenStatus); + + if (!allocasToSink.empty()) { + llvm::IRBuilderBase::InsertPointGuard guard(builder); + builder.SetInsertPoint(continuationBlock, continuationBlock->begin()); + + for (auto *alloca : allocasToSink) { + unsigned size = alloca->getAllocatedType()->getPrimitiveSizeInBits() / 8; + builder.CreateLifetimeEnd(alloca, builder.getInt64(size)); + } + } + + return continuationBlock; +} + /// Convert ProcBindKind from MLIR-generated enum to LLVM enum. static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) { switch (kind) { @@ -910,8 +996,9 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder, // Convert the body of the loop. builder.restoreIP(ip); - convertOmpOpRegions(loop.getRegion(), "omp.wsloop.region", builder, - moduleTranslation, bodyGenStatus); + convertOmpOpRegionsWithAllocaLifetimes(loop.getRegion(), + "omp.wsloop.region", builder, + moduleTranslation, bodyGenStatus); }; // Delegate actual loop construction to the OpenMP IRBuilder. @@ -1151,8 +1238,9 @@ convertOmpSimdLoop(Operation &opInst, llvm::IRBuilderBase &builder, // Convert the body of the loop. builder.restoreIP(ip); - convertOmpOpRegions(loop.getRegion(), "omp.simdloop.region", builder, - moduleTranslation, bodyGenStatus); + convertOmpOpRegionsWithAllocaLifetimes(loop.getRegion(), + "omp.simdloop.region", builder, + moduleTranslation, bodyGenStatus); }; // Delegate actual loop construction to the OpenMP IRBuilder. diff --git a/mlir/test/Target/LLVMIR/openmp-alloca-lifetime.mlir b/mlir/test/Target/LLVMIR/openmp-alloca-lifetime.mlir new file mode 100644 index 0000000000000..d9229bd24c71d --- /dev/null +++ b/mlir/test/Target/LLVMIR/openmp-alloca-lifetime.mlir @@ -0,0 +1,124 @@ +// This test checks the introduction of lifetime information for allocas defined +// outside of omp.wsloop and omp.simdloop regions but only used inside of them. + +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @foo(%arg0 : i32) { + llvm.return +} + +llvm.func @bar(%arg0 : i64) { + llvm.return +} + +// CHECK-LABEL: define void @wsloop_i32 +llvm.func @wsloop_i32(%size : i64, %lb : i32, %ub : i32, %step : i32) { + // CHECK-DAG: %[[LASTITER:.*]] = alloca i32 + // CHECK-DAG: %[[LB:.*]] = alloca i32 + // CHECK-DAG: %[[UB:.*]] = alloca i32 + // CHECK-DAG: %[[STRIDE:.*]] = alloca i32 + // CHECK-DAG: %[[I:.*]] = alloca i32 + %1 = llvm.alloca %size x i32 : (i64) -> !llvm.ptr + + // CHECK-NOT: %[[I]] + // CHECK: call void @llvm.lifetime.start.p0(i64 4, ptr %[[I]]) + // CHECK-NEXT: br label %[[WSLOOP_BB:.*]] + // CHECK-NOT: %[[I]] + // CHECK: [[WSLOOP_BB]]: + // CHECK-NOT: {{^.*}}: + // CHECK: br label %[[CONT_BB:.*]] + // CHECK: [[CONT_BB]]: + // CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr %[[I]]) + // CHECK-NOT: %[[I]] + omp.wsloop for (%iv) : i32 = (%lb) to (%ub) step (%step) { + llvm.store %iv, %1 : i32, !llvm.ptr + %2 = llvm.load %1 : !llvm.ptr -> i32 + llvm.call @foo(%2) : (i32) -> () + omp.yield + } + + // CHECK: ret void + llvm.return +} + +// CHECK-LABEL: define void @wsloop_i64 +llvm.func @wsloop_i64(%size : i64, %lb : i64, %ub : i64, %step : i64) { + // CHECK-DAG: %[[LASTITER:.*]] = alloca i32 + // CHECK-DAG: %[[LB:.*]] = alloca i64 + // CHECK-DAG: %[[UB:.*]] = alloca i64 + // CHECK-DAG: %[[STRIDE:.*]] = alloca i64 + // CHECK-DAG: %[[I:.*]] = alloca i64 + %1 = llvm.alloca %size x i64 : (i64) -> !llvm.ptr + + // CHECK-NOT: %[[I]] + // CHECK: call void @llvm.lifetime.start.p0(i64 8, ptr %[[I]]) + // CHECK-NEXT: br label %[[WSLOOP_BB:.*]] + // CHECK-NOT: %[[I]] + // CHECK: [[WSLOOP_BB]]: + // CHECK-NOT: {{^.*}}: + // CHECK: br label %[[CONT_BB:.*]] + // CHECK: [[CONT_BB]]: + // CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr %[[I]]) + // CHECK-NOT: %[[I]] + omp.wsloop for (%iv) : i64 = (%lb) to (%ub) step (%step) { + llvm.store %iv, %1 : i64, !llvm.ptr + %2 = llvm.load %1 : !llvm.ptr -> i64 + llvm.call @bar(%2) : (i64) -> () + omp.yield + } + + // CHECK: ret void + llvm.return +} + +// CHECK-LABEL: define void @simdloop_i32 +llvm.func @simdloop_i32(%size : i64, %lb : i32, %ub : i32, %step : i32) { + // CHECK: %[[I:.*]] = alloca i32 + %1 = llvm.alloca %size x i32 : (i64) -> !llvm.ptr + + // CHECK-NOT: %[[I]] + // CHECK: call void @llvm.lifetime.start.p0(i64 4, ptr %[[I]]) + // CHECK-NEXT: br label %[[SIMDLOOP_BB:.*]] + // CHECK-NOT: %[[I]] + // CHECK: [[SIMDLOOP_BB]]: + // CHECK-NOT: {{^.*}}: + // CHECK: br label %[[CONT_BB:.*]] + // CHECK: [[CONT_BB]]: + // CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr %[[I]]) + // CHECK-NOT: %[[I]] + omp.simdloop for (%iv) : i32 = (%lb) to (%ub) step (%step) { + llvm.store %iv, %1 : i32, !llvm.ptr + %2 = llvm.load %1 : !llvm.ptr -> i32 + llvm.call @foo(%2) : (i32) -> () + omp.yield + } + + // CHECK: ret void + llvm.return +} + +// CHECK-LABEL: define void @simdloop_i64 +llvm.func @simdloop_i64(%size : i64, %lb : i64, %ub : i64, %step : i64) { + // CHECK: %[[I:.*]] = alloca i64 + %1 = llvm.alloca %size x i64 : (i64) -> !llvm.ptr + + // CHECK-NOT: %[[I]] + // CHECK: call void @llvm.lifetime.start.p0(i64 8, ptr %[[I]]) + // CHECK-NEXT: br label %[[SIMDLOOP_BB:.*]] + // CHECK-NOT: %[[I]] + // CHECK: [[SIMDLOOP_BB]]: + // CHECK-NOT: {{^.*}}: + // CHECK: br label %[[CONT_BB:.*]] + // CHECK: [[CONT_BB]]: + // CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr %[[I]]) + // CHECK-NOT: %[[I]] + omp.simdloop for (%iv) : i64 = (%lb) to (%ub) step (%step) { + llvm.store %iv, %1 : i64, !llvm.ptr + %2 = llvm.load %1 : !llvm.ptr -> i64 + llvm.call @bar(%2) : (i64) -> () + omp.yield + } + + // CHECK: ret void + llvm.return +}