Skip to content

Commit 95d6caa

Browse files
authored
[flang][cuda] Add interfaces and lowering for atomicaddvector (#166275)
1 parent 57730f6 commit 95d6caa

File tree

4 files changed

+86
-0
lines changed

4 files changed

+86
-0
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ struct IntrinsicLibrary {
190190
mlir::Value genAtomicAdd(mlir::Type, llvm::ArrayRef<mlir::Value>);
191191
fir::ExtendedValue genAtomicAddR2(mlir::Type,
192192
llvm::ArrayRef<fir::ExtendedValue>);
193+
fir::ExtendedValue genAtomicAddVector(mlir::Type,
194+
llvm::ArrayRef<fir::ExtendedValue>);
193195
mlir::Value genAtomicAnd(mlir::Type, llvm::ArrayRef<mlir::Value>);
194196
fir::ExtendedValue genAtomicCas(mlir::Type,
195197
llvm::ArrayRef<fir::ExtendedValue>);

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,14 @@ static constexpr IntrinsicHandler handlers[]{
290290
{"atan2pi", &I::genAtanpi},
291291
{"atand", &I::genAtand},
292292
{"atanpi", &I::genAtanpi},
293+
{"atomicadd_r2x2",
294+
&I::genAtomicAddVector,
295+
{{{"a", asAddr}, {"v", asAddr}}},
296+
false},
297+
{"atomicadd_r4x2",
298+
&I::genAtomicAddVector,
299+
{{{"a", asAddr}, {"v", asAddr}}},
300+
false},
293301
{"atomicaddd", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
294302
{"atomicaddf", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
295303
{"atomicaddi", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
@@ -3168,6 +3176,47 @@ IntrinsicLibrary::genAtomicAddR2(mlir::Type resultType,
31683176
mlir::ArrayRef<int64_t>{0});
31693177
}
31703178

3179+
fir::ExtendedValue
3180+
IntrinsicLibrary::genAtomicAddVector(mlir::Type resultType,
3181+
llvm::ArrayRef<fir::ExtendedValue> args) {
3182+
assert(args.size() == 2);
3183+
mlir::Value res = fir::AllocaOp::create(
3184+
builder, loc, fir::SequenceType::get({2}, resultType));
3185+
mlir::Value a = fir::getBase(args[0]);
3186+
if (mlir::isa<fir::BaseBoxType>(a.getType())) {
3187+
a = fir::BoxAddrOp::create(builder, loc, a);
3188+
}
3189+
auto vecTy = mlir::VectorType::get({2}, resultType);
3190+
auto refTy = fir::ReferenceType::get(resultType);
3191+
mlir::Type i32Ty = builder.getI32Type();
3192+
mlir::Type idxTy = builder.getIndexType();
3193+
mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
3194+
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
3195+
mlir::Value v1Coord = fir::CoordinateOp::create(builder, loc, refTy,
3196+
fir::getBase(args[1]), zero);
3197+
mlir::Value v2Coord = fir::CoordinateOp::create(builder, loc, refTy,
3198+
fir::getBase(args[1]), one);
3199+
mlir::Value v1 = fir::LoadOp::create(builder, loc, v1Coord);
3200+
mlir::Value v2 = fir::LoadOp::create(builder, loc, v2Coord);
3201+
mlir::Value undef = mlir::LLVM::UndefOp::create(builder, loc, vecTy);
3202+
mlir::Value vec1 = mlir::LLVM::InsertElementOp::create(
3203+
builder, loc, undef, v1, builder.createIntegerConstant(loc, i32Ty, 0));
3204+
mlir::Value vec2 = mlir::LLVM::InsertElementOp::create(
3205+
builder, loc, vec1, v2, builder.createIntegerConstant(loc, i32Ty, 1));
3206+
mlir::Value add =
3207+
genAtomBinOp(builder, loc, mlir::LLVM::AtomicBinOp::fadd, a, vec2);
3208+
mlir::Value r1 = mlir::LLVM::ExtractElementOp::create(
3209+
builder, loc, add, builder.createIntegerConstant(loc, i32Ty, 0));
3210+
mlir::Value r2 = mlir::LLVM::ExtractElementOp::create(
3211+
builder, loc, add, builder.createIntegerConstant(loc, i32Ty, 1));
3212+
mlir::Value c1 = fir::CoordinateOp::create(builder, loc, refTy, res, zero);
3213+
mlir::Value c2 = fir::CoordinateOp::create(builder, loc, refTy, res, one);
3214+
fir::StoreOp::create(builder, loc, r1, c1);
3215+
fir::StoreOp::create(builder, loc, r2, c2);
3216+
mlir::Value ext = builder.createIntegerConstant(loc, idxTy, 2);
3217+
return fir::ArrayBoxValue(res, {ext});
3218+
}
3219+
31713220
mlir::Value IntrinsicLibrary::genAtomicSub(mlir::Type resultType,
31723221
llvm::ArrayRef<mlir::Value> args) {
31733222
assert(args.size() == 2);

flang/module/cudadevice.f90

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,6 +1178,22 @@ attributes(device) pure integer(4) function atomicaddr2(address, val)
11781178
end function
11791179
end interface
11801180

1181+
interface atomicaddvector
1182+
attributes(device) pure function atomicadd_r2x2(address, val) result(z)
1183+
!dir$ ignore_tkr (rd) address, (d) val
1184+
real(2), dimension(2), intent(inout) :: address
1185+
real(2), dimension(2), intent(in) :: val
1186+
real(2), dimension(2) :: z
1187+
end function
1188+
1189+
attributes(device) pure function atomicadd_r4x2(address, val) result(z)
1190+
!dir$ ignore_tkr (rd) address, (d) val
1191+
real(4), dimension(2), intent(inout) :: address
1192+
real(4), dimension(2), intent(in) :: val
1193+
real(4), dimension(2) :: z
1194+
end function
1195+
end interface
1196+
11811197
interface atomicsub
11821198
attributes(device) pure integer function atomicsubi(address, val)
11831199
!dir$ ignore_tkr (d) address, (d) val
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
2+
3+
! Test CUDA Fortran atmoicadd functions available cudadevice module
4+
5+
attributes(global) subroutine atomicaddvector_r2()
6+
real(2), device :: a(2), tmp1(2), tmp2(2)
7+
tmp1 = atomicAddVector(a, tmp2)
8+
end subroutine
9+
10+
! CHECK-LABEL: func.func @_QPatomicaddvector_r2() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
11+
! CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, vector<2xf16>
12+
13+
attributes(global) subroutine atomicaddvector_r4()
14+
real(4), device :: a(2), tmp1(2), tmp2(2)
15+
tmp1 = atomicAddVector(a, tmp2)
16+
end subroutine
17+
18+
! CHECK-LABEL: func.func @_QPatomicaddvector_r4() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
19+
! CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, vector<2xf32>

0 commit comments

Comments
 (0)