Skip to content

Commit 1c95c7a

Browse files
authored
[flang][cuda] Add interfaces and lowering for barrier_arrive (llvm#162949)
1 parent 6a0e5b2 commit 1c95c7a

File tree

4 files changed

+80
-9
lines changed

4 files changed

+80
-9
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ struct IntrinsicLibrary {
208208
fir::ExtendedValue genAssociated(mlir::Type,
209209
llvm::ArrayRef<fir::ExtendedValue>);
210210
mlir::Value genAtand(mlir::Type, llvm::ArrayRef<mlir::Value>);
211+
mlir::Value genBarrierArrive(mlir::Type, llvm::ArrayRef<mlir::Value>);
212+
mlir::Value genBarrierArriveCnt(mlir::Type, llvm::ArrayRef<mlir::Value>);
211213
void genBarrierInit(llvm::ArrayRef<fir::ExtendedValue>);
212214
fir::ExtendedValue genBesselJn(mlir::Type,
213215
llvm::ArrayRef<fir::ExtendedValue>);

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,14 @@ static constexpr IntrinsicHandler handlers[]{
346346
&I::genVoteSync<mlir::NVVM::VoteSyncKind::ballot>,
347347
{{{"mask", asValue}, {"pred", asValue}}},
348348
/*isElemental=*/false},
349+
{"barrier_arrive",
350+
&I::genBarrierArrive,
351+
{{{"barrier", asAddr}}},
352+
/*isElemental=*/false},
353+
{"barrier_arrive_cnt",
354+
&I::genBarrierArriveCnt,
355+
{{{"barrier", asAddr}, {"count", asValue}}},
356+
/*isElemental=*/false},
349357
{"barrier_init",
350358
&I::genBarrierInit,
351359
{{{"barrier", asAddr}, {"count", asValue}}},
@@ -3180,19 +3188,53 @@ IntrinsicLibrary::genAssociated(mlir::Type resultType,
31803188
return fir::runtime::genAssociated(builder, loc, pointerBox, targetBox);
31813189
}
31823190

3183-
// BARRIER_INIT (CUDA)
3184-
void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) {
3185-
assert(args.size() == 2);
3186-
auto llvmPtr = fir::ConvertOp::create(
3191+
static mlir::Value convertBarrierToLLVM(fir::FirOpBuilder &builder,
3192+
mlir::Location loc,
3193+
mlir::Value barrier) {
3194+
mlir::Value llvmPtr = fir::ConvertOp::create(
31873195
builder, loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()),
3188-
fir::getBase(args[0]));
3189-
auto addrCast = mlir::LLVM::AddrSpaceCastOp::create(
3196+
barrier);
3197+
mlir::Value addrCast = mlir::LLVM::AddrSpaceCastOp::create(
31903198
builder, loc,
31913199
mlir::LLVM::LLVMPointerType::get(
31923200
builder.getContext(),
31933201
static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Shared)),
31943202
llvmPtr);
3195-
mlir::NVVM::MBarrierInitSharedOp::create(builder, loc, addrCast,
3203+
return addrCast;
3204+
}
3205+
3206+
// BARRIER_ARRIVE (CUDA)
3207+
mlir::Value
3208+
IntrinsicLibrary::genBarrierArrive(mlir::Type resultType,
3209+
llvm::ArrayRef<mlir::Value> args) {
3210+
assert(args.size() == 1);
3211+
mlir::Value barrier = convertBarrierToLLVM(builder, loc, args[0]);
3212+
return mlir::NVVM::MBarrierArriveSharedOp::create(builder, loc, resultType,
3213+
barrier)
3214+
.getResult();
3215+
}
3216+
3217+
// BARRIER_ARRIBVE_CNT (CUDA)
3218+
mlir::Value
3219+
IntrinsicLibrary::genBarrierArriveCnt(mlir::Type resultType,
3220+
llvm::ArrayRef<mlir::Value> args) {
3221+
assert(args.size() == 2);
3222+
mlir::Value barrier = convertBarrierToLLVM(builder, loc, args[0]);
3223+
mlir::Value token = fir::AllocaOp::create(builder, loc, resultType);
3224+
// TODO: the MBarrierArriveExpectTxOp is not taking the state argument and
3225+
// currently just the sink symbol `_`.
3226+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
3227+
mlir::NVVM::MBarrierArriveExpectTxOp::create(builder, loc, barrier, args[1],
3228+
{});
3229+
return fir::LoadOp::create(builder, loc, token);
3230+
}
3231+
3232+
// BARRIER_INIT (CUDA)
3233+
void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) {
3234+
assert(args.size() == 2);
3235+
mlir::Value barrier =
3236+
convertBarrierToLLVM(builder, loc, fir::getBase(args[0]));
3237+
mlir::NVVM::MBarrierInitSharedOp::create(builder, loc, barrier,
31963238
fir::getBase(args[1]), {});
31973239
}
31983240

flang/module/cudadevice.f90

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1987,13 +1987,27 @@ attributes(device,host) logical function on_device() bind(c)
19871987
end function
19881988
end interface
19891989

1990+
! TMA Operations
1991+
19901992
interface
19911993
attributes(device) subroutine barrier_init(barrier, count)
1992-
integer(8) :: barrier
1994+
integer(8), shared :: barrier
19931995
integer(4) :: count
19941996
end subroutine
19951997
end interface
19961998

1999+
interface barrier_arrive
2000+
attributes(device) function barrier_arrive(barrier) result(token)
2001+
integer(8), shared :: barrier
2002+
integer(8) :: token
2003+
end function
2004+
attributes(device) function barrier_arrive_cnt(barrier, count) result(token)
2005+
integer(8), shared :: barrier
2006+
integer(4) :: count
2007+
integer(8) :: token
2008+
end function
2009+
end interface
2010+
19972011
contains
19982012

19992013
attributes(device) subroutine syncthreads()

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,9 +394,14 @@ end subroutine
394394

395395
attributes(global) subroutine test_barrier()
396396
integer(8), shared :: barrier
397+
integer(8) :: token
398+
integer :: count
397399
call barrier_init(barrier, 256)
398-
end subroutine
399400

401+
token = barrier_arrive(barrier)
402+
403+
token = barrier_arrive(barrier, count)
404+
end subroutine
400405

401406
! CHECK-LABEL: func.func @_QPtest_barrier()
402407

@@ -406,3 +411,11 @@ end subroutine
406411
! CHECK: %[[LLVM_PTR:.*]] = fir.convert %[[DECL_SHARED]]#0 : (!fir.ref<i64>) -> !llvm.ptr
407412
! CHECK: %[[SHARED_PTR:.*]] = llvm.addrspacecast %[[LLVM_PTR]] : !llvm.ptr to !llvm.ptr<3>
408413
! CHECK: nvvm.mbarrier.init.shared %[[SHARED_PTR]], %[[COUNT]] : !llvm.ptr<3>, i32
414+
415+
! CHECK: %[[LLVM_PTR:.*]] = fir.convert %[[DECL_SHARED]]#0 : (!fir.ref<i64>) -> !llvm.ptr
416+
! CHECK: %[[SHARED_PTR:.*]] = llvm.addrspacecast %[[LLVM_PTR]] : !llvm.ptr to !llvm.ptr<3>
417+
! CHECK: %{{.*}} = nvvm.mbarrier.arrive.shared %[[SHARED_PTR]] : !llvm.ptr<3> -> i64
418+
419+
! CHECK: %[[LLVM_PTR:.*]] = fir.convert %[[DECL_SHARED]]#0 : (!fir.ref<i64>) -> !llvm.ptr
420+
! CHECK: %[[SHARED_PTR:.*]] = llvm.addrspacecast %[[LLVM_PTR]] : !llvm.ptr to !llvm.ptr<3>
421+
! CHECK: nvvm.mbarrier.arrive.expect_tx %[[SHARED_PTR]], %{{.*}} : !llvm.ptr<3>, i32

0 commit comments

Comments
 (0)