1313#include " check-omp-structure.h"
1414
1515#include " flang/Common/indirection.h"
16+ #include " flang/Common/template.h"
1617#include " flang/Evaluate/expression.h"
18+ #include " flang/Evaluate/match.h"
1719#include " flang/Evaluate/rewrite.h"
1820#include " flang/Evaluate/tools.h"
1921#include " flang/Parser/char-block.h"
@@ -50,6 +52,127 @@ static bool operator!=(const evaluate::Expr<T> &e, const evaluate::Expr<U> &f) {
5052 return !(e == f);
5153}
5254
55+ namespace {
56+ template <typename ...> struct IsIntegral {
57+ static constexpr bool value{false };
58+ };
59+
60+ template <common::TypeCategory C, int K>
61+ struct IsIntegral <evaluate::Type<C, K>> {
62+ static constexpr bool value{//
63+ C == common::TypeCategory::Integer ||
64+ C == common::TypeCategory::Unsigned ||
65+ C == common::TypeCategory::Logical};
66+ };
67+
68+ template <typename T> constexpr bool is_integral_v{IsIntegral<T>::value};
69+
70+ template <typename T, typename Op0, typename Op1>
71+ using ReassocOpBase = evaluate::match::AnyOfPattern< //
72+ evaluate::match::Add<T, Op0, Op1>, //
73+ evaluate::match::Mul<T, Op0, Op1>>;
74+
75+ template <typename T, typename Op0, typename Op1>
76+ struct ReassocOp : public ReassocOpBase <T, Op0, Op1> {
77+ using Base = ReassocOpBase<T, Op0, Op1>;
78+ using Base::Base;
79+ };
80+
81+ template <typename T, typename Op0, typename Op1>
82+ ReassocOp<T, Op0, Op1> reassocOp (const Op0 &op0, const Op1 &op1) {
83+ return ReassocOp<T, Op0, Op1>(op0, op1);
84+ }
85+ } // namespace
86+
87+ struct ReassocRewriter : public evaluate ::rewrite::Identity {
88+ using Id = evaluate::rewrite::Identity;
89+ using Id::operator ();
90+ struct NonIntegralTag {};
91+
92+ ReassocRewriter (const SomeExpr &atom) : atom_(atom) {}
93+
94+ // Try to find cases where the input expression is of the form
95+ // (1) (a . b) . c, or
96+ // (2) a . (b . c),
97+ // where . denotes an associative operation (currently + or *), and a, b, c
98+ // are some subexpresions.
99+ // If one of the operands in the nested operation is the atomic variable
100+ // (with some possible type conversions applied to it), bring it to the
101+ // top-level operation, and move the top-level operand into the nested
102+ // operation.
103+ // For example, assuming x is the atomic variable:
104+ // (a + x) + b -> (a + b) + x, i.e. (conceptually) swap x and b.
105+ template <typename T, typename U,
106+ typename = std::enable_if_t <is_integral_v<T>>>
107+ evaluate::Expr<T> operator ()(evaluate::Expr<T> &&x, const U &u) {
108+ // As per the above comment, there are 3 subexpressions involved in this
109+ // transformation. A match::Expr<T> will match evaluate::Expr<U> when T is
110+ // same as U, plus it will store a pointer (ref) to the matched expression.
111+ // When the match is successful, the sub[i].ref will point to a, b, x (in
112+ // some order) from the example above.
113+ evaluate::match::Expr<T> sub[3 ];
114+ auto inner{reassocOp<T>(sub[0 ], sub[1 ])};
115+ auto outer1{reassocOp<T>(inner, sub[2 ])}; // inner + something
116+ auto outer2{reassocOp<T>(sub[2 ], inner)}; // something + inner
117+ // There is no way to ensure that the outer operation is the same as
118+ // the inner one. They are matched independently, so we need to compare
119+ // the index in the member variant that represents the matched type.
120+ if ((match (outer1, x) && outer1.ref .index () == inner.ref .index ()) ||
121+ (match (outer2, x) && outer2.ref .index () == inner.ref .index ())) {
122+ size_t atomIdx{[&]() { // sub[atomIdx] will be the atom.
123+ size_t idx;
124+ for (idx = 0 ; idx != 3 ; ++idx) {
125+ if (IsAtom (*sub[idx].ref )) {
126+ break ;
127+ }
128+ }
129+ return idx;
130+ }()};
131+
132+ if (atomIdx > 2 ) {
133+ return Id::operator ()(std::move (x), u);
134+ }
135+ return common::visit (
136+ [&](auto &&s) {
137+ using Expr = evaluate::Expr<T>;
138+ using TypeS = llvm::remove_cvref_t <decltype (s)>;
139+ // This visitor has to be semantically correct for all possible
140+ // types of s even though at runtime s will only be one of the
141+ // matched types.
142+ // Limit the construction to the operation types that we tried
143+ // to match (otherwise TypeS(op1, op2) would fail for non-binary
144+ // operations).
145+ if constexpr (common::HasMember<TypeS,
146+ typename decltype (outer1)::MatchTypes>) {
147+ Expr atom{*sub[atomIdx].ref };
148+ Expr op1{*sub[(atomIdx + 1 ) % 3 ].ref };
149+ Expr op2{*sub[(atomIdx + 2 ) % 3 ].ref };
150+ return Expr (
151+ TypeS (atom, Expr (TypeS (std::move (op1), std::move (op2)))));
152+ } else {
153+ return Expr (TypeS (s));
154+ }
155+ },
156+ evaluate::match::deparen (x).u );
157+ }
158+ return Id::operator ()(std::move (x), u);
159+ }
160+
161+ template <typename T, typename U,
162+ typename = std::enable_if_t <!is_integral_v<T>>>
163+ evaluate::Expr<T> operator ()(
164+ evaluate::Expr<T> &&x, const U &u, NonIntegralTag = {}) {
165+ return Id::operator ()(std::move (x), u);
166+ }
167+
168+ private:
169+ template <typename T> bool IsAtom (const evaluate::Expr<T> &x) const {
170+ return IsSameOrConvertOf (evaluate::AsGenericExpr (AsRvalue (x)), atom_);
171+ }
172+
173+ const SomeExpr &atom_;
174+ };
175+
53176struct AnalyzedCondStmt {
54177 SomeExpr cond{evaluate::NullPointer{}}; // Default ctor is deleted
55178 parser::CharBlock source;
@@ -199,6 +322,26 @@ static std::pair<parser::CharBlock, parser::CharBlock> SplitAssignmentSource(
199322 llvm_unreachable (" Could not find assignment operator" );
200323}
201324
325+ static std::vector<SomeExpr> GetNonAtomExpressions (
326+ const SomeExpr &atom, const std::vector<SomeExpr> &exprs) {
327+ std::vector<SomeExpr> nonAtom;
328+ for (const SomeExpr &e : exprs) {
329+ if (!IsSameOrConvertOf (e, atom)) {
330+ nonAtom.push_back (e);
331+ }
332+ }
333+ return nonAtom;
334+ }
335+
336+ static std::vector<SomeExpr> GetNonAtomArguments (
337+ const SomeExpr &atom, const SomeExpr &expr) {
338+ if (auto &&maybe{GetConvertInput (expr)}) {
339+ return GetNonAtomExpressions (
340+ atom, GetTopLevelOperationIgnoreResizing (*maybe).second );
341+ }
342+ return {};
343+ }
344+
202345static bool IsCheckForAssociated (const SomeExpr &cond) {
203346 return GetTopLevelOperationIgnoreResizing (cond).first ==
204347 operation::Operator::Associated;
@@ -625,7 +768,8 @@ void OmpStructureChecker::CheckAtomicWriteAssignment(
625768 }
626769}
627770
628- void OmpStructureChecker::CheckAtomicUpdateAssignment (
771+ std::optional<evaluate::Assignment>
772+ OmpStructureChecker::CheckAtomicUpdateAssignment (
629773 const evaluate::Assignment &update, parser::CharBlock source) {
630774 // [6.0:191:1-7]
631775 // An update structured block is update-statement, an update statement
@@ -641,14 +785,46 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
641785 if (!IsVarOrFunctionRef (atom)) {
642786 ErrorShouldBeVariable (atom, rsrc);
643787 // Skip other checks.
644- return ;
788+ return std:: nullopt ;
645789 }
646790
647791 CheckAtomicVariable (atom, lsrc);
648792
793+ auto [hasErrors, tryReassoc]{CheckAtomicUpdateAssignmentRhs (
794+ atom, update.rhs , source, /* suppressDiagnostics=*/ true )};
795+
796+ if (!hasErrors) {
797+ CheckStorageOverlap (atom, GetNonAtomArguments (atom, update.rhs ), source);
798+ return std::nullopt ;
799+ } else if (tryReassoc) {
800+ ReassocRewriter ra (atom);
801+ SomeExpr raRhs{evaluate::rewrite::Mutator (ra)(update.rhs )};
802+
803+ std::tie (hasErrors, tryReassoc) = CheckAtomicUpdateAssignmentRhs (
804+ atom, raRhs, source, /* suppressDiagnostics=*/ true );
805+ if (!hasErrors) {
806+ CheckStorageOverlap (atom, GetNonAtomArguments (atom, raRhs), source);
807+
808+ evaluate::Assignment raAssign (update);
809+ raAssign.rhs = raRhs;
810+ return raAssign;
811+ }
812+ }
813+
814+ // This is guaranteed to report errors.
815+ CheckAtomicUpdateAssignmentRhs (
816+ atom, update.rhs , source, /* suppressDiagnostics=*/ false );
817+ return std::nullopt ;
818+ }
819+
820+ std::pair<bool , bool > OmpStructureChecker::CheckAtomicUpdateAssignmentRhs (
821+ const SomeExpr &atom, const SomeExpr &rhs, parser::CharBlock source,
822+ bool suppressDiagnostics) {
823+ auto [lsrc, rsrc]{SplitAssignmentSource (source)};
824+
649825 std::pair<operation::Operator, std::vector<SomeExpr>> top{
650826 operation::Operator::Unknown, {}};
651- if (auto &&maybeInput{GetConvertInput (update. rhs )}) {
827+ if (auto &&maybeInput{GetConvertInput (rhs)}) {
652828 top = GetTopLevelOperationIgnoreResizing (*maybeInput);
653829 }
654830 switch (top.first ) {
@@ -665,29 +841,39 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
665841 case operation::Operator::Identity:
666842 break ;
667843 case operation::Operator::Call:
668- context_.Say (source,
669- " A call to this function is not a valid ATOMIC UPDATE operation" _err_en_US);
670- return ;
844+ if (!suppressDiagnostics) {
845+ context_.Say (source,
846+ " A call to this function is not a valid ATOMIC UPDATE operation" _err_en_US);
847+ }
848+ return std::make_pair (true , false );
671849 case operation::Operator::Convert:
672- context_.Say (source,
673- " An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation" _err_en_US);
674- return ;
850+ if (!suppressDiagnostics) {
851+ context_.Say (source,
852+ " An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation" _err_en_US);
853+ }
854+ return std::make_pair (true , false );
675855 case operation::Operator::Intrinsic:
676- context_.Say (source,
677- " This intrinsic function is not a valid ATOMIC UPDATE operation" _err_en_US);
678- return ;
856+ if (!suppressDiagnostics) {
857+ context_.Say (source,
858+ " This intrinsic function is not a valid ATOMIC UPDATE operation" _err_en_US);
859+ }
860+ return std::make_pair (true , false );
679861 case operation::Operator::Constant:
680862 case operation::Operator::Unknown:
681- context_.Say (
682- source, " This is not a valid ATOMIC UPDATE operation" _err_en_US);
683- return ;
863+ if (!suppressDiagnostics) {
864+ context_.Say (
865+ source, " This is not a valid ATOMIC UPDATE operation" _err_en_US);
866+ }
867+ return std::make_pair (true , false );
684868 default :
685869 assert (
686870 top.first != operation::Operator::Identity && " Handle this separately" );
687- context_.Say (source,
688- " The %s operator is not a valid ATOMIC UPDATE operation" _err_en_US,
689- operation::ToString (top.first ));
690- return ;
871+ if (!suppressDiagnostics) {
872+ context_.Say (source,
873+ " The %s operator is not a valid ATOMIC UPDATE operation" _err_en_US,
874+ operation::ToString (top.first ));
875+ }
876+ return std::make_pair (true , false );
691877 }
692878 // Check how many times `atom` occurs as an argument, if it's a subexpression
693879 // of an argument, and collect the non-atom arguments.
@@ -708,39 +894,48 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
708894 return count;
709895 }()};
710896
711- bool hasError{false };
897+ bool hasError{false }, tryReassoc{ false } ;
712898 if (subExpr) {
713- context_.Say (rsrc,
714- " The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation" _err_en_US,
715- atom.AsFortran (), subExpr->AsFortran ());
899+ if (!suppressDiagnostics) {
900+ context_.Say (rsrc,
901+ " The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation" _err_en_US,
902+ atom.AsFortran (), subExpr->AsFortran ());
903+ }
716904 hasError = true ;
717905 }
718906 if (top.first == operation::Operator::Identity) {
719907 // This is "x = y".
720908 assert ((atomCount == 0 || atomCount == 1 ) && " Unexpected count" );
721909 if (atomCount == 0 ) {
722- context_.Say (rsrc,
723- " The atomic variable %s should appear as an argument in the update operation" _err_en_US,
724- atom.AsFortran ());
910+ if (!suppressDiagnostics) {
911+ context_.Say (rsrc,
912+ " The atomic variable %s should appear as an argument in the update operation" _err_en_US,
913+ atom.AsFortran ());
914+ }
725915 hasError = true ;
726916 }
727917 } else {
728918 if (atomCount == 0 ) {
729- context_.Say (rsrc,
730- " The atomic variable %s should appear as an argument of the top-level %s operator" _err_en_US,
731- atom.AsFortran (), operation::ToString (top.first ));
919+ if (!suppressDiagnostics) {
920+ context_.Say (rsrc,
921+ " The atomic variable %s should appear as an argument of the top-level %s operator" _err_en_US,
922+ atom.AsFortran (), operation::ToString (top.first ));
923+ }
924+ // If `atom` is a proper subexpression, and it not present as an
925+ // argument on its own, reassociation may be able to help.
926+ tryReassoc = subExpr.has_value ();
732927 hasError = true ;
733928 } else if (atomCount > 1 ) {
734- context_.Say (rsrc,
735- " The atomic variable %s should be exactly one of the arguments of the top-level %s operator" _err_en_US,
736- atom.AsFortran (), operation::ToString (top.first ));
929+ if (!suppressDiagnostics) {
930+ context_.Say (rsrc,
931+ " The atomic variable %s should be exactly one of the arguments of the top-level %s operator" _err_en_US,
932+ atom.AsFortran (), operation::ToString (top.first ));
933+ }
737934 hasError = true ;
738935 }
739936 }
740937
741- if (!hasError) {
742- CheckStorageOverlap (atom, nonAtom, source);
743- }
938+ return std::make_pair (hasError, tryReassoc);
744939}
745940
746941void OmpStructureChecker::CheckAtomicConditionalUpdateAssignment (
@@ -843,11 +1038,13 @@ void OmpStructureChecker::CheckAtomicUpdateOnly(
8431038 SourcedActionStmt action{GetActionStmt (&body.front ())};
8441039 if (auto maybeUpdate{GetEvaluateAssignment (action.stmt )}) {
8451040 const SomeExpr &atom{maybeUpdate->lhs };
846- CheckAtomicUpdateAssignment (*maybeUpdate, action.source );
1041+ auto maybeAssign{
1042+ CheckAtomicUpdateAssignment (*maybeUpdate, action.source )};
1043+ auto &updateAssign{maybeAssign.has_value () ? maybeAssign : maybeUpdate};
8471044
8481045 using Analysis = parser::OpenMPAtomicConstruct::Analysis;
8491046 x.analysis = AtomicAnalysis (atom)
850- .addOp0 (Analysis::Update, maybeUpdate )
1047+ .addOp0 (Analysis::Update, updateAssign )
8511048 .addOp1 (Analysis::None);
8521049 } else if (!IsAssignment (action.stmt )) {
8531050 context_.Say (
@@ -963,29 +1160,32 @@ void OmpStructureChecker::CheckAtomicUpdateCapture(
9631160 using Analysis = parser::OpenMPAtomicConstruct::Analysis;
9641161 int action;
9651162
1163+ std::optional<evaluate::Assignment> updateAssign{update};
9661164 if (IsMaybeAtomicWrite (update)) {
9671165 action = Analysis::Write;
9681166 CheckAtomicWriteAssignment (update, uact.source );
9691167 } else {
9701168 action = Analysis::Update;
971- CheckAtomicUpdateAssignment (update, uact.source );
1169+ if (auto &&maybe{CheckAtomicUpdateAssignment (update, uact.source )}) {
1170+ updateAssign = maybe;
1171+ }
9721172 }
9731173 CheckAtomicCaptureAssignment (capture, atom, cact.source );
9741174
975- if (IsPointerAssignment (update ) != IsPointerAssignment (capture)) {
1175+ if (IsPointerAssignment (*updateAssign ) != IsPointerAssignment (capture)) {
9761176 context_.Say (cact.source ,
9771177 " The update and capture assignments should both be pointer-assignments or both be non-pointer-assignments" _err_en_US);
9781178 return ;
9791179 }
9801180
9811181 if (GetActionStmt (&body.front ()).stmt == uact.stmt ) {
9821182 x.analysis = AtomicAnalysis (atom)
983- .addOp0 (action, update )
1183+ .addOp0 (action, updateAssign )
9841184 .addOp1 (Analysis::Read, capture);
9851185 } else {
9861186 x.analysis = AtomicAnalysis (atom)
9871187 .addOp0 (Analysis::Read, capture)
988- .addOp1 (action, update );
1188+ .addOp1 (action, updateAssign );
9891189 }
9901190}
9911191
0 commit comments