Skip to content

Commit b6b0f97

Browse files
authored
[flang][OpenMP] Support reduction of POINTER variables (#95148)
Just treat them the same as ALLOCATABLE. gfortran doesn't allow POINTER objects in a REDUCTION clause, but so far as I can tell the standard explicitly allows it (openmp5.2 section 5.5.5).
1 parent d62ff71 commit b6b0f97

File tree

3 files changed

+255
-8
lines changed

3 files changed

+255
-8
lines changed

flang/lib/Lower/OpenMP/ReductionProcessor.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -332,15 +332,17 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
332332
fir::unwrapRefType(boxTy.getEleTy()));
333333
fir::HeapType heapTy =
334334
mlir::dyn_cast_or_null<fir::HeapType>(boxTy.getEleTy());
335-
if ((!seqTy || seqTy.hasUnknownShape()) && !heapTy)
335+
fir::PointerType ptrTy =
336+
mlir::dyn_cast_or_null<fir::PointerType>(boxTy.getEleTy());
337+
if ((!seqTy || seqTy.hasUnknownShape()) && !heapTy && !ptrTy)
336338
TODO(loc, "Unsupported boxed type in OpenMP reduction");
337339

338340
// load fir.ref<fir.box<...>>
339341
mlir::Value lhsAddr = lhs;
340342
lhs = builder.create<fir::LoadOp>(loc, lhs);
341343
rhs = builder.create<fir::LoadOp>(loc, rhs);
342344

343-
if (heapTy && !seqTy) {
345+
if ((heapTy || ptrTy) && !seqTy) {
344346
// get box contents (heap pointers)
345347
lhs = builder.create<fir::BoxAddrOp>(loc, lhs);
346348
rhs = builder.create<fir::BoxAddrOp>(loc, rhs);
@@ -350,8 +352,10 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
350352
lhs = builder.create<fir::LoadOp>(loc, lhs);
351353
rhs = builder.create<fir::LoadOp>(loc, rhs);
352354

355+
mlir::Type eleTy = heapTy ? heapTy.getEleTy() : ptrTy.getEleTy();
356+
353357
mlir::Value result = ReductionProcessor::createScalarCombiner(
354-
builder, loc, redId, heapTy.getEleTy(), lhs, rhs);
358+
builder, loc, redId, eleTy, lhs, rhs);
355359
builder.create<fir::StoreOp>(loc, result, lhsValAddr);
356360
builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
357361
return;
@@ -439,7 +443,7 @@ createReductionCleanupRegion(fir::FirOpBuilder &builder, mlir::Location loc,
439443

440444
mlir::Type valTy = fir::unwrapRefType(redTy);
441445
if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(valTy)) {
442-
if (!mlir::isa<fir::HeapType>(boxTy.getEleTy())) {
446+
if (!mlir::isa<fir::HeapType, fir::PointerType>(boxTy.getEleTy())) {
443447
mlir::Type innerTy = fir::extractSequenceType(boxTy);
444448
if (!mlir::isa<fir::SequenceType>(innerTy))
445449
typeError();
@@ -533,12 +537,13 @@ createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc,
533537
// all arrays are boxed
534538
if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
535539
assert(isByRef && "passing boxes by value is unsupported");
536-
bool isAllocatable = mlir::isa<fir::HeapType>(boxTy.getEleTy());
540+
bool isAllocatableOrPointer =
541+
mlir::isa<fir::HeapType, fir::PointerType>(boxTy.getEleTy());
537542
mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty);
538543
mlir::Type innerTy = fir::unwrapRefType(boxTy.getEleTy());
539544
if (fir::isa_trivial(innerTy)) {
540545
// boxed non-sequence value e.g. !fir.box<!fir.heap<i32>>
541-
if (!isAllocatable)
546+
if (!isAllocatableOrPointer)
542547
TODO(loc, "Reduction of non-allocatable trivial typed box");
543548

544549
fir::IfOp ifUnallocated = handleNullAllocatable(boxAlloca);
@@ -560,7 +565,7 @@ createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc,
560565
TODO(loc, "Unsupported boxed type for reduction");
561566

562567
fir::IfOp ifUnallocated{nullptr};
563-
if (isAllocatable) {
568+
if (isAllocatableOrPointer) {
564569
ifUnallocated = handleNullAllocatable(boxAlloca);
565570
builder.setInsertionPointToStart(&ifUnallocated.getElseRegion().front());
566571
}
@@ -587,7 +592,8 @@ createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc,
587592
mlir::OpBuilder::InsertionGuard guard(builder);
588593
createReductionCleanupRegion(builder, loc, reductionDecl);
589594
} else {
590-
assert(!isAllocatable && "Allocatable arrays must be heap allocated");
595+
assert(!isAllocatableOrPointer &&
596+
"Pointer-like arrays must be heap allocated");
591597
}
592598

593599
// Put the temporary inside of a box:
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
! RUN: bbc -emit-hlfir -fopenmp -o - %s | FileCheck %s
2+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s | FileCheck %s
3+
4+
program reduce
5+
integer :: i = 0
6+
integer, dimension(:), pointer :: r
7+
8+
allocate(r(2))
9+
10+
!$omp parallel do reduction(+:r)
11+
do i=0,10
12+
r(1) = i
13+
r(2) = -i
14+
enddo
15+
!$omp end parallel do
16+
17+
print *,r
18+
deallocate(r)
19+
20+
end program
21+
22+
! CHECK-LABEL: omp.declare_reduction @add_reduction_byref_box_ptr_Uxi32 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>> init {
23+
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
24+
! CHECK: %[[VAL_1:.*]] = arith.constant 0 : i32
25+
! CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
26+
! CHECK: %[[VAL_3:.*]] = fir.alloca !fir.box<!fir.ptr<!fir.array<?xi32>>>
27+
! CHECK: %[[VAL_4:.*]] = fir.box_addr %[[VAL_2]] : (!fir.box<!fir.ptr<!fir.array<?xi32>>>) -> !fir.ptr<!fir.array<?xi32>>
28+
! CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_4]] : (!fir.ptr<!fir.array<?xi32>>) -> i64
29+
! CHECK: %[[VAL_6:.*]] = arith.constant 0 : i64
30+
! CHECK: %[[VAL_7:.*]] = arith.cmpi eq, %[[VAL_5]], %[[VAL_6]] : i64
31+
! CHECK: fir.if %[[VAL_7]] {
32+
! CHECK: %[[VAL_8:.*]] = fir.embox %[[VAL_4]] : (!fir.ptr<!fir.array<?xi32>>) -> !fir.box<!fir.ptr<!fir.array<?xi32>>>
33+
! CHECK: fir.store %[[VAL_8]] to %[[VAL_3]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
34+
! CHECK: } else {
35+
! CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
36+
! CHECK: %[[VAL_10:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_9]] : (!fir.box<!fir.ptr<!fir.array<?xi32>>>, index) -> (index, index, index)
37+
! CHECK: %[[VAL_11:.*]] = fir.shape %[[VAL_10]]#1 : (index) -> !fir.shape<1>
38+
! CHECK: %[[VAL_12:.*]] = fir.allocmem !fir.array<?xi32>, %[[VAL_10]]#1 {bindc_name = ".tmp", uniq_name = ""}
39+
! CHECK: %[[VAL_13:.*]] = arith.constant true
40+
! CHECK: %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_11]]) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xi32>>, !fir.heap<!fir.array<?xi32>>)
41+
! CHECK: %[[VAL_15:.*]] = arith.constant 0 : index
42+
! CHECK: %[[VAL_16:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_15]] : (!fir.box<!fir.ptr<!fir.array<?xi32>>>, index) -> (index, index, index)
43+
! CHECK: %[[VAL_17:.*]] = fir.shape_shift %[[VAL_16]]#0, %[[VAL_16]]#1 : (index, index) -> !fir.shapeshift<1>
44+
! CHECK: %[[VAL_18:.*]] = fir.rebox %[[VAL_14]]#0(%[[VAL_17]]) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) -> !fir.box<!fir.ptr<!fir.array<?xi32>>>
45+
! CHECK: hlfir.assign %[[VAL_1]] to %[[VAL_18]] : i32, !fir.box<!fir.ptr<!fir.array<?xi32>>>
46+
! CHECK: fir.store %[[VAL_18]] to %[[VAL_3]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
47+
! CHECK: }
48+
! CHECK: omp.yield(%[[VAL_3]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>)
49+
! CHECK-LABEL: } combiner {
50+
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, %[[VAL_1:.*]]: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
51+
! CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
52+
! CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
53+
! CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
54+
! CHECK: %[[VAL_5:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_4]] : (!fir.box<!fir.ptr<!fir.array<?xi32>>>, index) -> (index, index, index)
55+
! CHECK: %[[VAL_6:.*]] = fir.shape_shift %[[VAL_5]]#0, %[[VAL_5]]#1 : (index, index) -> !fir.shapeshift<1>
56+
! CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
57+
! CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_7]] to %[[VAL_5]]#1 step %[[VAL_7]] unordered {
58+
! CHECK: %[[VAL_9:.*]] = fir.array_coor %[[VAL_2]](%[[VAL_6]]) %[[VAL_8]] : (!fir.box<!fir.ptr<!fir.array<?xi32>>>, !fir.shapeshift<1>, index) -> !fir.ref<i32>
59+
! CHECK: %[[VAL_10:.*]] = fir.array_coor %[[VAL_3]](%[[VAL_6]]) %[[VAL_8]] : (!fir.box<!fir.ptr<!fir.array<?xi32>>>, !fir.shapeshift<1>, index) -> !fir.ref<i32>
60+
! CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_9]] : !fir.ref<i32>
61+
! CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_10]] : !fir.ref<i32>
62+
! CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i32
63+
! CHECK: fir.store %[[VAL_13]] to %[[VAL_9]] : !fir.ref<i32>
64+
! CHECK: }
65+
! CHECK: omp.yield(%[[VAL_0]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>)
66+
! CHECK-LABEL: } cleanup {
67+
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
68+
! CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
69+
! CHECK: %[[VAL_2:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.ptr<!fir.array<?xi32>>>) -> !fir.ptr<!fir.array<?xi32>>
70+
! CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ptr<!fir.array<?xi32>>) -> i64
71+
! CHECK: %[[VAL_4:.*]] = arith.constant 0 : i64
72+
! CHECK: %[[VAL_5:.*]] = arith.cmpi ne, %[[VAL_3]], %[[VAL_4]] : i64
73+
! CHECK: fir.if %[[VAL_5]] {
74+
! CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_2]] : (!fir.ptr<!fir.array<?xi32>>) -> !fir.heap<!fir.array<?xi32>>
75+
! CHECK: fir.freemem %[[VAL_6]] : !fir.heap<!fir.array<?xi32>>
76+
! CHECK: }
77+
! CHECK: omp.yield
78+
! CHECK: }
79+
80+
! CHECK-LABEL: func.func @_QQmain() attributes {fir.bindc_name = "reduce"} {
81+
! CHECK: %[[VAL_0:.*]] = fir.address_of(@_QFEi) : !fir.ref<i32>
82+
! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
83+
! CHECK: %[[VAL_2:.*]] = fir.address_of(@_QFEr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
84+
! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] {fortran_attrs = {{.*}}<pointer>, uniq_name = "_QFEr"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>)
85+
! CHECK: %[[VAL_4:.*]] = arith.constant false
86+
! CHECK: %[[VAL_5:.*]] = fir.absent !fir.box<none>
87+
! CHECK: %[[VAL_6:.*]] = fir.address_of(
88+
! CHECK: %[[VAL_7:.*]] = arith.constant 8 : i32
89+
! CHECK: %[[VAL_8:.*]] = fir.zero_bits !fir.ptr<!fir.array<?xi32>>
90+
! CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
91+
! CHECK: %[[VAL_10:.*]] = fir.shape %[[VAL_9]] : (index) -> !fir.shape<1>
92+
! CHECK: %[[VAL_11:.*]] = fir.embox %[[VAL_8]](%[[VAL_10]]) : (!fir.ptr<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xi32>>>
93+
! CHECK: fir.store %[[VAL_11]] to %[[VAL_3]]#1 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
94+
! CHECK: %[[VAL_12:.*]] = arith.constant 1 : index
95+
! CHECK: %[[VAL_13:.*]] = arith.constant 2 : i32
96+
! CHECK: %[[VAL_14:.*]] = arith.constant 0 : i32
97+
! CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_3]]#1 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
98+
! CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_12]] : (index) -> i64
99+
! CHECK: %[[VAL_17:.*]] = fir.convert %[[VAL_13]] : (i32) -> i64
100+
! CHECK: %[[VAL_18:.*]] = fir.call @_FortranAPointerSetBounds(%[[VAL_15]], %[[VAL_14]], %[[VAL_16]], %[[VAL_17]]) fastmath<contract> : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> none
101+
! CHECK: %[[VAL_19:.*]] = fir.convert %[[VAL_3]]#1 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
102+
! CHECK: %[[VAL_20:.*]] = fir.convert %[[VAL_6]] : (!fir.ref<!fir.char<{{.*}}>>) -> !fir.ref<i8>
103+
! CHECK: %[[VAL_21:.*]] = fir.call @_FortranAPointerAllocate(%[[VAL_19]], %[[VAL_4]], %[[VAL_5]], %[[VAL_20]], %[[VAL_7]]) fastmath<contract> : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
104+
! CHECK: omp.parallel {
105+
! CHECK: %[[VAL_22:.*]] = fir.alloca i32 {bindc_name = "i", pinned, uniq_name = "_QFEi"}
106+
! CHECK: %[[VAL_23:.*]]:2 = hlfir.declare %[[VAL_22]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
107+
! CHECK: %[[VAL_24:.*]] = arith.constant 0 : i32
108+
! CHECK: %[[VAL_25:.*]] = arith.constant 10 : i32
109+
! CHECK: %[[VAL_26:.*]] = arith.constant 1 : i32
110+
! CHECK: omp.wsloop reduction(byref @add_reduction_byref_box_ptr_Uxi32 %[[VAL_3]]#0 -> %[[VAL_27:.*]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
111+
! CHECK: omp.loop_nest (%[[VAL_28:.*]]) : i32 = (%[[VAL_24]]) to (%[[VAL_25]]) inclusive step (%[[VAL_26]]) {
112+
! CHECK: %[[VAL_29:.*]]:2 = hlfir.declare %[[VAL_27]] {fortran_attrs = {{.*}}<pointer>, uniq_name = "_QFEr"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>)
113+
! CHECK: fir.store %[[VAL_28]] to %[[VAL_23]]#1 : !fir.ref<i32>
114+
! CHECK: %[[VAL_30:.*]] = fir.load %[[VAL_23]]#0 : !fir.ref<i32>
115+
! CHECK: %[[VAL_31:.*]] = fir.load %[[VAL_29]]#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
116+
! CHECK: %[[VAL_32:.*]] = arith.constant 1 : index
117+
! CHECK: %[[VAL_33:.*]] = hlfir.designate %[[VAL_31]] (%[[VAL_32]]) : (!fir.box<!fir.ptr<!fir.array<?xi32>>>, index) -> !fir.ref<i32>
118+
! CHECK: hlfir.assign %[[VAL_30]] to %[[VAL_33]] : i32, !fir.ref<i32>
119+
! CHECK: %[[VAL_34:.*]] = fir.load %[[VAL_23]]#0 : !fir.ref<i32>
120+
! CHECK: %[[VAL_35:.*]] = arith.constant 0 : i32
121+
! CHECK: %[[VAL_36:.*]] = arith.subi %[[VAL_35]], %[[VAL_34]] : i32
122+
! CHECK: %[[VAL_37:.*]] = fir.load %[[VAL_29]]#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
123+
! CHECK: %[[VAL_38:.*]] = arith.constant 2 : index
124+
! CHECK: %[[VAL_39:.*]] = hlfir.designate %[[VAL_37]] (%[[VAL_38]]) : (!fir.box<!fir.ptr<!fir.array<?xi32>>>, index) -> !fir.ref<i32>
125+
! CHECK: hlfir.assign %[[VAL_36]] to %[[VAL_39]] : i32, !fir.ref<i32>
126+
! CHECK: omp.yield
127+
! CHECK: }
128+
! CHECK: omp.terminator
129+
! CHECK: }
130+
! CHECK: omp.terminator
131+
! CHECK: }

0 commit comments

Comments
 (0)