From d80d54bf9178baecc46121f33eec861729544cd3 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Mon, 4 Dec 2023 12:57:36 -0800 Subject: [PATCH 01/29] Add coexecute directives --- llvm/include/llvm/Frontend/OpenMP/OMP.td | 50 ++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index a87111cb5a11d..696ec91083743 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -685,7 +685,7 @@ def OMP_CancellationPoint : Directive<[Spelling<"cancellation point">]> { let association = AS_None; let category = CA_Executable; } -def OMP_Critical : Directive<[Spelling<"critical">]> { +def OMP_Critical : Directive<"critical"> { let allowedOnceClauses = [ VersionedClause, ]; @@ -2206,8 +2206,34 @@ def OMP_TargetTeams : Directive<[Spelling<"target teams">]> { let leafConstructs = [OMP_Target, OMP_Teams]; let category = CA_Executable; } -def OMP_TargetTeamsDistribute - : Directive<[Spelling<"target teams distribute">]> { +def OMP_TargetTeamsCoexecute : Directive<"target teams coexecute"> { + let allowedClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + + let allowedOnceClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; +} +def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> { let allowedClauses = [ VersionedClause, VersionedClause, @@ -2493,7 +2519,23 @@ def OMP_TaskLoopSimd : Directive<[Spelling<"taskloop simd">]> { let leafConstructs = [OMP_TaskLoop, OMP_Simd]; let category = CA_Executable; } -def OMP_TeamsDistribute : Directive<[Spelling<"teams distribute">]> { +def OMP_TeamsCoexecute : Directive<"teams coexecute"> { + let allowedClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let allowedOnceClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause + ]; +} +def OMP_TeamsDistribute : Directive<"teams distribute"> { let allowedClauses = [ VersionedClause, VersionedClause, From 9235aaee4a4cb60fd85227681c3bbf1c99f6b01d Mon Sep 17 00:00:00 2001 From: skc7 Date: Tue, 13 May 2025 11:01:45 +0530 Subject: [PATCH 02/29] [OpenMP] Fix Coexecute definitions --- llvm/include/llvm/Frontend/OpenMP/OMP.td | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index 696ec91083743..9b325057a01b7 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -685,6 +685,15 @@ def OMP_CancellationPoint : Directive<[Spelling<"cancellation point">]> { let association = AS_None; let category = CA_Executable; } +def OMP_Coexecute : Directive<"coexecute"> { + let association = AS_Block; + let category = CA_Executable; +} +def OMP_EndCoexecute : Directive<"end coexecute"> { + let leafConstructs = OMP_Coexecute.leafConstructs; + let association = OMP_Coexecute.association; + let category = OMP_Coexecute.category; +} def OMP_Critical : Directive<"critical"> { let allowedOnceClauses = [ VersionedClause, @@ -2230,8 +2239,10 @@ def OMP_TargetTeamsCoexecute : Directive<"target teams coexecute"> { VersionedClause, VersionedClause, VersionedClause, - VersionedClause, + VersionedClause, ]; + let leafConstructs = [OMP_Target, OMP_Teams, OMP_Coexecute]; + let category = CA_Executable; } def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> { let allowedClauses = [ @@ -2534,6 +2545,8 @@ def OMP_TeamsCoexecute : Directive<"teams coexecute"> { VersionedClause, VersionedClause ]; + let leafConstructs = [OMP_Target, OMP_Teams]; + let category = CA_Executable; } def OMP_TeamsDistribute : Directive<"teams distribute"> { let allowedClauses = [ From e67e21ef820bdb506866345479bb74a51ae31de5 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Mon, 4 Dec 2023 12:58:10 -0800 Subject: [PATCH 03/29] Add omp.coexecute op --- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index ac80926053a2d..bb61a46e13d6b 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -325,6 +325,41 @@ def SectionsOp : OpenMP_Op<"sections", traits = [ let hasRegionVerifier = 1; } +//===----------------------------------------------------------------------===// +// Coexecute Construct +//===----------------------------------------------------------------------===// + +def CoexecuteOp : OpenMP_Op<"coexecute"> { + let summary = "coexecute directive"; + let description = [{ + The coexecute construct specifies that the teams from the teams directive + this is nested in shall cooperate to execute the computation in this region. + There is no implicit barrier at the end as specified in the standard. + + TODO + We should probably change the defaut behaviour to have a barrier unless + nowait is specified, see below snippet. + + ``` + !$omp target teams + !$omp coexecute + tmp = matmul(x, y) + !$omp end coexecute + a = tmp(0, 0) ! there is no implicit barrier! the matmul hasnt completed! + !$omp end target teams coexecute + ``` + + }]; + + let arguments = (ins UnitAttr:$nowait); + + let regions = (region AnyRegion:$region); + + let assemblyFormat = [{ + oilist(`nowait` $nowait) $region attr-dict + }]; +} + //===----------------------------------------------------------------------===// // 2.8.2 Single Construct //===----------------------------------------------------------------------===// From 2fb623becf2b4ea69fe48a0a51a6ac40cac1e0c6 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Mon, 4 Dec 2023 17:50:41 -0800 Subject: [PATCH 04/29] Initial frontend support for coexecute --- .../include/flang/Semantics/openmp-directive-sets.h | 13 +++++++++++++ flang/lib/Lower/OpenMP/OpenMP.cpp | 12 ++++++++++++ flang/lib/Parser/openmp-parsers.cpp | 5 ++++- flang/lib/Semantics/resolve-directives.cpp | 6 ++++++ 4 files changed, 35 insertions(+), 1 deletion(-) diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h index dd610c9702c28..5c316e030c63f 100644 --- a/flang/include/flang/Semantics/openmp-directive-sets.h +++ b/flang/include/flang/Semantics/openmp-directive-sets.h @@ -143,6 +143,7 @@ static const OmpDirectiveSet topTargetSet{ Directive::OMPD_target_teams_distribute_parallel_do_simd, Directive::OMPD_target_teams_distribute_simd, Directive::OMPD_target_teams_loop, + Directive::OMPD_target_teams_coexecute, }; static const OmpDirectiveSet allTargetSet{topTargetSet}; @@ -187,9 +188,16 @@ static const OmpDirectiveSet allTeamsSet{ Directive::OMPD_target_teams_distribute_parallel_do_simd, Directive::OMPD_target_teams_distribute_simd, Directive::OMPD_target_teams_loop, + Directive::OMPD_target_teams_coexecute, } | topTeamsSet, }; +static const OmpDirectiveSet allCoexecuteSet{ + Directive::OMPD_coexecute, + Directive::OMPD_teams_coexecute, + Directive::OMPD_target_teams_coexecute, +}; + //===----------------------------------------------------------------------===// // Directive sets for groups of multiple directives //===----------------------------------------------------------------------===// @@ -230,6 +238,9 @@ static const OmpDirectiveSet blockConstructSet{ Directive::OMPD_taskgroup, Directive::OMPD_teams, Directive::OMPD_workshare, + Directive::OMPD_target_teams_coexecute, + Directive::OMPD_teams_coexecute, + Directive::OMPD_coexecute, }; static const OmpDirectiveSet loopConstructSet{ @@ -294,6 +305,7 @@ static const OmpDirectiveSet workShareSet{ Directive::OMPD_scope, Directive::OMPD_sections, Directive::OMPD_single, + Directive::OMPD_coexecute, } | allDoSet, }; @@ -376,6 +388,7 @@ static const OmpDirectiveSet nestedReduceWorkshareAllowedSet{ }; static const OmpDirectiveSet nestedTeamsAllowedSet{ + Directive::OMPD_coexecute, Directive::OMPD_distribute, Directive::OMPD_distribute_parallel_do, Directive::OMPD_distribute_parallel_do_simd, diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index ebd1d038716e4..390bba9cc5adf 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -2682,6 +2682,15 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, queue, item, clauseOps); } +static mlir::omp::CoexecuteOp +genCoexecuteOp(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + mlir::Location currentLocation, + const Fortran::parser::OmpClauseList &clauseList) { + return genOpWithBody( + converter, eval, currentLocation, /*outerCombined=*/false, &clauseList); +} + //===----------------------------------------------------------------------===// // Code generation functions for the standalone version of constructs that can // also be a leaf of a composite construct @@ -3296,6 +3305,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter, newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); break; + case llvm::omp::Directive::OMPD_coexecute: + newOp = genCoexecuteOp(converter, eval, currentLocation, beginClauseList); + break; case llvm::omp::Directive::OMPD_tile: case llvm::omp::Directive::OMPD_unroll: { unsigned version = semaCtx.langOptions().OpenMPVersion; diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp index c55642d969503..ebf2bb0c19bfd 100644 --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -1492,12 +1492,15 @@ TYPE_PARSER( "SINGLE" >> pure(llvm::omp::Directive::OMPD_single), "TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data), "TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel), + "TARGET TEAMS COEXECUTE" >> pure(llvm::omp::Directive::OMPD_target_teams_coexecute), "TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams), "TARGET" >> pure(llvm::omp::Directive::OMPD_target), "TASK"_id >> pure(llvm::omp::Directive::OMPD_task), "TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup), + "TEAMS COEXECUTE" >> pure(llvm::omp::Directive::OMPD_teams_coexecute), "TEAMS" >> pure(llvm::omp::Directive::OMPD_teams), - "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare)))) + "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare), + "COEXECUTE" >> pure(llvm::omp::Directive::OMPD_coexecute)))) TYPE_PARSER(sourced(construct( sourced(Parser{}), Parser{}))) diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index 885c02e6ec74b..133d4a6c18f17 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -1656,6 +1656,9 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) { case llvm::omp::Directive::OMPD_task: case llvm::omp::Directive::OMPD_taskgroup: case llvm::omp::Directive::OMPD_teams: + case llvm::omp::Directive::OMPD_coexecute: + case llvm::omp::Directive::OMPD_teams_coexecute: + case llvm::omp::Directive::OMPD_target_teams_coexecute: case llvm::omp::Directive::OMPD_workshare: case llvm::omp::Directive::OMPD_parallel_workshare: case llvm::omp::Directive::OMPD_target_teams: @@ -1689,6 +1692,9 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) { case llvm::omp::Directive::OMPD_target: case llvm::omp::Directive::OMPD_task: case llvm::omp::Directive::OMPD_teams: + case llvm::omp::Directive::OMPD_coexecute: + case llvm::omp::Directive::OMPD_teams_coexecute: + case llvm::omp::Directive::OMPD_target_teams_coexecute: case llvm::omp::Directive::OMPD_parallel_workshare: case llvm::omp::Directive::OMPD_target_teams: case llvm::omp::Directive::OMPD_target_parallel: { From 21a93675e2b15a853fa69d80b9ecb4bc7b47156d Mon Sep 17 00:00:00 2001 From: skc7 Date: Tue, 13 May 2025 15:09:45 +0530 Subject: [PATCH 05/29] [OpenMP] Fixes for coexecute definitions --- .../flang/Semantics/openmp-directive-sets.h | 1 + flang/lib/Lower/OpenMP/OpenMP.cpp | 13 ++-- flang/test/Lower/OpenMP/coexecute.f90 | 59 +++++++++++++++++++ llvm/include/llvm/Frontend/OpenMP/OMP.td | 33 +++++------ 4 files changed, 83 insertions(+), 23 deletions(-) create mode 100644 flang/test/Lower/OpenMP/coexecute.f90 diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h index 5c316e030c63f..43f4e642b3d86 100644 --- a/flang/include/flang/Semantics/openmp-directive-sets.h +++ b/flang/include/flang/Semantics/openmp-directive-sets.h @@ -173,6 +173,7 @@ static const OmpDirectiveSet topTeamsSet{ Directive::OMPD_teams_distribute_parallel_do_simd, Directive::OMPD_teams_distribute_simd, Directive::OMPD_teams_loop, + Directive::OMPD_teams_coexecute, }; static const OmpDirectiveSet bottomTeamsSet{ diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 390bba9cc5adf..0c436c15fd6c9 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -2683,12 +2683,13 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, } static mlir::omp::CoexecuteOp -genCoexecuteOp(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList &clauseList) { +genCoexecuteOp(lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, + mlir::Location loc, const ConstructQueue &queue, + ConstructQueue::const_iterator item) { return genOpWithBody( - converter, eval, currentLocation, /*outerCombined=*/false, &clauseList); + OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, + llvm::omp::Directive::OMPD_coexecute), queue, item); } //===----------------------------------------------------------------------===// @@ -3306,7 +3307,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter, item); break; case llvm::omp::Directive::OMPD_coexecute: - newOp = genCoexecuteOp(converter, eval, currentLocation, beginClauseList); + newOp = genCoexecuteOp(converter, symTable, semaCtx, eval, loc, queue, item); break; case llvm::omp::Directive::OMPD_tile: case llvm::omp::Directive::OMPD_unroll: { diff --git a/flang/test/Lower/OpenMP/coexecute.f90 b/flang/test/Lower/OpenMP/coexecute.f90 new file mode 100644 index 0000000000000..b14f71f9bbbfa --- /dev/null +++ b/flang/test/Lower/OpenMP/coexecute.f90 @@ -0,0 +1,59 @@ +! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_coexecute +subroutine target_teams_coexecute() + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.coexecute + !$omp target teams coexecute + ! CHECK: fir.call + call f1() + ! CHECK: omp.terminator + ! CHECK: omp.terminator + ! CHECK: omp.terminator + !$omp end target teams coexecute +end subroutine target_teams_coexecute + +! CHECK-LABEL: func @_QPteams_coexecute +subroutine teams_coexecute() + ! CHECK: omp.teams + ! CHECK: omp.coexecute + !$omp teams coexecute + ! CHECK: fir.call + call f1() + ! CHECK: omp.terminator + ! CHECK: omp.terminator + !$omp end teams coexecute +end subroutine teams_coexecute + +! CHECK-LABEL: func @_QPtarget_teams_coexecute_m +subroutine target_teams_coexecute_m() + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.coexecute + !$omp target + !$omp teams + !$omp coexecute + ! CHECK: fir.call + call f1() + ! CHECK: omp.terminator + ! CHECK: omp.terminator + ! CHECK: omp.terminator + !$omp end coexecute + !$omp end teams + !$omp end target +end subroutine target_teams_coexecute_m + +! CHECK-LABEL: func @_QPteams_coexecute_m +subroutine teams_coexecute_m() + ! CHECK: omp.teams + ! CHECK: omp.coexecute + !$omp teams + !$omp coexecute + ! CHECK: fir.call + call f1() + ! CHECK: omp.terminator + ! CHECK: omp.terminator + !$omp end coexecute + !$omp end teams +end subroutine teams_coexecute_m diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index 9b325057a01b7..f9d790a7bc987 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -2217,29 +2217,28 @@ def OMP_TargetTeams : Directive<[Spelling<"target teams">]> { } def OMP_TargetTeamsCoexecute : Directive<"target teams coexecute"> { let allowedClauses = [ - VersionedClause, - VersionedClause, - VersionedClause, + VersionedClause, VersionedClause, VersionedClause, - VersionedClause, VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, VersionedClause, - VersionedClause, - VersionedClause, VersionedClause, - VersionedClause, + VersionedClause, ]; - let allowedOnceClauses = [ + VersionedClause, + VersionedClause, VersionedClause, VersionedClause, - VersionedClause, - VersionedClause, VersionedClause, - VersionedClause, VersionedClause, VersionedClause, + VersionedClause, ]; let leafConstructs = [OMP_Target, OMP_Teams, OMP_Coexecute]; let category = CA_Executable; @@ -2532,20 +2531,20 @@ def OMP_TaskLoopSimd : Directive<[Spelling<"taskloop simd">]> { } def OMP_TeamsCoexecute : Directive<"teams coexecute"> { let allowedClauses = [ - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, VersionedClause, + VersionedClause, VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, ]; let allowedOnceClauses = [ VersionedClause, VersionedClause, VersionedClause, - VersionedClause + VersionedClause, ]; - let leafConstructs = [OMP_Target, OMP_Teams]; + let leafConstructs = [OMP_Teams, OMP_Coexecute]; let category = CA_Executable; } def OMP_TeamsDistribute : Directive<"teams distribute"> { From 479f166e3836f485ef6d7143fecda9ef1d8ce84c Mon Sep 17 00:00:00 2001 From: skc7 Date: Wed, 14 May 2025 14:48:52 +0530 Subject: [PATCH 06/29] [OpenMP] Use workdistribute instead of coexecute --- .../flang/Semantics/openmp-directive-sets.h | 24 ++-- flang/lib/Lower/OpenMP/OpenMP.cpp | 15 ++- flang/lib/Parser/openmp-parsers.cpp | 6 +- flang/lib/Semantics/resolve-directives.cpp | 12 +- flang/test/Lower/OpenMP/coexecute.f90 | 59 ---------- flang/test/Lower/OpenMP/workdistribute.f90 | 59 ++++++++++ llvm/include/llvm/Frontend/OpenMP/OMP.td | 103 ++++++++++-------- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 28 ++--- 8 files changed, 153 insertions(+), 153 deletions(-) delete mode 100644 flang/test/Lower/OpenMP/coexecute.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute.f90 diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h index 43f4e642b3d86..7ced6ed9b44d6 100644 --- a/flang/include/flang/Semantics/openmp-directive-sets.h +++ b/flang/include/flang/Semantics/openmp-directive-sets.h @@ -143,7 +143,7 @@ static const OmpDirectiveSet topTargetSet{ Directive::OMPD_target_teams_distribute_parallel_do_simd, Directive::OMPD_target_teams_distribute_simd, Directive::OMPD_target_teams_loop, - Directive::OMPD_target_teams_coexecute, + Directive::OMPD_target_teams_workdistribute, }; static const OmpDirectiveSet allTargetSet{topTargetSet}; @@ -173,7 +173,7 @@ static const OmpDirectiveSet topTeamsSet{ Directive::OMPD_teams_distribute_parallel_do_simd, Directive::OMPD_teams_distribute_simd, Directive::OMPD_teams_loop, - Directive::OMPD_teams_coexecute, + Directive::OMPD_teams_workdistribute, }; static const OmpDirectiveSet bottomTeamsSet{ @@ -189,14 +189,14 @@ static const OmpDirectiveSet allTeamsSet{ Directive::OMPD_target_teams_distribute_parallel_do_simd, Directive::OMPD_target_teams_distribute_simd, Directive::OMPD_target_teams_loop, - Directive::OMPD_target_teams_coexecute, + Directive::OMPD_target_teams_workdistribute, } | topTeamsSet, }; -static const OmpDirectiveSet allCoexecuteSet{ - Directive::OMPD_coexecute, - Directive::OMPD_teams_coexecute, - Directive::OMPD_target_teams_coexecute, +static const OmpDirectiveSet allWorkdistributeSet{ + Directive::OMPD_workdistribute, + Directive::OMPD_teams_workdistribute, + Directive::OMPD_target_teams_workdistribute, }; //===----------------------------------------------------------------------===// @@ -239,9 +239,9 @@ static const OmpDirectiveSet blockConstructSet{ Directive::OMPD_taskgroup, Directive::OMPD_teams, Directive::OMPD_workshare, - Directive::OMPD_target_teams_coexecute, - Directive::OMPD_teams_coexecute, - Directive::OMPD_coexecute, + Directive::OMPD_target_teams_workdistribute, + Directive::OMPD_teams_workdistribute, + Directive::OMPD_workdistribute, }; static const OmpDirectiveSet loopConstructSet{ @@ -306,7 +306,7 @@ static const OmpDirectiveSet workShareSet{ Directive::OMPD_scope, Directive::OMPD_sections, Directive::OMPD_single, - Directive::OMPD_coexecute, + Directive::OMPD_workdistribute, } | allDoSet, }; @@ -389,7 +389,7 @@ static const OmpDirectiveSet nestedReduceWorkshareAllowedSet{ }; static const OmpDirectiveSet nestedTeamsAllowedSet{ - Directive::OMPD_coexecute, + Directive::OMPD_workdistribute, Directive::OMPD_distribute, Directive::OMPD_distribute_parallel_do, Directive::OMPD_distribute_parallel_do_simd, diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 0c436c15fd6c9..0133dbf912659 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -2682,14 +2682,14 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, queue, item, clauseOps); } -static mlir::omp::CoexecuteOp -genCoexecuteOp(lower::AbstractConverter &converter, lower::SymMap &symTable, +static mlir::omp::WorkdistributeOp +genWorkdistributeOp(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, mlir::Location loc, const ConstructQueue &queue, ConstructQueue::const_iterator item) { - return genOpWithBody( + return genOpWithBody( OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, - llvm::omp::Directive::OMPD_coexecute), queue, item); + llvm::omp::Directive::OMPD_workdistribute), queue, item); } //===----------------------------------------------------------------------===// @@ -3306,16 +3306,15 @@ static void genOMPDispatch(lower::AbstractConverter &converter, newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); break; - case llvm::omp::Directive::OMPD_coexecute: - newOp = genCoexecuteOp(converter, symTable, semaCtx, eval, loc, queue, item); - break; case llvm::omp::Directive::OMPD_tile: case llvm::omp::Directive::OMPD_unroll: { unsigned version = semaCtx.langOptions().OpenMPVersion; TODO(loc, "Unhandled loop directive (" + llvm::omp::getOpenMPDirectiveName(dir, version) + ")"); } - // case llvm::omp::Directive::OMPD_workdistribute: + case llvm::omp::Directive::OMPD_workdistribute: + newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue, item); + break; case llvm::omp::Directive::OMPD_workshare: newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp index ebf2bb0c19bfd..cd0eccffa43f5 100644 --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -1492,15 +1492,15 @@ TYPE_PARSER( "SINGLE" >> pure(llvm::omp::Directive::OMPD_single), "TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data), "TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel), - "TARGET TEAMS COEXECUTE" >> pure(llvm::omp::Directive::OMPD_target_teams_coexecute), + "TARGET TEAMS WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_target_teams_workdistribute), "TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams), "TARGET" >> pure(llvm::omp::Directive::OMPD_target), "TASK"_id >> pure(llvm::omp::Directive::OMPD_task), "TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup), - "TEAMS COEXECUTE" >> pure(llvm::omp::Directive::OMPD_teams_coexecute), + "TEAMS WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_teams_workdistribute), "TEAMS" >> pure(llvm::omp::Directive::OMPD_teams), "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare), - "COEXECUTE" >> pure(llvm::omp::Directive::OMPD_coexecute)))) + "WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_workdistribute)))) TYPE_PARSER(sourced(construct( sourced(Parser{}), Parser{}))) diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index 133d4a6c18f17..da3315ff1acfb 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -1656,9 +1656,9 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) { case llvm::omp::Directive::OMPD_task: case llvm::omp::Directive::OMPD_taskgroup: case llvm::omp::Directive::OMPD_teams: - case llvm::omp::Directive::OMPD_coexecute: - case llvm::omp::Directive::OMPD_teams_coexecute: - case llvm::omp::Directive::OMPD_target_teams_coexecute: + case llvm::omp::Directive::OMPD_workdistribute: + case llvm::omp::Directive::OMPD_teams_workdistribute: + case llvm::omp::Directive::OMPD_target_teams_workdistribute: case llvm::omp::Directive::OMPD_workshare: case llvm::omp::Directive::OMPD_parallel_workshare: case llvm::omp::Directive::OMPD_target_teams: @@ -1692,9 +1692,9 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) { case llvm::omp::Directive::OMPD_target: case llvm::omp::Directive::OMPD_task: case llvm::omp::Directive::OMPD_teams: - case llvm::omp::Directive::OMPD_coexecute: - case llvm::omp::Directive::OMPD_teams_coexecute: - case llvm::omp::Directive::OMPD_target_teams_coexecute: + case llvm::omp::Directive::OMPD_workdistribute: + case llvm::omp::Directive::OMPD_teams_workdistribute: + case llvm::omp::Directive::OMPD_target_teams_workdistribute: case llvm::omp::Directive::OMPD_parallel_workshare: case llvm::omp::Directive::OMPD_target_teams: case llvm::omp::Directive::OMPD_target_parallel: { diff --git a/flang/test/Lower/OpenMP/coexecute.f90 b/flang/test/Lower/OpenMP/coexecute.f90 deleted file mode 100644 index b14f71f9bbbfa..0000000000000 --- a/flang/test/Lower/OpenMP/coexecute.f90 +++ /dev/null @@ -1,59 +0,0 @@ -! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s - -! CHECK-LABEL: func @_QPtarget_teams_coexecute -subroutine target_teams_coexecute() - ! CHECK: omp.target - ! CHECK: omp.teams - ! CHECK: omp.coexecute - !$omp target teams coexecute - ! CHECK: fir.call - call f1() - ! CHECK: omp.terminator - ! CHECK: omp.terminator - ! CHECK: omp.terminator - !$omp end target teams coexecute -end subroutine target_teams_coexecute - -! CHECK-LABEL: func @_QPteams_coexecute -subroutine teams_coexecute() - ! CHECK: omp.teams - ! CHECK: omp.coexecute - !$omp teams coexecute - ! CHECK: fir.call - call f1() - ! CHECK: omp.terminator - ! CHECK: omp.terminator - !$omp end teams coexecute -end subroutine teams_coexecute - -! CHECK-LABEL: func @_QPtarget_teams_coexecute_m -subroutine target_teams_coexecute_m() - ! CHECK: omp.target - ! CHECK: omp.teams - ! CHECK: omp.coexecute - !$omp target - !$omp teams - !$omp coexecute - ! CHECK: fir.call - call f1() - ! CHECK: omp.terminator - ! CHECK: omp.terminator - ! CHECK: omp.terminator - !$omp end coexecute - !$omp end teams - !$omp end target -end subroutine target_teams_coexecute_m - -! CHECK-LABEL: func @_QPteams_coexecute_m -subroutine teams_coexecute_m() - ! CHECK: omp.teams - ! CHECK: omp.coexecute - !$omp teams - !$omp coexecute - ! CHECK: fir.call - call f1() - ! CHECK: omp.terminator - ! CHECK: omp.terminator - !$omp end coexecute - !$omp end teams -end subroutine teams_coexecute_m diff --git a/flang/test/Lower/OpenMP/workdistribute.f90 b/flang/test/Lower/OpenMP/workdistribute.f90 new file mode 100644 index 0000000000000..924205bb72e5e --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute.f90 @@ -0,0 +1,59 @@ +! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute() + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.workdistribute + !$omp target teams workdistribute + ! CHECK: fir.call + call f1() + ! CHECK: omp.terminator + ! CHECK: omp.terminator + ! CHECK: omp.terminator + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute() + ! CHECK: omp.teams + ! CHECK: omp.workdistribute + !$omp teams workdistribute + ! CHECK: fir.call + call f1() + ! CHECK: omp.terminator + ! CHECK: omp.terminator + !$omp end teams workdistribute +end subroutine teams_workdistribute + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute_m +subroutine target_teams_workdistribute_m() + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.workdistribute + !$omp target + !$omp teams + !$omp workdistribute + ! CHECK: fir.call + call f1() + ! CHECK: omp.terminator + ! CHECK: omp.terminator + ! CHECK: omp.terminator + !$omp end workdistribute + !$omp end teams + !$omp end target +end subroutine target_teams_workdistribute_m + +! CHECK-LABEL: func @_QPteams_workdistribute_m +subroutine teams_workdistribute_m() + ! CHECK: omp.teams + ! CHECK: omp.workdistribute + !$omp teams + !$omp workdistribute + ! CHECK: fir.call + call f1() + ! CHECK: omp.terminator + ! CHECK: omp.terminator + !$omp end workdistribute + !$omp end teams +end subroutine teams_workdistribute_m diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index f9d790a7bc987..b6d92d572206a 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -1295,6 +1295,15 @@ def OMP_EndWorkshare : Directive<[Spelling<"end workshare">]> { let category = OMP_Workshare.category; let languages = [L_Fortran]; } +def OMP_Workdistribute : Directive<"workdistribute"> { + let association = AS_Block; + let category = CA_Executable; +} +def OMP_EndWorkdistribute : Directive<"end workdistribute"> { + let leafConstructs = OMP_Workdistribute.leafConstructs; + let association = OMP_Workdistribute.association; + let category = OMP_Workdistribute.category; +} //===----------------------------------------------------------------------===// // Definitions of OpenMP compound directives @@ -2215,34 +2224,6 @@ def OMP_TargetTeams : Directive<[Spelling<"target teams">]> { let leafConstructs = [OMP_Target, OMP_Teams]; let category = CA_Executable; } -def OMP_TargetTeamsCoexecute : Directive<"target teams coexecute"> { - let allowedClauses = [ - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - ]; - let allowedOnceClauses = [ - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - ]; - let leafConstructs = [OMP_Target, OMP_Teams, OMP_Coexecute]; - let category = CA_Executable; -} def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> { let allowedClauses = [ VersionedClause, @@ -2465,7 +2446,35 @@ def OMP_TargetTeamsDistributeSimd let leafConstructs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Simd]; let category = CA_Executable; } -def OMP_target_teams_loop : Directive<[Spelling<"target teams loop">]> { +def OMP_TargetTeamsWorkdistribute : Directive<"target teams workdistribute"> { + let allowedClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let allowedOnceClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let leafConstructs = [OMP_Target, OMP_Teams, OMP_Workdistribute]; + let category = CA_Executable; +} +def OMP_target_teams_loop : Directive<"target teams loop"> { let allowedClauses = [ VersionedClause, VersionedClause, @@ -2529,24 +2538,6 @@ def OMP_TaskLoopSimd : Directive<[Spelling<"taskloop simd">]> { let leafConstructs = [OMP_TaskLoop, OMP_Simd]; let category = CA_Executable; } -def OMP_TeamsCoexecute : Directive<"teams coexecute"> { - let allowedClauses = [ - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - ]; - let allowedOnceClauses = [ - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - ]; - let leafConstructs = [OMP_Teams, OMP_Coexecute]; - let category = CA_Executable; -} def OMP_TeamsDistribute : Directive<"teams distribute"> { let allowedClauses = [ VersionedClause, @@ -2734,3 +2725,21 @@ def OMP_teams_loop : Directive<[Spelling<"teams loop">]> { let leafConstructs = [OMP_Teams, OMP_loop]; let category = CA_Executable; } +def OMP_TeamsWorkdistribute : Directive<"teams workdistribute"> { + let allowedClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let allowedOnceClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let leafConstructs = [OMP_Teams, OMP_Workdistribute]; + let category = CA_Executable; +} diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index bb61a46e13d6b..8d65b37330eb8 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -326,38 +326,30 @@ def SectionsOp : OpenMP_Op<"sections", traits = [ } //===----------------------------------------------------------------------===// -// Coexecute Construct +// workdistribute Construct //===----------------------------------------------------------------------===// -def CoexecuteOp : OpenMP_Op<"coexecute"> { - let summary = "coexecute directive"; +def WorkdistributeOp : OpenMP_Op<"workdistribute"> { + let summary = "workdistribute directive"; let description = [{ - The coexecute construct specifies that the teams from the teams directive - this is nested in shall cooperate to execute the computation in this region. - There is no implicit barrier at the end as specified in the standard. - - TODO - We should probably change the defaut behaviour to have a barrier unless - nowait is specified, see below snippet. + workdistribute divides execution of the enclosed structured block into + separate units of work, each executed only once by each + initial thread in the league. ``` !$omp target teams - !$omp coexecute + !$omp workdistribute tmp = matmul(x, y) - !$omp end coexecute + !$omp end workdistribute a = tmp(0, 0) ! there is no implicit barrier! the matmul hasnt completed! - !$omp end target teams coexecute + !$omp end target teams workdistribute ``` }]; - let arguments = (ins UnitAttr:$nowait); - let regions = (region AnyRegion:$region); - let assemblyFormat = [{ - oilist(`nowait` $nowait) $region attr-dict - }]; + let assemblyFormat = "$region attr-dict"; } //===----------------------------------------------------------------------===// From f7d6a3b9b30a82ca7849f47343bd94a430630efe Mon Sep 17 00:00:00 2001 From: skc7 Date: Wed, 14 May 2025 16:17:14 +0530 Subject: [PATCH 07/29] [OpenMP] workdistribute trivial lowering Lowering logic inspired from ivanradanov coexeute lowering f56da1a207df4a40776a8570122a33f047074a3c --- .../include/flang/Optimizer/OpenMP/Passes.td | 4 + flang/lib/Optimizer/OpenMP/CMakeLists.txt | 1 + .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 101 ++++++++++++++++++ .../OpenMP/lower-workdistribute.mlir | 52 +++++++++ 4 files changed, 158 insertions(+) create mode 100644 flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute.mlir diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td index 704faf0ccd856..743b6d381ed42 100644 --- a/flang/include/flang/Optimizer/OpenMP/Passes.td +++ b/flang/include/flang/Optimizer/OpenMP/Passes.td @@ -93,6 +93,10 @@ def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> { let summary = "Lower workshare construct"; } +def LowerWorkdistribute : Pass<"lower-workdistribute", "::mlir::ModuleOp"> { + let summary = "Lower workdistribute construct"; +} + def GenericLoopConversionPass : Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> { let summary = "Converts OpenMP generic `omp.loop` to semantically " diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt index e31543328a9f9..cd746834741f9 100644 --- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt @@ -7,6 +7,7 @@ add_flang_library(FlangOpenMPTransforms MapsForPrivatizedSymbols.cpp MapInfoFinalization.cpp MarkDeclareTarget.cpp + LowerWorkdistribute.cpp LowerWorkshare.cpp LowerNontemporal.cpp diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp new file mode 100644 index 0000000000000..75c9d2b0d494e --- /dev/null +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -0,0 +1,101 @@ +//===- LowerWorkshare.cpp - special cases for bufferization -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the lowering of omp.workdistribute. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include + +namespace flangomp { +#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE +#include "flang/Optimizer/OpenMP/Passes.h.inc" +} // namespace flangomp + +#define DEBUG_TYPE "lower-workdistribute" + +using namespace mlir; + +namespace { + +struct WorkdistributeToSingle : public mlir::OpRewritePattern { +using OpRewritePattern::OpRewritePattern; +mlir::LogicalResult + matchAndRewrite(mlir::omp::WorkdistributeOp workdistribute, + mlir::PatternRewriter &rewriter) const override { + auto loc = workdistribute->getLoc(); + auto teams = llvm::dyn_cast(workdistribute->getParentOp()); + if (!teams) { + mlir::emitError(loc, "workdistribute not nested in teams\n"); + return mlir::failure(); + } + if (workdistribute.getRegion().getBlocks().size() != 1) { + mlir::emitError(loc, "workdistribute with multiple blocks\n"); + return mlir::failure(); + } + if (teams.getRegion().getBlocks().size() != 1) { + mlir::emitError(loc, "teams with multiple blocks\n"); + return mlir::failure(); + } + if (teams.getRegion().getBlocks().front().getOperations().size() != 2) { + mlir::emitError(loc, "teams with multiple nested ops\n"); + return mlir::failure(); + } + mlir::Block *workdistributeBlock = &workdistribute.getRegion().front(); + rewriter.eraseOp(workdistributeBlock->getTerminator()); + rewriter.inlineBlockBefore(workdistributeBlock, teams); + rewriter.eraseOp(teams); + return mlir::success(); + } +}; + +class LowerWorkdistributePass + : public flangomp::impl::LowerWorkdistributeBase { +public: + void runOnOperation() override { + mlir::MLIRContext &context = getContext(); + mlir::RewritePatternSet patterns(&context); + mlir::GreedyRewriteConfig config; + // prevent the pattern driver form merging blocks + config.setRegionSimplificationLevel( + mlir::GreedySimplifyRegionLevel::Disabled); + + patterns.insert(&context); + mlir::Operation *op = getOperation(); + if (mlir::failed(mlir::applyPatternsGreedily(op, std::move(patterns), config))) { + mlir::emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); + signalPassFailure(); + } + } +}; +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute.mlir new file mode 100644 index 0000000000000..34c8c3f01976d --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute.mlir @@ -0,0 +1,52 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @_QPtarget_simple() { +// CHECK: %[[VAL_0:.*]] = arith.constant 2 : i32 +// CHECK: %[[VAL_1:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtarget_simpleEa"} +// CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref) -> (!fir.ref, !fir.ref) +// CHECK: %[[VAL_3:.*]] = fir.alloca !fir.box> {bindc_name = "simple_var", uniq_name = "_QFtarget_simpleEsimple_var"} +// CHECK: %[[VAL_4:.*]] = fir.zero_bits !fir.heap +// CHECK: %[[VAL_5:.*]] = fir.embox %[[VAL_4]] : (!fir.heap) -> !fir.box> +// CHECK: fir.store %[[VAL_5]] to %[[VAL_3]] : !fir.ref>> +// CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_3]] {fortran_attrs = #fir.var_attrs, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref>>) -> (!fir.ref>>, !fir.ref>>) +// CHECK: hlfir.assign %[[VAL_0]] to %[[VAL_2]]#0 : i32, !fir.ref +// CHECK: %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_2]]#1 : !fir.ref, i32) map_clauses(to) capture(ByRef) -> !fir.ref {name = "a"} +// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_8:.*]] : !fir.ref) private(@_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 %[[VAL_6]]#0 -> %[[VAL_9:.*]] : !fir.ref>>) { +// CHECK: %[[VAL_10:.*]] = arith.constant 10 : i32 +// CHECK: %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_8]] {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref) -> (!fir.ref, !fir.ref) +// CHECK: %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_9]] {fortran_attrs = #fir.var_attrs, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref>>) -> (!fir.ref>>, !fir.ref>>) +// CHECK: %[[VAL_13:.*]] = fir.load %[[VAL_11]]#0 : !fir.ref +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_10]] : i32 +// CHECK: hlfir.assign %[[VAL_14]] to %[[VAL_12]]#0 realloc : i32, !fir.ref>> +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } +func.func @_QPtarget_simple() { + %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtarget_simpleEa"} + %1:2 = hlfir.declare %0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %2 = fir.alloca !fir.box> {bindc_name = "simple_var", uniq_name = "_QFtarget_simpleEsimple_var"} + %3 = fir.zero_bits !fir.heap + %4 = fir.embox %3 : (!fir.heap) -> !fir.box> + fir.store %4 to %2 : !fir.ref>> + %5:2 = hlfir.declare %2 {fortran_attrs = #fir.var_attrs, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref>>) -> (!fir.ref>>, !fir.ref>>) + %c2_i32 = arith.constant 2 : i32 + hlfir.assign %c2_i32 to %1#0 : i32, !fir.ref + %6 = omp.map.info var_ptr(%1#1 : !fir.ref, i32) map_clauses(to) capture(ByRef) -> !fir.ref {name = "a"} + omp.target map_entries(%6 -> %arg0 : !fir.ref) private(@_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 %5#0 -> %arg1 : !fir.ref>>){ + omp.teams { + omp.workdistribute { + %11:2 = hlfir.declare %arg0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %12:2 = hlfir.declare %arg1 {fortran_attrs = #fir.var_attrs, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref>>) -> (!fir.ref>>, !fir.ref>>) + %c10_i32 = arith.constant 10 : i32 + %13 = fir.load %11#0 : !fir.ref + %14 = arith.addi %c10_i32, %13 : i32 + hlfir.assign %14 to %12#0 realloc : i32, !fir.ref>> + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} \ No newline at end of file From 46001010c8dca18dcb5d3c1edcb109b39709ce52 Mon Sep 17 00:00:00 2001 From: skc7 Date: Wed, 14 May 2025 19:29:33 +0530 Subject: [PATCH 08/29] [Flang][OpenMP] Add workdistribute lower pass to pipeline --- flang/lib/Optimizer/Passes/Pipelines.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 70f57bdeddd3f..c63e3799be650 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -288,8 +288,10 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP, addNestedPassToAllTopLevelOperations( pm, hlfir::createInlineHLFIRAssign); pm.addPass(hlfir::createConvertHLFIRtoFIR()); - if (enableOpenMP) + if (enableOpenMP) { pm.addPass(flangomp::createLowerWorkshare()); + pm.addPass(flangomp::createLowerWorkdistribute()); + } } /// Create a pass pipeline for handling certain OpenMP transformations needed From 3ad6d57b2b68304012bb6917e881b3fafffe651b Mon Sep 17 00:00:00 2001 From: skc7 Date: Thu, 15 May 2025 16:39:21 +0530 Subject: [PATCH 09/29] [Flang][OpenMP] Add FissionWorkdistribute lowering. Fission logic inspired from ivanradanov implementation : c97eca4010e460aac5a3d795614ca0980bce4565 --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 233 ++++++++++++++---- .../OpenMP/lower-workdistribute-fission.mlir | 60 +++++ ...ir => lower-workdistribute-to-single.mlir} | 2 +- 3 files changed, 243 insertions(+), 52 deletions(-) create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir rename flang/test/Transforms/OpenMP/{lower-workdistribute.mlir => lower-workdistribute-to-single.mlir} (99%) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 75c9d2b0d494e..f799202be2645 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -10,31 +10,26 @@ // //===----------------------------------------------------------------------===// -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "flang/Optimizer/HLFIR/Passes.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include #include -#include -#include -#include +#include +#include #include +#include #include -#include #include -#include -#include #include #include -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - +#include #include namespace flangomp { @@ -48,52 +43,188 @@ using namespace mlir; namespace { -struct WorkdistributeToSingle : public mlir::OpRewritePattern { -using OpRewritePattern::OpRewritePattern; -mlir::LogicalResult - matchAndRewrite(mlir::omp::WorkdistributeOp workdistribute, - mlir::PatternRewriter &rewriter) const override { - auto loc = workdistribute->getLoc(); - auto teams = llvm::dyn_cast(workdistribute->getParentOp()); - if (!teams) { - mlir::emitError(loc, "workdistribute not nested in teams\n"); - return mlir::failure(); - } - if (workdistribute.getRegion().getBlocks().size() != 1) { - mlir::emitError(loc, "workdistribute with multiple blocks\n"); - return mlir::failure(); +template +static T getPerfectlyNested(Operation *op) { + if (op->getNumRegions() != 1) + return nullptr; + auto ®ion = op->getRegion(0); + if (region.getBlocks().size() != 1) + return nullptr; + auto *block = ®ion.front(); + auto *firstOp = &block->front(); + if (auto nested = dyn_cast(firstOp)) + if (firstOp->getNextNode() == block->getTerminator()) + return nested; + return nullptr; +} + +/// This is the single source of truth about whether we should parallelize an +/// operation nested in an omp.workdistribute region. +static bool shouldParallelize(Operation *op) { + // Currently we cannot parallelize operations with results that have uses + if (llvm::any_of(op->getResults(), + [](OpResult v) -> bool { return !v.use_empty(); })) + return false; + // We will parallelize unordered loops - these come from array syntax + if (auto loop = dyn_cast(op)) { + auto unordered = loop.getUnordered(); + if (!unordered) + return false; + return *unordered; + } + if (auto callOp = dyn_cast(op)) { + auto callee = callOp.getCallee(); + if (!callee) + return false; + auto *func = op->getParentOfType().lookupSymbol(*callee); + // TODO need to insert a check here whether it is a call we can actually + // parallelize currently + if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName())) + return true; + return false; + } + // We cannot parallise anything else + return false; +} + +struct WorkdistributeToSingle : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(omp::TeamsOp teamsOp, + PatternRewriter &rewriter) const override { + auto workdistributeOp = getPerfectlyNested(teamsOp); + if (!workdistributeOp) { + LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n"); + return failure(); } - if (teams.getRegion().getBlocks().size() != 1) { - mlir::emitError(loc, "teams with multiple blocks\n"); - return mlir::failure(); + + Block *workdistributeBlock = &workdistributeOp.getRegion().front(); + rewriter.eraseOp(workdistributeBlock->getTerminator()); + rewriter.inlineBlockBefore(workdistributeBlock, teamsOp); + rewriter.eraseOp(teamsOp); + workdistributeOp.emitWarning("unable to parallelize coexecute"); + return success(); + } +}; + +/// If B() and D() are parallelizable, +/// +/// omp.teams { +/// omp.workdistribute { +/// A() +/// B() +/// C() +/// D() +/// E() +/// } +/// } +/// +/// becomes +/// +/// A() +/// omp.teams { +/// omp.workdistribute { +/// B() +/// } +/// } +/// C() +/// omp.teams { +/// omp.workdistribute { +/// D() +/// } +/// } +/// E() + +struct FissionWorkdistribute + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult + matchAndRewrite(omp::WorkdistributeOp workdistribute, + PatternRewriter &rewriter) const override { + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast(workdistribute->getParentOp()); + if (!teams) { + emitError(loc, "workdistribute not nested in teams\n"); + return failure(); + } + if (workdistribute.getRegion().getBlocks().size() != 1) { + emitError(loc, "workdistribute with multiple blocks\n"); + return failure(); + } + if (teams.getRegion().getBlocks().size() != 1) { + emitError(loc, "teams with multiple blocks\n"); + return failure(); + } + if (teams.getRegion().getBlocks().front().getOperations().size() != 2) { + emitError(loc, "teams with multiple nested ops\n"); + return failure(); + } + + auto *teamsBlock = &teams.getRegion().front(); + + // While we have unhandled operations in the original workdistribute + auto *workdistributeBlock = &workdistribute.getRegion().front(); + auto *terminator = workdistributeBlock->getTerminator(); + bool changed = false; + while (&workdistributeBlock->front() != terminator) { + rewriter.setInsertionPoint(teams); + IRMapping mapping; + llvm::SmallVector hoisted; + Operation *parallelize = nullptr; + for (auto &op : workdistribute.getOps()) { + if (&op == terminator) { + break; } - if (teams.getRegion().getBlocks().front().getOperations().size() != 2) { - mlir::emitError(loc, "teams with multiple nested ops\n"); - return mlir::failure(); + if (shouldParallelize(&op)) { + parallelize = &op; + break; + } else { + rewriter.clone(op, mapping); + hoisted.push_back(&op); + changed = true; } - mlir::Block *workdistributeBlock = &workdistribute.getRegion().front(); - rewriter.eraseOp(workdistributeBlock->getTerminator()); - rewriter.inlineBlockBefore(workdistributeBlock, teams); - rewriter.eraseOp(teams); - return mlir::success(); + } + + for (auto *op : hoisted) + rewriter.replaceOp(op, mapping.lookup(op)); + + if (parallelize && hoisted.empty() && + parallelize->getNextNode() == terminator) + break; + if (parallelize) { + auto newTeams = rewriter.cloneWithoutRegions(teams); + auto *newTeamsBlock = rewriter.createBlock( + &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {}); + for (auto arg : teamsBlock->getArguments()) + newTeamsBlock->addArgument(arg.getType(), arg.getLoc()); + auto newWorkdistribute = rewriter.create(loc); + rewriter.create(loc); + rewriter.createBlock(&newWorkdistribute.getRegion(), + newWorkdistribute.getRegion().begin(), {}, {}); + auto *cloned = rewriter.clone(*parallelize); + rewriter.replaceOp(parallelize, cloned); + rewriter.create(loc); + changed = true; + } } + return success(changed); + } }; class LowerWorkdistributePass : public flangomp::impl::LowerWorkdistributeBase { public: void runOnOperation() override { - mlir::MLIRContext &context = getContext(); - mlir::RewritePatternSet patterns(&context); - mlir::GreedyRewriteConfig config; + MLIRContext &context = getContext(); + RewritePatternSet patterns(&context); + GreedyRewriteConfig config; // prevent the pattern driver form merging blocks config.setRegionSimplificationLevel( - mlir::GreedySimplifyRegionLevel::Disabled); + GreedySimplifyRegionLevel::Disabled); - patterns.insert(&context); - mlir::Operation *op = getOperation(); - if (mlir::failed(mlir::applyPatternsGreedily(op, std::move(patterns), config))) { - mlir::emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); + patterns.insert(&context); + Operation *op = getOperation(); + if (failed(applyPatternsGreedily(op, std::move(patterns), config))) { + emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); signalPassFailure(); } } diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir new file mode 100644 index 0000000000000..ea03a10dd3d44 --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir @@ -0,0 +1,60 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @test_fission_workdistribute({{.*}}) { +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 9 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32 +// CHECK: fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref +// CHECK: fir.do_loop %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] unordered { +// CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref +// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref +// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref +// CHECK: fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref +// CHECK: } +// CHECK: fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref) -> () +// CHECK: fir.call @my_fir_parallel_runtime_func(%[[ARG3:.*]]) : (!fir.ref) -> () +// CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] { +// CHECK: %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_8]] : (!fir.ref>, index) -> !fir.ref +// CHECK: fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref +// CHECK: } +// CHECK: %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref +// CHECK: fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref +// CHECK: return +// CHECK: } +module { +func.func @regular_side_effect_func(%arg0: !fir.ref) { + return +} +func.func @my_fir_parallel_runtime_func(%arg0: !fir.ref) attributes {fir.runtime} { + return +} +func.func @test_fission_workdistribute(%arr1: !fir.ref>, %arr2: !fir.ref>, %scalar_ref1: !fir.ref, %scalar_ref2: !fir.ref) { + %c0_idx = arith.constant 0 : index + %c1_idx = arith.constant 1 : index + %c9_idx = arith.constant 9 : index + %float_val = arith.constant 5.0 : f32 + omp.teams { + omp.workdistribute { + fir.store %float_val to %scalar_ref1 : !fir.ref + fir.do_loop %iv = %c0_idx to %c9_idx step %c1_idx unordered { + %elem_ptr_arr1 = fir.coordinate_of %arr1, %iv : (!fir.ref>, index) -> !fir.ref + %loaded_val_loop1 = fir.load %elem_ptr_arr1 : !fir.ref + %elem_ptr_arr2 = fir.coordinate_of %arr2, %iv : (!fir.ref>, index) -> !fir.ref + fir.store %loaded_val_loop1 to %elem_ptr_arr2 : !fir.ref + } + fir.call @regular_side_effect_func(%scalar_ref1) : (!fir.ref) -> () + fir.call @my_fir_parallel_runtime_func(%scalar_ref2) : (!fir.ref) -> () + fir.do_loop %jv = %c0_idx to %c9_idx step %c1_idx { + %elem_ptr_ordered_loop = fir.coordinate_of %arr1, %jv : (!fir.ref>, index) -> !fir.ref + fir.store %float_val to %elem_ptr_ordered_loop : !fir.ref + } + %loaded_for_hoist = fir.load %scalar_ref1 : !fir.ref + fir.store %loaded_for_hoist to %scalar_ref2 : !fir.ref + omp.terminator + } + omp.terminator + } + return +} +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir similarity index 99% rename from flang/test/Transforms/OpenMP/lower-workdistribute.mlir rename to flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir index 34c8c3f01976d..0cc2aeded2532 100644 --- a/flang/test/Transforms/OpenMP/lower-workdistribute.mlir +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir @@ -49,4 +49,4 @@ func.func @_QPtarget_simple() { omp.terminator } return -} \ No newline at end of file +} From 7becfeede0bea64a137a5ef43e213213f3021462 Mon Sep 17 00:00:00 2001 From: skc7 Date: Sun, 18 May 2025 12:37:53 +0530 Subject: [PATCH 10/29] [OpenMP][Flang] Lower teams workdistribute do_loop to wsloop. Logic inspired from ivanradanov commit 5682e9ea7fcba64693f7cfdc0f1970fab2d7d4ae --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 177 +++++++++++++++--- .../OpenMP/lower-workdistribute-doloop.mlir | 28 +++ .../OpenMP/lower-workdistribute-fission.mlir | 22 ++- 3 files changed, 193 insertions(+), 34 deletions(-) create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index f799202be2645..de208a8190650 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -6,18 +6,22 @@ // //===----------------------------------------------------------------------===// // -// This file implements the lowering of omp.workdistribute. +// This file implements the lowering and optimisations of omp.workdistribute. // //===----------------------------------------------------------------------===// +#include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Transforms/Passes.h" #include "flang/Optimizer/HLFIR/Passes.h" +#include "flang/Optimizer/OpenMP/Utils.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include #include @@ -29,6 +33,7 @@ #include #include #include +#include "mlir/Transforms/RegionUtils.h" #include #include @@ -87,25 +92,6 @@ static bool shouldParallelize(Operation *op) { return false; } -struct WorkdistributeToSingle : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(omp::TeamsOp teamsOp, - PatternRewriter &rewriter) const override { - auto workdistributeOp = getPerfectlyNested(teamsOp); - if (!workdistributeOp) { - LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n"); - return failure(); - } - - Block *workdistributeBlock = &workdistributeOp.getRegion().front(); - rewriter.eraseOp(workdistributeBlock->getTerminator()); - rewriter.inlineBlockBefore(workdistributeBlock, teamsOp); - rewriter.eraseOp(teamsOp); - workdistributeOp.emitWarning("unable to parallelize coexecute"); - return success(); - } -}; - /// If B() and D() are parallelizable, /// /// omp.teams { @@ -210,22 +196,161 @@ struct FissionWorkdistribute } }; +static void +genLoopNestClauseOps(mlir::Location loc, + mlir::PatternRewriter &rewriter, + fir::DoLoopOp loop, + mlir::omp::LoopNestOperands &loopNestClauseOps) { + assert(loopNestClauseOps.loopLowerBounds.empty() && + "Loop nest bounds were already emitted!"); + loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound()); + loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound()); + loopNestClauseOps.loopSteps.push_back(loop.getStep()); + loopNestClauseOps.loopInclusive = rewriter.getUnitAttr(); +} + +static void +genWsLoopOp(mlir::PatternRewriter &rewriter, + fir::DoLoopOp doLoop, + const mlir::omp::LoopNestOperands &clauseOps) { + + auto wsloopOp = rewriter.create(doLoop.getLoc()); + rewriter.createBlock(&wsloopOp.getRegion()); + + auto loopNestOp = + rewriter.create(doLoop.getLoc(), clauseOps); + + // Clone the loop's body inside the loop nest construct using the + // mapped values. + rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(), + loopNestOp.getRegion().begin()); + Block *clonedBlock = &loopNestOp.getRegion().back(); + mlir::Operation *terminatorOp = clonedBlock->getTerminator(); + + // Erase fir.result op of do loop and create yield op. + if (auto resultOp = dyn_cast(terminatorOp)) { + rewriter.setInsertionPoint(terminatorOp); + rewriter.create(doLoop->getLoc()); + rewriter.eraseOp(terminatorOp); + } + return; +} + +/// If fir.do_loop id present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// fir.do_loop unoredered { +/// ... +/// } +/// } +/// } +/// +/// Then, its lowered to +/// +/// omp.teams { +/// omp.workdistribute { +/// omp.parallel { +/// omp.wsloop { +/// omp.loop_nest +/// ... +/// } +/// } +/// } +/// } +/// } + +struct TeamsWorkdistributeLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(omp::TeamsOp teamsOp, + PatternRewriter &rewriter) const override { + auto teamsLoc = teamsOp->getLoc(); + auto workdistributeOp = getPerfectlyNested(teamsOp); + if (!workdistributeOp) { + LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n"); + return failure(); + } + assert(teamsOp.getReductionVars().empty()); + + auto doLoop = getPerfectlyNested(workdistributeOp); + if (doLoop && shouldParallelize(doLoop)) { + + auto parallelOp = rewriter.create(teamsLoc); + rewriter.createBlock(¶llelOp.getRegion()); + rewriter.setInsertionPoint(rewriter.create(doLoop.getLoc())); + + mlir::omp::LoopNestOperands loopNestClauseOps; + genLoopNestClauseOps(doLoop.getLoc(), rewriter, doLoop, + loopNestClauseOps); + + genWsLoopOp(rewriter, doLoop, loopNestClauseOps); + rewriter.setInsertionPoint(doLoop); + rewriter.eraseOp(doLoop); + return success(); + } + return failure(); + } +}; + + +/// If A() and B () are present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// A() +/// B() +/// } +/// } +/// +/// Then, its lowered to +/// +/// A() +/// B() +/// + +struct TeamsWorkdistributeToSingle : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(omp::TeamsOp teamsOp, + PatternRewriter &rewriter) const override { + auto workdistributeOp = getPerfectlyNested(teamsOp); + if (!workdistributeOp) { + LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n"); + return failure(); + } + Block *workdistributeBlock = &workdistributeOp.getRegion().front(); + rewriter.eraseOp(workdistributeBlock->getTerminator()); + rewriter.inlineBlockBefore(workdistributeBlock, teamsOp); + rewriter.eraseOp(teamsOp); + return success(); + } +}; + class LowerWorkdistributePass : public flangomp::impl::LowerWorkdistributeBase { public: void runOnOperation() override { MLIRContext &context = getContext(); - RewritePatternSet patterns(&context); GreedyRewriteConfig config; // prevent the pattern driver form merging blocks config.setRegionSimplificationLevel( GreedySimplifyRegionLevel::Disabled); - - patterns.insert(&context); + Operation *op = getOperation(); - if (failed(applyPatternsGreedily(op, std::move(patterns), config))) { - emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); - signalPassFailure(); + { + RewritePatternSet patterns(&context); + patterns.insert(&context); + if (failed(applyPatternsGreedily(op, std::move(patterns), config))) { + emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); + signalPassFailure(); + } + } + { + RewritePatternSet patterns(&context); + patterns.insert(&context); + if (failed(applyPatternsGreedily(op, std::move(patterns), config))) { + emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); + signalPassFailure(); + } } } }; diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir new file mode 100644 index 0000000000000..666bdb3ced647 --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir @@ -0,0 +1,28 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @x({{.*}}) +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index +// CHECK: omp.parallel { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) { +// CHECK: fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref +// CHECK: omp.yield +// CHECK: } +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } +func.func @x(%lb : index, %ub : index, %step : index, %b : i1, %addr : !fir.ref) { + omp.teams { + omp.workdistribute { + fir.do_loop %iv = %lb to %ub step %step unordered { + %zero = arith.constant 0 : index + fir.store %zero to %addr : !fir.ref + } + omp.terminator + } + omp.terminator + } + return +} \ No newline at end of file diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir index ea03a10dd3d44..cf50d135d01ec 100644 --- a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir @@ -6,20 +6,26 @@ // CHECK: %[[VAL_2:.*]] = arith.constant 9 : index // CHECK: %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32 // CHECK: fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref -// CHECK: fir.do_loop %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] unordered { -// CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref -// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref -// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref -// CHECK: fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref +// CHECK: omp.parallel { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) { +// CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref +// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref +// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref +// CHECK: fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref +// CHECK: omp.yield +// CHECK: } +// CHECK: } +// CHECK: omp.terminator // CHECK: } // CHECK: fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref) -> () // CHECK: fir.call @my_fir_parallel_runtime_func(%[[ARG3:.*]]) : (!fir.ref) -> () // CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] { -// CHECK: %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_8]] : (!fir.ref>, index) -> !fir.ref +// CHECK: %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0]], %[[VAL_8]] : (!fir.ref>, index) -> !fir.ref // CHECK: fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref // CHECK: } -// CHECK: %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref -// CHECK: fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref +// CHECK: %[[VAL_10:.*]] = fir.load %[[ARG2]] : !fir.ref +// CHECK: fir.store %[[VAL_10]] to %[[ARG3]] : !fir.ref // CHECK: return // CHECK: } module { From b23969d36dc149e3ef86474400908c6079a5f15a Mon Sep 17 00:00:00 2001 From: skc7 Date: Mon, 19 May 2025 15:33:53 +0530 Subject: [PATCH 11/29] clang format --- flang/lib/Lower/OpenMP/OpenMP.cpp | 18 +-- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 108 +++++++++--------- flang/lib/Parser/openmp-parsers.cpp | 6 +- .../OpenMP/lower-workdistribute-doloop.mlir | 2 +- 4 files changed, 67 insertions(+), 67 deletions(-) diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 0133dbf912659..a956498f6e521 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -2682,14 +2682,15 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, queue, item, clauseOps); } -static mlir::omp::WorkdistributeOp -genWorkdistributeOp(lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, - mlir::Location loc, const ConstructQueue &queue, - ConstructQueue::const_iterator item) { +static mlir::omp::WorkdistributeOp genWorkdistributeOp( + lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, + mlir::Location loc, const ConstructQueue &queue, + ConstructQueue::const_iterator item) { return genOpWithBody( - OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, - llvm::omp::Directive::OMPD_workdistribute), queue, item); + OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, + llvm::omp::Directive::OMPD_workdistribute), + queue, item); } //===----------------------------------------------------------------------===// @@ -3313,7 +3314,8 @@ static void genOMPDispatch(lower::AbstractConverter &converter, llvm::omp::getOpenMPDirectiveName(dir, version) + ")"); } case llvm::omp::Directive::OMPD_workdistribute: - newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue, item); + newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue, + item); break; case llvm::omp::Directive::OMPD_workshare: newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc, diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index de208a8190650..f75d4d1988fd2 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -14,15 +14,16 @@ #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" -#include "flang/Optimizer/Transforms/Passes.h" #include "flang/Optimizer/HLFIR/Passes.h" #include "flang/Optimizer/OpenMP/Utils.h" +#include "flang/Optimizer/Transforms/Passes.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" #include #include #include @@ -33,7 +34,6 @@ #include #include #include -#include "mlir/Transforms/RegionUtils.h" #include #include @@ -66,30 +66,30 @@ static T getPerfectlyNested(Operation *op) { /// This is the single source of truth about whether we should parallelize an /// operation nested in an omp.workdistribute region. static bool shouldParallelize(Operation *op) { - // Currently we cannot parallelize operations with results that have uses - if (llvm::any_of(op->getResults(), - [](OpResult v) -> bool { return !v.use_empty(); })) + // Currently we cannot parallelize operations with results that have uses + if (llvm::any_of(op->getResults(), + [](OpResult v) -> bool { return !v.use_empty(); })) + return false; + // We will parallelize unordered loops - these come from array syntax + if (auto loop = dyn_cast(op)) { + auto unordered = loop.getUnordered(); + if (!unordered) return false; - // We will parallelize unordered loops - these come from array syntax - if (auto loop = dyn_cast(op)) { - auto unordered = loop.getUnordered(); - if (!unordered) - return false; - return *unordered; - } - if (auto callOp = dyn_cast(op)) { - auto callee = callOp.getCallee(); - if (!callee) - return false; - auto *func = op->getParentOfType().lookupSymbol(*callee); - // TODO need to insert a check here whether it is a call we can actually - // parallelize currently - if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName())) - return true; + return *unordered; + } + if (auto callOp = dyn_cast(op)) { + auto callee = callOp.getCallee(); + if (!callee) return false; - } - // We cannot parallise anything else + auto *func = op->getParentOfType().lookupSymbol(*callee); + // TODO need to insert a check here whether it is a call we can actually + // parallelize currently + if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName())) + return true; return false; + } + // We cannot parallise anything else + return false; } /// If B() and D() are parallelizable, @@ -120,12 +120,10 @@ static bool shouldParallelize(Operation *op) { /// } /// E() -struct FissionWorkdistribute - : public OpRewritePattern { +struct FissionWorkdistribute : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult - matchAndRewrite(omp::WorkdistributeOp workdistribute, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(omp::WorkdistributeOp workdistribute, + PatternRewriter &rewriter) const override { auto loc = workdistribute->getLoc(); auto teams = dyn_cast(workdistribute->getParentOp()); if (!teams) { @@ -185,7 +183,7 @@ struct FissionWorkdistribute auto newWorkdistribute = rewriter.create(loc); rewriter.create(loc); rewriter.createBlock(&newWorkdistribute.getRegion(), - newWorkdistribute.getRegion().begin(), {}, {}); + newWorkdistribute.getRegion().begin(), {}, {}); auto *cloned = rewriter.clone(*parallelize); rewriter.replaceOp(parallelize, cloned); rewriter.create(loc); @@ -197,8 +195,7 @@ struct FissionWorkdistribute }; static void -genLoopNestClauseOps(mlir::Location loc, - mlir::PatternRewriter &rewriter, +genLoopNestClauseOps(mlir::Location loc, mlir::PatternRewriter &rewriter, fir::DoLoopOp loop, mlir::omp::LoopNestOperands &loopNestClauseOps) { assert(loopNestClauseOps.loopLowerBounds.empty() && @@ -209,10 +206,8 @@ genLoopNestClauseOps(mlir::Location loc, loopNestClauseOps.loopInclusive = rewriter.getUnitAttr(); } -static void -genWsLoopOp(mlir::PatternRewriter &rewriter, - fir::DoLoopOp doLoop, - const mlir::omp::LoopNestOperands &clauseOps) { +static void genWsLoopOp(mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop, + const mlir::omp::LoopNestOperands &clauseOps) { auto wsloopOp = rewriter.create(doLoop.getLoc()); rewriter.createBlock(&wsloopOp.getRegion()); @@ -236,7 +231,7 @@ genWsLoopOp(mlir::PatternRewriter &rewriter, return; } -/// If fir.do_loop id present inside teams workdistribute +/// If fir.do_loop is present inside teams workdistribute /// /// omp.teams { /// omp.workdistribute { @@ -246,7 +241,7 @@ genWsLoopOp(mlir::PatternRewriter &rewriter, /// } /// } /// -/// Then, its lowered to +/// Then, its lowered to /// /// omp.teams { /// omp.workdistribute { @@ -277,7 +272,8 @@ struct TeamsWorkdistributeLowering : public OpRewritePattern { auto parallelOp = rewriter.create(teamsLoc); rewriter.createBlock(¶llelOp.getRegion()); - rewriter.setInsertionPoint(rewriter.create(doLoop.getLoc())); + rewriter.setInsertionPoint( + rewriter.create(doLoop.getLoc())); mlir::omp::LoopNestOperands loopNestClauseOps; genLoopNestClauseOps(doLoop.getLoc(), rewriter, doLoop, @@ -292,7 +288,6 @@ struct TeamsWorkdistributeLowering : public OpRewritePattern { } }; - /// If A() and B () are present inside teams workdistribute /// /// omp.teams { @@ -311,17 +306,17 @@ struct TeamsWorkdistributeLowering : public OpRewritePattern { struct TeamsWorkdistributeToSingle : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(omp::TeamsOp teamsOp, - PatternRewriter &rewriter) const override { - auto workdistributeOp = getPerfectlyNested(teamsOp); - if (!workdistributeOp) { - LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n"); - return failure(); - } - Block *workdistributeBlock = &workdistributeOp.getRegion().front(); - rewriter.eraseOp(workdistributeBlock->getTerminator()); - rewriter.inlineBlockBefore(workdistributeBlock, teamsOp); - rewriter.eraseOp(teamsOp); - return success(); + PatternRewriter &rewriter) const override { + auto workdistributeOp = getPerfectlyNested(teamsOp); + if (!workdistributeOp) { + LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n"); + return failure(); + } + Block *workdistributeBlock = &workdistributeOp.getRegion().front(); + rewriter.eraseOp(workdistributeBlock->getTerminator()); + rewriter.inlineBlockBefore(workdistributeBlock, teamsOp); + rewriter.eraseOp(teamsOp); + return success(); } }; @@ -332,13 +327,13 @@ class LowerWorkdistributePass MLIRContext &context = getContext(); GreedyRewriteConfig config; // prevent the pattern driver form merging blocks - config.setRegionSimplificationLevel( - GreedySimplifyRegionLevel::Disabled); - + config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled); + Operation *op = getOperation(); { RewritePatternSet patterns(&context); - patterns.insert(&context); + patterns.insert( + &context); if (failed(applyPatternsGreedily(op, std::move(patterns), config))) { emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); signalPassFailure(); @@ -346,7 +341,8 @@ class LowerWorkdistributePass } { RewritePatternSet patterns(&context); - patterns.insert(&context); + patterns.insert( + &context); if (failed(applyPatternsGreedily(op, std::move(patterns), config))) { emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); signalPassFailure(); @@ -354,4 +350,4 @@ class LowerWorkdistributePass } } }; -} +} // namespace diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp index cd0eccffa43f5..ad729932a5f00 100644 --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -1492,12 +1492,14 @@ TYPE_PARSER( "SINGLE" >> pure(llvm::omp::Directive::OMPD_single), "TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data), "TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel), - "TARGET TEAMS WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_target_teams_workdistribute), + "TARGET TEAMS WORKDISTRIBUTE" >> + pure(llvm::omp::Directive::OMPD_target_teams_workdistribute), "TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams), "TARGET" >> pure(llvm::omp::Directive::OMPD_target), "TASK"_id >> pure(llvm::omp::Directive::OMPD_task), "TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup), - "TEAMS WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_teams_workdistribute), + "TEAMS WORKDISTRIBUTE" >> + pure(llvm::omp::Directive::OMPD_teams_workdistribute), "TEAMS" >> pure(llvm::omp::Directive::OMPD_teams), "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare), "WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_workdistribute)))) diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir index 666bdb3ced647..9fb970246b90c 100644 --- a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir @@ -25,4 +25,4 @@ func.func @x(%lb : index, %ub : index, %step : index, %b : i1, %addr : !fir.ref< omp.terminator } return -} \ No newline at end of file +} From 44709698f5baea865de0c197d06eabcf36bad66e Mon Sep 17 00:00:00 2001 From: skc7 Date: Tue, 27 May 2025 16:24:26 +0530 Subject: [PATCH 12/29] update to workdistribute lowering --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 194 ++++++++++-------- .../OpenMP/lower-workdistribute-doloop.mlir | 19 +- .../OpenMP/lower-workdistribute-fission.mlir | 31 +-- 3 files changed, 139 insertions(+), 105 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index f75d4d1988fd2..c9c7827ace217 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -48,25 +48,21 @@ using namespace mlir; namespace { -template -static T getPerfectlyNested(Operation *op) { - if (op->getNumRegions() != 1) - return nullptr; - auto ®ion = op->getRegion(0); - if (region.getBlocks().size() != 1) - return nullptr; - auto *block = ®ion.front(); - auto *firstOp = &block->front(); - if (auto nested = dyn_cast(firstOp)) - if (firstOp->getNextNode() == block->getTerminator()) - return nested; - return nullptr; +static bool isRuntimeCall(Operation *op) { + if (auto callOp = dyn_cast(op)) { + auto callee = callOp.getCallee(); + if (!callee) + return false; + auto *func = op->getParentOfType().lookupSymbol(*callee); + if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName())) + return true; + } + return false; } /// This is the single source of truth about whether we should parallelize an -/// operation nested in an omp.workdistribute region. +/// operation nested in an omp.execute region. static bool shouldParallelize(Operation *op) { - // Currently we cannot parallelize operations with results that have uses if (llvm::any_of(op->getResults(), [](OpResult v) -> bool { return !v.use_empty(); })) return false; @@ -77,21 +73,28 @@ static bool shouldParallelize(Operation *op) { return false; return *unordered; } - if (auto callOp = dyn_cast(op)) { - auto callee = callOp.getCallee(); - if (!callee) - return false; - auto *func = op->getParentOfType().lookupSymbol(*callee); - // TODO need to insert a check here whether it is a call we can actually - // parallelize currently - if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName())) - return true; - return false; + if (isRuntimeCall(op)) { + return true; } // We cannot parallise anything else return false; } +template +static T getPerfectlyNested(Operation *op) { + if (op->getNumRegions() != 1) + return nullptr; + auto ®ion = op->getRegion(0); + if (region.getBlocks().size() != 1) + return nullptr; + auto *block = ®ion.front(); + auto *firstOp = &block->front(); + if (auto nested = dyn_cast(firstOp)) + if (firstOp->getNextNode() == block->getTerminator()) + return nested; + return nullptr; +} + /// If B() and D() are parallelizable, /// /// omp.teams { @@ -138,17 +141,33 @@ struct FissionWorkdistribute : public OpRewritePattern { emitError(loc, "teams with multiple blocks\n"); return failure(); } - if (teams.getRegion().getBlocks().front().getOperations().size() != 2) { - emitError(loc, "teams with multiple nested ops\n"); - return failure(); - } auto *teamsBlock = &teams.getRegion().front(); + bool changed = false; + // Move the ops inside teams and before workdistribute outside. + IRMapping irMapping; + llvm::SmallVector teamsHoisted; + for (auto &op : teams.getOps()) { + if (&op == workdistribute) { + break; + } + if (shouldParallelize(&op)) { + emitError(loc, + "teams has parallelize ops before first workdistribute\n"); + return failure(); + } else { + rewriter.setInsertionPoint(teams); + rewriter.clone(op, irMapping); + teamsHoisted.push_back(&op); + changed = true; + } + } + for (auto *op : teamsHoisted) + rewriter.replaceOp(op, irMapping.lookup(op)); // While we have unhandled operations in the original workdistribute auto *workdistributeBlock = &workdistribute.getRegion().front(); auto *terminator = workdistributeBlock->getTerminator(); - bool changed = false; while (&workdistributeBlock->front() != terminator) { rewriter.setInsertionPoint(teams); IRMapping mapping; @@ -194,9 +213,51 @@ struct FissionWorkdistribute : public OpRewritePattern { } }; +/// If fir.do_loop is present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// fir.do_loop unoredered { +/// ... +/// } +/// } +/// } +/// +/// Then, its lowered to +/// +/// omp.teams { +/// omp.parallel { +/// omp.distribute { +/// omp.wsloop { +/// omp.loop_nest +/// ... +/// } +/// } +/// } +/// } + +static void genParallelOp(Location loc, PatternRewriter &rewriter, + bool composite) { + auto parallelOp = rewriter.create(loc); + parallelOp.setComposite(composite); + rewriter.createBlock(¶llelOp.getRegion()); + rewriter.setInsertionPoint(rewriter.create(loc)); + return; +} + +static void genDistributeOp(Location loc, PatternRewriter &rewriter, + bool composite) { + mlir::omp::DistributeOperands distributeClauseOps; + auto distributeOp = + rewriter.create(loc, distributeClauseOps); + distributeOp.setComposite(composite); + auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion()); + rewriter.setInsertionPointToStart(distributeBlock); + return; +} + static void -genLoopNestClauseOps(mlir::Location loc, mlir::PatternRewriter &rewriter, - fir::DoLoopOp loop, +genLoopNestClauseOps(mlir::PatternRewriter &rewriter, fir::DoLoopOp loop, mlir::omp::LoopNestOperands &loopNestClauseOps) { assert(loopNestClauseOps.loopLowerBounds.empty() && "Loop nest bounds were already emitted!"); @@ -207,9 +268,11 @@ genLoopNestClauseOps(mlir::Location loc, mlir::PatternRewriter &rewriter, } static void genWsLoopOp(mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop, - const mlir::omp::LoopNestOperands &clauseOps) { + const mlir::omp::LoopNestOperands &clauseOps, + bool composite) { auto wsloopOp = rewriter.create(doLoop.getLoc()); + wsloopOp.setComposite(composite); rewriter.createBlock(&wsloopOp.getRegion()); auto loopNestOp = @@ -231,57 +294,20 @@ static void genWsLoopOp(mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop, return; } -/// If fir.do_loop is present inside teams workdistribute -/// -/// omp.teams { -/// omp.workdistribute { -/// fir.do_loop unoredered { -/// ... -/// } -/// } -/// } -/// -/// Then, its lowered to -/// -/// omp.teams { -/// omp.workdistribute { -/// omp.parallel { -/// omp.wsloop { -/// omp.loop_nest -/// ... -/// } -/// } -/// } -/// } -/// } - -struct TeamsWorkdistributeLowering : public OpRewritePattern { +struct WorkdistributeDoLower : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(omp::TeamsOp teamsOp, + LogicalResult matchAndRewrite(omp::WorkdistributeOp workdistribute, PatternRewriter &rewriter) const override { - auto teamsLoc = teamsOp->getLoc(); - auto workdistributeOp = getPerfectlyNested(teamsOp); - if (!workdistributeOp) { - LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n"); - return failure(); - } - assert(teamsOp.getReductionVars().empty()); - - auto doLoop = getPerfectlyNested(workdistributeOp); + auto doLoop = getPerfectlyNested(workdistribute); + auto wdLoc = workdistribute->getLoc(); if (doLoop && shouldParallelize(doLoop)) { - - auto parallelOp = rewriter.create(teamsLoc); - rewriter.createBlock(¶llelOp.getRegion()); - rewriter.setInsertionPoint( - rewriter.create(doLoop.getLoc())); - + assert(doLoop.getReduceOperands().empty()); + genParallelOp(wdLoc, rewriter, true); + genDistributeOp(wdLoc, rewriter, true); mlir::omp::LoopNestOperands loopNestClauseOps; - genLoopNestClauseOps(doLoop.getLoc(), rewriter, doLoop, - loopNestClauseOps); - - genWsLoopOp(rewriter, doLoop, loopNestClauseOps); - rewriter.setInsertionPoint(doLoop); - rewriter.eraseOp(doLoop); + genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps); + genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true); + rewriter.eraseOp(workdistribute); return success(); } return failure(); @@ -315,7 +341,7 @@ struct TeamsWorkdistributeToSingle : public OpRewritePattern { Block *workdistributeBlock = &workdistributeOp.getRegion().front(); rewriter.eraseOp(workdistributeBlock->getTerminator()); rewriter.inlineBlockBefore(workdistributeBlock, teamsOp); - rewriter.eraseOp(teamsOp); + rewriter.eraseOp(workdistributeOp); return success(); } }; @@ -332,8 +358,7 @@ class LowerWorkdistributePass Operation *op = getOperation(); { RewritePatternSet patterns(&context); - patterns.insert( - &context); + patterns.insert(&context); if (failed(applyPatternsGreedily(op, std::move(patterns), config))) { emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); signalPassFailure(); @@ -341,8 +366,7 @@ class LowerWorkdistributePass } { RewritePatternSet patterns(&context); - patterns.insert( - &context); + patterns.insert(&context); if (failed(applyPatternsGreedily(op, std::move(patterns), config))) { emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); signalPassFailure(); diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir index 9fb970246b90c..f8351bb64e6e8 100644 --- a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir @@ -2,13 +2,18 @@ // CHECK-LABEL: func.func @x({{.*}}) // CHECK: %[[VAL_0:.*]] = arith.constant 0 : index -// CHECK: omp.parallel { -// CHECK: omp.wsloop { -// CHECK: omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) { -// CHECK: fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref -// CHECK: omp.yield -// CHECK: } -// CHECK: } +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) { +// CHECK: fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} // CHECK: omp.terminator // CHECK: } // CHECK: return diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir index cf50d135d01ec..c562b7009664d 100644 --- a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir @@ -1,21 +1,26 @@ // RUN: fir-opt --lower-workdistribute %s | FileCheck %s -// CHECK-LABEL: func.func @test_fission_workdistribute({{.*}}) { +// CHECK-LABEL: func.func @test_fission_workdistribute( // CHECK: %[[VAL_0:.*]] = arith.constant 0 : index // CHECK: %[[VAL_1:.*]] = arith.constant 1 : index // CHECK: %[[VAL_2:.*]] = arith.constant 9 : index // CHECK: %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32 // CHECK: fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref -// CHECK: omp.parallel { -// CHECK: omp.wsloop { -// CHECK: omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) { -// CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref -// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref -// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref -// CHECK: fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref -// CHECK: omp.yield -// CHECK: } -// CHECK: } +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) { +// CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref +// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref +// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref +// CHECK: fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} // CHECK: omp.terminator // CHECK: } // CHECK: fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref) -> () @@ -24,8 +29,8 @@ // CHECK: %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0]], %[[VAL_8]] : (!fir.ref>, index) -> !fir.ref // CHECK: fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref // CHECK: } -// CHECK: %[[VAL_10:.*]] = fir.load %[[ARG2]] : !fir.ref -// CHECK: fir.store %[[VAL_10]] to %[[ARG3]] : !fir.ref +// CHECK: %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref +// CHECK: fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref // CHECK: return // CHECK: } module { From 0166ceb0ac9c4b651ac2e969eb55224e83a02ec1 Mon Sep 17 00:00:00 2001 From: skc7 Date: Wed, 28 May 2025 21:41:25 +0530 Subject: [PATCH 13/29] Fix basic-program.fir test. --- flang/test/Fir/basic-program.fir | 1 + 1 file changed, 1 insertion(+) diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir index 7ac8b92f48953..a611629eeb280 100644 --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -69,6 +69,7 @@ func.func @_QQmain() { // PASSES-NEXT: InlineHLFIRAssign // PASSES-NEXT: ConvertHLFIRtoFIR // PASSES-NEXT: LowerWorkshare +// PASSES-NEXT: LowerWorkdistribute // PASSES-NEXT: CSE // PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd // PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd From 7f8b3a29be467d998aa0591c77a75c68659f63db Mon Sep 17 00:00:00 2001 From: skc7 Date: Fri, 30 May 2025 15:12:46 +0530 Subject: [PATCH 14/29] Wrap omp.target with omp.target_data --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 88 +++++++++++++++++++ .../OpenMP/lower-workdistribute-target.mlir | 36 ++++++++ .../lower-workdistribute-to-single.mlir | 52 ----------- 3 files changed, 124 insertions(+), 52 deletions(-) create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir delete mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index c9c7827ace217..6509cc5014dd7 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -346,6 +346,85 @@ struct TeamsWorkdistributeToSingle : public OpRewritePattern { } }; +static std::optional> +getNestedOpToIsolate(omp::TargetOp targetOp) { + auto *targetBlock = &targetOp.getRegion().front(); + for (auto &op : *targetBlock) { + bool first = &op == &*targetBlock->begin(); + bool last = op.getNextNode() == targetBlock->getTerminator(); + if (first && last) + return std::nullopt; + + if (isa(&op)) + return {{&op, first, last}}; + } + return std::nullopt; +} + +struct SplitTargetResult { + omp::TargetOp targetOp; + omp::TargetDataOp dataOp; +}; + +/// If multiple coexecutes are nested in a target regions, we will need to split +/// the target region, but we want to preserve the data semantics of the +/// original data region and avoid unnecessary data movement at each of the +/// subkernels - we split the target region into a target_data{target} +/// nest where only the outer one moves the data +std::optional splitTargetData(omp::TargetOp targetOp, + RewriterBase &rewriter) { + + auto loc = targetOp->getLoc(); + if (targetOp.getMapVars().empty()) { + LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " target region has no data maps\n"); + return std::nullopt; + } + + // Collect all map_entries with capture(ByRef) + SmallVector byRefMapInfos; + SmallVector MapInfos; + for (auto opr : targetOp.getMapVars()) { + auto mapInfo = cast(opr.getDefiningOp()); + MapInfos.push_back(mapInfo); + if (mapInfo.getMapCaptureType() == omp::VariableCaptureKind::ByRef) + byRefMapInfos.push_back(opr); + } + + // Create the new omp.target_data op with these collected map_entries + auto targetLoc = targetOp.getLoc(); + rewriter.setInsertionPoint(targetOp); + auto device = targetOp.getDevice(); + auto ifExpr = targetOp.getIfExpr(); + auto deviceAddrVars = targetOp.getHasDeviceAddrVars(); + auto devicePtrVars = targetOp.getIsDevicePtrVars(); + auto targetDataOp = rewriter.create(loc, device, ifExpr, + mlir::ValueRange{byRefMapInfos}, + deviceAddrVars, + devicePtrVars); + + auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion()); + rewriter.create(loc); + rewriter.setInsertionPointToStart(taregtDataBlock); + + // Clone mapInfo ops inside omp.target_data region + IRMapping mapping; + for (auto mapInfo : MapInfos) { + rewriter.clone(*mapInfo, mapping); + } + // Clone omp.target from exisiting targetOp inside target_data region. + auto newTargetOp = rewriter.clone(*targetOp, mapping); + + // Erase TargetOp and its MapInfoOps + rewriter.eraseOp(targetOp); + + for (auto mapInfo : MapInfos) { + auto mapInfoRes = mapInfo.getResult(); + if (mapInfoRes.getUsers().empty()) + rewriter.eraseOp(mapInfo); + } + return SplitTargetResult{targetOp, targetDataOp}; +} + class LowerWorkdistributePass : public flangomp::impl::LowerWorkdistributeBase { public: @@ -372,6 +451,15 @@ class LowerWorkdistributePass signalPassFailure(); } } + { + SmallVector targetOps; + op->walk([&](omp::TargetOp targetOp) { targetOps.push_back(targetOp); }); + IRRewriter rewriter(&context); + for (auto targetOp : targetOps) { + auto res = splitTargetData(targetOp, rewriter); + } + } + } }; } // namespace diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir new file mode 100644 index 0000000000000..e6ca98d3bf596 --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir @@ -0,0 +1,36 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @test_nested_derived_type_map_operand_and_block_addition( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref}>>) { +// CHECK: %[[VAL_0:.*]] = fir.declare %[[ARG0]] {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref}>>) -> !fir.ref}>> +// CHECK: %[[VAL_1:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref}>>) -> !fir.ref> +// CHECK: %[[VAL_2:.*]] = fir.coordinate_of %[[VAL_1]], i : (!fir.ref>) -> !fir.ref +// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%[[VAL_4:.*]]%[[VAL_5:.*]]"} +// CHECK: %[[VAL_6:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref}>>) -> !fir.ref> +// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[VAL_6]], r : (!fir.ref>) -> !fir.ref +// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_7]] : !fir.ref, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%[[VAL_4]]%[[VAL_9:.*]]"} +// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref, !fir.ref) -> !fir.ref}>> {name = "sa", partial_map = true} +// CHECK: omp.target_data map_entries(%[[VAL_10]] : !fir.ref}>>) { +// CHECK: %[[VAL_11:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref, !fir.ref) -> !fir.ref}>> {name = "sa", partial_map = true} +// CHECK: omp.target map_entries(%[[VAL_11]] -> %[[VAL_12:.*]] : !fir.ref}>>) { +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } + +func.func @test_nested_derived_type_map_operand_and_block_addition(%arg0: !fir.ref}>>) { + %0 = fir.declare %arg0 {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref}>>) -> !fir.ref}>> + %2 = fir.coordinate_of %0, n : (!fir.ref}>>) -> !fir.ref> + %4 = fir.coordinate_of %2, i : (!fir.ref>) -> !fir.ref + %5 = omp.map.info var_ptr(%4 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%n%i"} + %7 = fir.coordinate_of %0, n : (!fir.ref}>>) -> !fir.ref> + %9 = fir.coordinate_of %7, r : (!fir.ref>) -> !fir.ref + %10 = omp.map.info var_ptr(%9 : !fir.ref, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%n%r"} + %11 = omp.map.info var_ptr(%0 : !fir.ref}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%5, %10 : [1,0], [1,1] : !fir.ref, !fir.ref) -> !fir.ref}>> {name = "sa", partial_map = true} + omp.target map_entries(%11 -> %arg1 : !fir.ref}>>) { + omp.terminator + } + return +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir deleted file mode 100644 index 0cc2aeded2532..0000000000000 --- a/flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir +++ /dev/null @@ -1,52 +0,0 @@ -// RUN: fir-opt --lower-workdistribute %s | FileCheck %s - -// CHECK-LABEL: func.func @_QPtarget_simple() { -// CHECK: %[[VAL_0:.*]] = arith.constant 2 : i32 -// CHECK: %[[VAL_1:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtarget_simpleEa"} -// CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref) -> (!fir.ref, !fir.ref) -// CHECK: %[[VAL_3:.*]] = fir.alloca !fir.box> {bindc_name = "simple_var", uniq_name = "_QFtarget_simpleEsimple_var"} -// CHECK: %[[VAL_4:.*]] = fir.zero_bits !fir.heap -// CHECK: %[[VAL_5:.*]] = fir.embox %[[VAL_4]] : (!fir.heap) -> !fir.box> -// CHECK: fir.store %[[VAL_5]] to %[[VAL_3]] : !fir.ref>> -// CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_3]] {fortran_attrs = #fir.var_attrs, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref>>) -> (!fir.ref>>, !fir.ref>>) -// CHECK: hlfir.assign %[[VAL_0]] to %[[VAL_2]]#0 : i32, !fir.ref -// CHECK: %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_2]]#1 : !fir.ref, i32) map_clauses(to) capture(ByRef) -> !fir.ref {name = "a"} -// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_8:.*]] : !fir.ref) private(@_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 %[[VAL_6]]#0 -> %[[VAL_9:.*]] : !fir.ref>>) { -// CHECK: %[[VAL_10:.*]] = arith.constant 10 : i32 -// CHECK: %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_8]] {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref) -> (!fir.ref, !fir.ref) -// CHECK: %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_9]] {fortran_attrs = #fir.var_attrs, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref>>) -> (!fir.ref>>, !fir.ref>>) -// CHECK: %[[VAL_13:.*]] = fir.load %[[VAL_11]]#0 : !fir.ref -// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_10]] : i32 -// CHECK: hlfir.assign %[[VAL_14]] to %[[VAL_12]]#0 realloc : i32, !fir.ref>> -// CHECK: omp.terminator -// CHECK: } -// CHECK: return -// CHECK: } -func.func @_QPtarget_simple() { - %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtarget_simpleEa"} - %1:2 = hlfir.declare %0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref) -> (!fir.ref, !fir.ref) - %2 = fir.alloca !fir.box> {bindc_name = "simple_var", uniq_name = "_QFtarget_simpleEsimple_var"} - %3 = fir.zero_bits !fir.heap - %4 = fir.embox %3 : (!fir.heap) -> !fir.box> - fir.store %4 to %2 : !fir.ref>> - %5:2 = hlfir.declare %2 {fortran_attrs = #fir.var_attrs, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref>>) -> (!fir.ref>>, !fir.ref>>) - %c2_i32 = arith.constant 2 : i32 - hlfir.assign %c2_i32 to %1#0 : i32, !fir.ref - %6 = omp.map.info var_ptr(%1#1 : !fir.ref, i32) map_clauses(to) capture(ByRef) -> !fir.ref {name = "a"} - omp.target map_entries(%6 -> %arg0 : !fir.ref) private(@_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 %5#0 -> %arg1 : !fir.ref>>){ - omp.teams { - omp.workdistribute { - %11:2 = hlfir.declare %arg0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref) -> (!fir.ref, !fir.ref) - %12:2 = hlfir.declare %arg1 {fortran_attrs = #fir.var_attrs, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref>>) -> (!fir.ref>>, !fir.ref>>) - %c10_i32 = arith.constant 10 : i32 - %13 = fir.load %11#0 : !fir.ref - %14 = arith.addi %c10_i32, %13 : i32 - hlfir.assign %14 to %12#0 realloc : i32, !fir.ref>> - omp.terminator - } - omp.terminator - } - omp.terminator - } - return -} From c36df1081b6eb222d941a54e8b958bcae144e352 Mon Sep 17 00:00:00 2001 From: skc7 Date: Tue, 3 Jun 2025 15:47:08 +0530 Subject: [PATCH 15/29] Add fission of target region Logic inspired from ivanradanov llvm branch: flang_workdistribute_iwomp_2024 commit: a77451505dbd728a7a339f6c7c4c1382c709c502 --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 437 +++++++++++++++++- .../lower-workdistribute-fission-target.mlir | 104 +++++ 2 files changed, 521 insertions(+), 20 deletions(-) create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 6509cc5014dd7..8f2de92cfd186 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -34,6 +34,7 @@ #include #include #include +#include "llvm/Frontend/OpenMP/OMPConstants.h" #include #include @@ -346,21 +347,6 @@ struct TeamsWorkdistributeToSingle : public OpRewritePattern { } }; -static std::optional> -getNestedOpToIsolate(omp::TargetOp targetOp) { - auto *targetBlock = &targetOp.getRegion().front(); - for (auto &op : *targetBlock) { - bool first = &op == &*targetBlock->begin(); - bool last = op.getNextNode() == targetBlock->getTerminator(); - if (first && last) - return std::nullopt; - - if (isa(&op)) - return {{&op, first, last}}; - } - return std::nullopt; -} - struct SplitTargetResult { omp::TargetOp targetOp; omp::TargetDataOp dataOp; @@ -371,8 +357,7 @@ struct SplitTargetResult { /// original data region and avoid unnecessary data movement at each of the /// subkernels - we split the target region into a target_data{target} /// nest where only the outer one moves the data -std::optional splitTargetData(omp::TargetOp targetOp, - RewriterBase &rewriter) { +std::optional splitTargetData(omp::TargetOp targetOp, RewriterBase &rewriter) { auto loc = targetOp->getLoc(); if (targetOp.getMapVars().empty()) { @@ -391,7 +376,6 @@ std::optional splitTargetData(omp::TargetOp targetOp, } // Create the new omp.target_data op with these collected map_entries - auto targetLoc = targetOp.getLoc(); rewriter.setInsertionPoint(targetOp); auto device = targetOp.getDevice(); auto ifExpr = targetOp.getIfExpr(); @@ -422,8 +406,420 @@ std::optional splitTargetData(omp::TargetOp targetOp, if (mapInfoRes.getUsers().empty()) rewriter.eraseOp(mapInfo); } - return SplitTargetResult{targetOp, targetDataOp}; -} + return SplitTargetResult{cast(newTargetOp), targetDataOp}; +} + +static std::optional> +getNestedOpToIsolate(omp::TargetOp targetOp) { + if (targetOp.getRegion().empty()) + return std::nullopt; + auto *targetBlock = &targetOp.getRegion().front(); + for (auto &op : *targetBlock) { + bool first = &op == &*targetBlock->begin(); + bool last = op.getNextNode() == targetBlock->getTerminator(); + if (first && last) + return std::nullopt; + + if (isa(&op)) + return {{&op, first, last}}; + } + return std::nullopt; +} + +struct TempOmpVar { + omp::MapInfoOp from, to; +}; + +static bool isPtr(Type ty) { + return isa(ty) || isa(ty); +} + +static Type getPtrTypeForOmp(Type ty) { + if (isPtr(ty)) + return LLVM::LLVMPointerType::get(ty.getContext()); + else + return fir::LLVMPointerType::get(ty); +} + +static TempOmpVar +allocateTempOmpVar(Location loc, Type ty, RewriterBase &rewriter) { + MLIRContext& ctx = *ty.getContext(); + Value alloc; + Type allocType; + auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx); + if (isPtr(ty)) { + Type intTy = rewriter.getI32Type(); + auto one = rewriter.create(loc, intTy, 1); + allocType = llvmPtrTy; + alloc = rewriter.create(loc, llvmPtrTy, allocType, one); + allocType = intTy; + } + else { + allocType = ty; + alloc = rewriter.create(loc, allocType); + } + auto getMapInfo = [&](uint64_t mappingFlags, const char *name) { + return rewriter.create( + loc, alloc.getType(), alloc, + TypeAttr::get(allocType), + rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false), mappingFlags), + rewriter.getAttr( + omp::VariableCaptureKind::ByRef), + /*varPtrPtr=*/Value{}, + /*members=*/SmallVector{}, + /*member_index=*/mlir::ArrayAttr{}, + /*bounds=*/ValueRange(), + /*mapperId=*/mlir::FlatSymbolRefAttr(), + /*name=*/rewriter.getStringAttr(name), + rewriter.getBoolAttr(false)); + }; + uint64_t mapFrom = static_cast>(llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); + uint64_t mapTo = static_cast>(llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); + auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from"); + auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to"); + return TempOmpVar{mapInfoFrom, mapInfoTo}; +}; + +static bool usedOutsideSplit(Value v, Operation *split) { + if (!split) + return false; + auto targetOp = cast(split->getParentOp()); + auto *targetBlock = &targetOp.getRegion().front(); + for (auto *user : v.getUsers()) { + while (user->getBlock() != targetBlock) { + user = user->getParentOp(); + } + if (!user->isBeforeInBlock(split)) + return true; + } + return false; +}; + +static bool isOpToBeCached(Operation *op) { + if (auto loadOp = dyn_cast(op)) { + Value memref = loadOp.getMemref(); + if (auto blockArg = dyn_cast(memref)) { + // 'op' is an operation within the targetOp that 'splitBefore' is also in. + Operation *parentOpOfLoadBlock = op->getBlock()->getParentOp(); + // Ensure the blockArg belongs to the entry block of this parent omp.TargetOp. + // This implies the load is from a variable directly mapped into the target region. + if (isa(parentOpOfLoadBlock) && + !parentOpOfLoadBlock->getRegions().empty()) { + Block *targetOpEntryBlock = &parentOpOfLoadBlock->getRegions().front().front(); + if (blockArg.getOwner() == targetOpEntryBlock) { + // This load is from a direct argument of the target op. + // It's safe to recompute. + return false; + } + } + } + } + return true; +} + +static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) { + if (isa(op)) + return true; + + if (auto loadOp = dyn_cast(op)) { + Value memref = loadOp.getMemref(); + if (auto blockArg = dyn_cast(memref)) { + // 'op' is an operation within the targetOp that 'splitBefore' is also in. + Operation *parentOpOfLoadBlock = op->getBlock()->getParentOp(); + // Ensure the blockArg belongs to the entry block of this parent omp.TargetOp. + // This implies the load is from a variable directly mapped into the target region. + if (isa(parentOpOfLoadBlock) && + !parentOpOfLoadBlock->getRegions().empty()) { + Block *targetOpEntryBlock = &parentOpOfLoadBlock->getRegions().front().front(); + if (blockArg.getOwner() == targetOpEntryBlock) { + // This load is from a direct argument of the target op. + // It's safe to recompute. + return true; + } + } + } + } + + llvm::SmallVector effects; + MemoryEffectOpInterface interface = dyn_cast(op); + if (!interface) { + return false; + } + interface.getEffects(effects); + if (effects.empty()) + return true; + return false; +} + +struct SplitResult { + omp::TargetOp preTargetOp; + omp::TargetOp isolatedTargetOp; + omp::TargetOp postTargetOp; +}; + +static void collectNonRecomputableDeps(Value& v, + omp::TargetOp targetOp, + SetVector& nonRecomputable, + SetVector& toCache, + SetVector& toRecompute) { + Operation *op = v.getDefiningOp(); + if (!op) { + assert(cast(v).getOwner()->getParentOp() == targetOp); + return; + } + if (nonRecomputable.contains(op)) { + toCache.insert(op); + return; + } + toRecompute.insert(op); + for (auto opr : op->getOperands()) + collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache, toRecompute); +} + + +static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter, + MLIRContext& ctx, + IRMapping &mapping, Operation *splitBefore, + Block *targetBlock, Block *newTargetBlock, + SmallVector& allocs, + SetVector& toRecompute) { + for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) { + auto originalArg = targetBlock->getArgument(i); + auto newArg = newTargetBlock->addArgument(originalArg.getType(), + originalArg.getLoc()); + mapping.map(originalArg, newArg); + } + auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx); + for (auto original : allocs) { + Value newArg = newTargetBlock->addArgument( + getPtrTypeForOmp(original.getType()), original.getLoc()); + Value restored; + if (isPtr(original.getType())) { + restored = rewriter.create(loc, llvmPtrTy, newArg); + if (!isa(original.getType())) + restored = rewriter.create(loc, original.getType(), ValueRange(restored)) + .getResult(0); + } + else { + restored = rewriter.create(loc, newArg); + } + mapping.map(original, restored); + } + for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) { + if (toRecompute.contains(&*it)) + rewriter.clone(*it, mapping); + } +} + +static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, + RewriterBase &rewriter) { + auto targetOp = cast(splitBeforeOp->getParentOp()); + MLIRContext& ctx = *targetOp.getContext(); + assert(targetOp); + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + rewriter.setInsertionPoint(targetOp); + + auto preMapOperands = SmallVector(targetOp.getMapVars()); + auto postMapOperands = SmallVector(targetOp.getMapVars()); + + SmallVector requiredVals; + SetVector toCache; + SetVector toRecompute; + SetVector nonRecomputable; + SmallVector allocs; + + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); it++) { + for (auto res : it->getResults()) { + if (usedOutsideSplit(res, splitBeforeOp)) + requiredVals.push_back(res); + } + if (!isRecomputableAfterFission(&*it, splitBeforeOp)) + nonRecomputable.insert(&*it); + } + + for (auto requiredVal : requiredVals) + collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache, toRecompute); + + for (Operation *op : toCache) { + for (auto res : op->getResults()) { + auto alloc = allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter); + allocs.push_back(res); + preMapOperands.push_back(alloc.from); + postMapOperands.push_back(alloc.to); + } + } + + rewriter.setInsertionPoint(targetOp); + + auto preTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), + targetOp.getBareAttr(), targetOp.getDependKindsAttr(), + targetOp.getDependVars(), targetOp.getDevice(), + targetOp.getHasDeviceAddrVars(), targetOp.getHostEvalVars(), + targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), preMapOperands, + targetOp.getNowaitAttr(), targetOp.getPrivateVars(), + targetOp.getPrivateSymsAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + auto *preTargetBlock = rewriter.createBlock( + &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {}); + IRMapping preMapping; + for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) { + auto originalArg = targetBlock->getArgument(i); + auto newArg = preTargetBlock->addArgument(originalArg.getType(), + originalArg.getLoc()); + preMapping.map(originalArg, newArg); + } + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); it++) + rewriter.clone(*it, preMapping); + + auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext()); + + + for (auto original : allocs) { + Value toStore = preMapping.lookup(original); + auto newArg = preTargetBlock->addArgument( + getPtrTypeForOmp(original.getType()), original.getLoc()); + if (isPtr(original.getType())) { + if (!isa(toStore.getType())) + toStore = rewriter.create(loc, llvmPtrTy, + ValueRange(toStore)) + .getResult(0); + rewriter.create(loc, toStore, newArg); + } else { + rewriter.create(loc, toStore, newArg); + } + } + rewriter.create(loc); + + rewriter.setInsertionPoint(targetOp); + + auto isolatedTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), + targetOp.getBareAttr(), targetOp.getDependKindsAttr(), + targetOp.getDependVars(), targetOp.getDevice(), + targetOp.getHasDeviceAddrVars(), targetOp.getHostEvalVars(), + targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), postMapOperands, + targetOp.getNowaitAttr(), targetOp.getPrivateVars(), + targetOp.getPrivateSymsAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + + auto *isolatedTargetBlock = + rewriter.createBlock(&isolatedTargetOp.getRegion(), + isolatedTargetOp.getRegion().begin(), {}, {}); + + IRMapping isolatedMapping; + reloadCacheAndRecompute(loc, rewriter, ctx, isolatedMapping, splitBeforeOp, + targetBlock, isolatedTargetBlock, + allocs, toRecompute); + rewriter.clone(*splitBeforeOp, isolatedMapping); + rewriter.create(loc); + + omp::TargetOp postTargetOp = nullptr; + + if (splitAfter) { + rewriter.setInsertionPoint(targetOp); + postTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), + targetOp.getBareAttr(), targetOp.getDependKindsAttr(), + targetOp.getDependVars(), targetOp.getDevice(), + targetOp.getHasDeviceAddrVars(), targetOp.getHostEvalVars(), + targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), postMapOperands, + targetOp.getNowaitAttr(), targetOp.getPrivateVars(), + targetOp.getPrivateSymsAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + auto *postTargetBlock = rewriter.createBlock( + &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {}); + IRMapping postMapping; + reloadCacheAndRecompute(loc, rewriter, ctx, postMapping, splitBeforeOp, + targetBlock, postTargetBlock, + allocs, toRecompute); + + assert(splitBeforeOp->getNumResults() == 0 || + llvm::all_of(splitBeforeOp->getResults(), + [](Value result) { return result.use_empty(); })); + + for (auto it = std::next(splitBeforeOp->getIterator()); + it != targetBlock->end(); it++) + rewriter.clone(*it, postMapping); + } + + rewriter.eraseOp(targetOp); + return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp}; +} + +static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); + Block *targetBlock = &targetOp.getRegion().front(); + assert(targetBlock == &targetOp.getRegion().back()); + IRMapping mapping; + for (auto map : + zip_equal(targetOp.getMapVars(), targetBlock->getArguments())) { + Value mapInfo = std::get<0>(map); + BlockArgument arg = std::get<1>(map); + Operation *op = mapInfo.getDefiningOp(); + assert(op); + auto mapInfoOp = cast(op); + mapping.map(arg, mapInfoOp.getVarPtr()); + } + rewriter.setInsertionPoint(targetOp); + SmallVector opsToMove; + for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end()); + it != end; ++it) { + auto *op = &*it; + auto allocOp = dyn_cast(op); + auto freeOp = dyn_cast(op); + fir::CallOp runtimeCall = nullptr; + if (isRuntimeCall(op)) + runtimeCall = cast(op); + + if (allocOp || freeOp || runtimeCall) + continue; + opsToMove.push_back(op); + } + // Move ops before targetOp and erase from region + for (Operation *op : opsToMove) + rewriter.clone(*op, mapping); + + rewriter.eraseOp(targetOp); +} + +void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) { + auto tuple = getNestedOpToIsolate(targetOp); + if (!tuple) { + LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n"); + //moveToHost(targetOp, rewriter); + return; + } + + Operation *toIsolate = std::get<0>(*tuple); + bool splitBefore = !std::get<1>(*tuple); + bool splitAfter = !std::get<2>(*tuple); + + if (splitBefore && splitAfter) { + auto res = isolateOp(toIsolate, splitAfter, rewriter); + //moveToHost(res.preTargetOp, rewriter); + fissionTarget(res.postTargetOp, rewriter); + return; + } + if (splitBefore) { + auto res = isolateOp(toIsolate, splitAfter, rewriter); + //moveToHost(res.preTargetOp, rewriter); + return; + } + if (splitAfter) { + assert(false && "TODO"); + auto res = isolateOp(toIsolate->getNextNode(), splitAfter, rewriter); + fissionTarget(res.postTargetOp, rewriter); + return; + } +} class LowerWorkdistributePass : public flangomp::impl::LowerWorkdistributeBase { @@ -457,6 +853,7 @@ class LowerWorkdistributePass IRRewriter rewriter(&context); for (auto targetOp : targetOps) { auto res = splitTargetData(targetOp, rewriter); + if (res) fissionTarget(res->targetOp, rewriter); } } diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir new file mode 100644 index 0000000000000..ed6c641f2e934 --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir @@ -0,0 +1,104 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @x +// CHECK: %[[VAL_0:.*]] = fir.alloca index {bindc_name = "lb"} +// CHECK: fir.store %[[ARG0:.*]] to %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_1:.*]] = fir.alloca index {bindc_name = "ub"} +// CHECK: fir.store %[[ARG1:.*]] to %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_2:.*]] = fir.alloca index {bindc_name = "step"} +// CHECK: fir.store %[[ARG2:.*]] to %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_4:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_6:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: omp.target_data map_entries(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { +// CHECK: %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: %[[VAL_11:.*]] = fir.alloca !fir.heap +// CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref>, !fir.heap) map_clauses(from) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref>, !fir.heap) map_clauses(to) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_to"} +// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_14:.*]], %[[VAL_8]] -> %[[VAL_15:.*]], %[[VAL_9]] -> %[[VAL_16:.*]], %[[VAL_10]] -> %[[VAL_17:.*]], %[[VAL_12]] -> %[[VAL_18:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { +// CHECK: %[[VAL_19:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_20:.*]] = fir.load %[[VAL_14]] : !fir.ref +// CHECK: %[[VAL_21:.*]] = fir.load %[[VAL_15]] : !fir.ref +// CHECK: %[[VAL_22:.*]] = fir.load %[[VAL_16]] : !fir.ref +// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_21]] : index +// CHECK: %[[VAL_24:.*]] = fir.allocmem index, %[[VAL_19]] {uniq_name = "dev_buf"} +// CHECK: fir.store %[[VAL_24]] to %[[VAL_18]] : !fir.llvm_ptr> +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_25:.*]], %[[VAL_8]] -> %[[VAL_26:.*]], %[[VAL_9]] -> %[[VAL_27:.*]], %[[VAL_10]] -> %[[VAL_28:.*]], %[[VAL_13]] -> %[[VAL_29:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { +// CHECK: %[[VAL_30:.*]] = fir.load %[[VAL_29]] : !fir.llvm_ptr> +// CHECK: %[[VAL_31:.*]] = fir.load %[[VAL_25]] : !fir.ref +// CHECK: %[[VAL_32:.*]] = fir.load %[[VAL_26]] : !fir.ref +// CHECK: %[[VAL_33:.*]] = fir.load %[[VAL_27]] : !fir.ref +// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_32]], %[[VAL_32]] : index +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_35:.*]]) : index = (%[[VAL_31]]) to (%[[VAL_32]]) inclusive step (%[[VAL_33]]) { +// CHECK: fir.store %[[VAL_34]] to %[[VAL_30]] : !fir.heap +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_36:.*]], %[[VAL_8]] -> %[[VAL_37:.*]], %[[VAL_9]] -> %[[VAL_38:.*]], %[[VAL_10]] -> %[[VAL_39:.*]], %[[VAL_13]] -> %[[VAL_40:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { +// CHECK: %[[VAL_41:.*]] = fir.load %[[VAL_40]] : !fir.llvm_ptr> +// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_36]] : !fir.ref +// CHECK: %[[VAL_43:.*]] = fir.load %[[VAL_37]] : !fir.ref +// CHECK: %[[VAL_44:.*]] = fir.load %[[VAL_38]] : !fir.ref +// CHECK: %[[VAL_45:.*]] = arith.addi %[[VAL_43]], %[[VAL_43]] : index +// CHECK: fir.store %[[VAL_42]] to %[[VAL_41]] : !fir.heap +// CHECK: fir.freemem %[[VAL_41]] : !fir.heap +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } + +func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref) { + %lb_ref = fir.alloca index {bindc_name = "lb"} + fir.store %lb to %lb_ref : !fir.ref + %ub_ref = fir.alloca index {bindc_name = "ub"} + fir.store %ub to %ub_ref : !fir.ref + %step_ref = fir.alloca index {bindc_name = "step"} + fir.store %step to %step_ref : !fir.ref + + %lb_map = omp.map.info var_ptr(%lb_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} + %ub_map = omp.map.info var_ptr(%ub_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} + %step_map = omp.map.info var_ptr(%step_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} + %addr_map = omp.map.info var_ptr(%addr : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} + + omp.target map_entries(%lb_map -> %arg0, %ub_map -> %arg1, %step_map -> %arg2, %addr_map -> %arg3 : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { + %lb_val = fir.load %arg0 : !fir.ref + %ub_val = fir.load %arg1 : !fir.ref + %step_val = fir.load %arg2 : !fir.ref + %one = arith.constant 1 : index + + %20 = arith.addi %ub_val, %ub_val : index + omp.teams { + omp.workdistribute { + %dev_mem = fir.allocmem index, %one {uniq_name = "dev_buf"} + fir.do_loop %iv = %lb_val to %ub_val step %step_val unordered { + fir.store %20 to %dev_mem : !fir.heap + } + fir.store %lb_val to %dev_mem : !fir.heap + fir.freemem %dev_mem : !fir.heap + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} From 611f8b9344502dcf1311b108ba18804fff89b53b Mon Sep 17 00:00:00 2001 From: skc7 Date: Fri, 6 Jun 2025 14:20:08 +0530 Subject: [PATCH 16/29] Use fir.convert instead of unrealised cast --- flang/lib/Lower/OpenMP/OpenMP.cpp | 10 ++++++++++ flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp | 7 ++----- .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 3 +++ 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index a956498f6e521..1ae6366b3631a 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -565,6 +565,16 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, }); break; + case OMPD_teams_workdistribute: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams_workdistribute: + cp.processNumTeams(stmtCtx, hostInfo.ops); + processSingleNestedIf([](Directive nestedDir) { + return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir); + }); + break; + case OMPD_teams_distribute: case OMPD_teams_distribute_simd: cp.processThreadLimit(stmtCtx, hostInfo.ops); diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 8f2de92cfd186..6d6de47f7741e 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -597,8 +597,7 @@ static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter, if (isPtr(original.getType())) { restored = rewriter.create(loc, llvmPtrTy, newArg); if (!isa(original.getType())) - restored = rewriter.create(loc, original.getType(), ValueRange(restored)) - .getResult(0); + restored = rewriter.create(loc, original.getType(), restored); } else { restored = rewriter.create(loc, newArg); @@ -684,9 +683,7 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, getPtrTypeForOmp(original.getType()), original.getLoc()); if (isPtr(original.getType())) { if (!isa(toStore.getType())) - toStore = rewriter.create(loc, llvmPtrTy, - ValueRange(toStore)) - .getResult(0); + toStore = rewriter.create(loc, llvmPtrTy, toStore); rewriter.create(loc, toStore, newArg); } else { rewriter.create(loc, toStore, newArg); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index eece8573f00ec..3fed83112dc97 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -5246,6 +5246,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, omp::LoopNestOp loopOp = castOrGetParentOfType(capturedOp); unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0; + if (targetOp.getHostEvalVars().empty()) + numLoops = 0; + Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit; llvm::SmallVector lowerBounds(numLoops), upperBounds(numLoops), steps(numLoops); From 99b8d676648f20bd84634c6e6b1e06772c4bf9ad Mon Sep 17 00:00:00 2001 From: skc7 Date: Tue, 10 Jun 2025 18:46:57 +0530 Subject: [PATCH 17/29] [Flang] Add fir omp target alloc and free ops This commit is CP from ivanradanov commit be860ac8baf24b8405e6f396c75d7f0d26375de5 --- .../include/flang/Optimizer/Dialect/FIROps.td | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 8ac847dd7dd0a..466699cc4d476 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -517,6 +517,67 @@ def fir_ZeroOp : fir_OneResultOp<"zero_bits", [NoMemoryEffect]> { let assemblyFormat = "type($intype) attr-dict"; } +def fir_OmpTargetAllocMemOp : fir_Op<"omp_target_allocmem", + [MemoryEffects<[MemAlloc]>, AttrSizedOperandSegments]> { + let summary = "allocate storage on an openmp device for an object of a given type"; + + let description = [{ + Creates a heap memory reference suitable for storing a value of the + given type, T. The heap refernce returned has type `!fir.heap`. + The memory object is in an undefined state. `allocmem` operations must + be paired with `freemem` operations to avoid memory leaks. + + ``` + %0 = fir.omp_target_allocmem !fir.array<10 x f32> + ``` + }]; + + let arguments = (ins + Arg:$device, + TypeAttr:$in_type, + OptionalAttr:$uniq_name, + OptionalAttr:$bindc_name, + Variadic:$typeparams, + Variadic:$shape + ); + let results = (outs fir_HeapType); + + let extraClassDeclaration = [{ + mlir::Type getAllocatedType(); + bool hasLenParams() { return !getTypeparams().empty(); } + bool hasShapeOperands() { return !getShape().empty(); } + unsigned numLenParams() { return getTypeparams().size(); } + operand_range getLenParams() { return getTypeparams(); } + unsigned numShapeOperands() { return getShape().size(); } + operand_range getShapeOperands() { return getShape(); } + static mlir::Type getRefTy(mlir::Type ty); + }]; +} + +def fir_OmpTargetFreeMemOp : fir_Op<"omp_target_freemem", + [MemoryEffects<[MemFree]>]> { + let summary = "free a heap object"; + + let description = [{ + Deallocates a heap memory reference that was allocated by an `allocmem`. + The memory object that is deallocated is placed in an undefined state + after `fir.freemem`. Optimizations may treat the loading of an object + in the undefined state as undefined behavior. This includes aliasing + references, such as the result of an `fir.embox`.Add commentMore actions + + ``` + %21 = fir.omp_target_allocmem !fir.type + ... + fir.omp_target_freemem %21 : !fir.heap> + ``` + }]; + + let arguments = (ins + Arg:$device, + Arg:$heapref + ); +} + //===----------------------------------------------------------------------===// // Terminator operations //===----------------------------------------------------------------------===// From cd7eb59ef83a4af943afeb3215f476a07762586a Mon Sep 17 00:00:00 2001 From: skc7 Date: Tue, 10 Jun 2025 22:11:36 +0530 Subject: [PATCH 18/29] [Flang] Add fir omp target alloc and free ops This commit is C-P from ivaradanov commit be860ac8baf24b8405e6f396c75d7f0d26375de5 --- flang/lib/Optimizer/CodeGen/CodeGen.cpp | 102 +++++++++++++++++- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 81 ++++++++++++-- 2 files changed, 172 insertions(+), 11 deletions(-) diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index a3de3ae9d116a..f94ccbb91de57 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -1168,6 +1168,105 @@ struct FreeMemOpConversion : public fir::FIROpConversion { }; } // namespace +static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) { + auto module = op->getParentOfType(); + if (mlir::LLVM::LLVMFuncOp mallocFunc = + module.lookupSymbol("omp_target_alloc")) + return mallocFunc; + mlir::OpBuilder moduleBuilder(module.getBodyRegion()); + auto i64Ty = mlir::IntegerType::get(module->getContext(), 64); + auto i32Ty = mlir::IntegerType::get(module->getContext(), 32); + return moduleBuilder.create( + moduleBuilder.getUnknownLoc(), "omp_target_alloc", + mlir::LLVM::LLVMFunctionType::get( + mlir::LLVM::LLVMPointerType::get(module->getContext()), + {i64Ty, i32Ty}, + /*isVarArg=*/false)); +} + +namespace { +struct OmpTargetAllocMemOpConversion + : public fir::FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::OmpTargetAllocMemOp heap, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Type heapTy = heap.getType(); + mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(heap); + mlir::Location loc = heap.getLoc(); + auto ity = lowerTy().indexType(); + mlir::Type dataTy = fir::unwrapRefType(heapTy); + mlir::Type llvmObjectTy = convertObjectType(dataTy); + if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) + TODO(loc, "fir.omp_target_allocmem codegen of derived type with length " + "parameters"); + mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy); + if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter)) + size = rewriter.create(loc, ity, size, scaleSize); + for (mlir::Value opnd : adaptor.getOperands()) + size = rewriter.create( + loc, ity, size, integerCast(loc, rewriter, ity, opnd)); + auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); + auto mallocTy = + mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); + if (mallocTyWidth != ity.getIntOrFloatBitWidth()) + size = integerCast(loc, rewriter, mallocTy, size); + heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc)); + rewriter.replaceOpWithNewOp( + heap, ::getLlvmPtrType(heap.getContext()), + mlir::SmallVector({size, heap.getDevice()}), + addLLVMOpBundleAttrs(rewriter, heap->getAttrs(), 2)); + return mlir::success(); + } + + /// Compute the allocation size in bytes of the element type of + /// \p llTy pointer type. The result is returned as a value of \p idxTy + /// integer type. + mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type llTy) const { + return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout()); + } +}; +} // namespace + +static mlir::LLVM::LLVMFuncOp getOmpTargetFree(mlir::Operation *op) { + auto module = op->getParentOfType(); + if (mlir::LLVM::LLVMFuncOp freeFunc = + module.lookupSymbol("omp_target_free")) + return freeFunc; + mlir::OpBuilder moduleBuilder(module.getBodyRegion()); + auto i32Ty = mlir::IntegerType::get(module->getContext(), 32); + return moduleBuilder.create( + moduleBuilder.getUnknownLoc(), "omp_target_free", + mlir::LLVM::LLVMFunctionType::get( + mlir::LLVM::LLVMVoidType::get(module->getContext()), + {getLlvmPtrType(module->getContext()), i32Ty}, + /*isVarArg=*/false)); +} + +namespace { +struct OmpTargetFreeMemOpConversion + : public fir::FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::OmpTargetFreeMemOp freemem, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::LLVM::LLVMFuncOp freeFunc = getOmpTargetFree(freemem); + mlir::Location loc = freemem.getLoc(); + freemem->setAttr("callee", mlir::SymbolRefAttr::get(freeFunc)); + rewriter.create( + loc, mlir::TypeRange{}, + mlir::ValueRange{adaptor.getHeapref(), freemem.getDevice()}, + addLLVMOpBundleAttrs(rewriter, freemem->getAttrs(), 2)); + rewriter.eraseOp(freemem); + return mlir::success(); + } +}; +} // namespace + // Convert subcomponent array indices from column-major to row-major ordering. static llvm::SmallVector convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy, @@ -4274,7 +4373,8 @@ void fir::populateFIRToLLVMConversionPatterns( GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion, IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion, LocalitySpecifierOpConversion, MulcOpConversion, NegcOpConversion, - NoReassocOpConversion, SelectCaseOpConversion, SelectOpConversion, + NoReassocOpConversion, OmpTargetAllocMemOpConversion, + OmpTargetFreeMemOpConversion,SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion, SliceOpConversion, StoreOpConversion, StringLitOpConversion, SubcOpConversion, diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 6d6de47f7741e..f0b4f24c2db5b 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -751,6 +751,15 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp}; } +static mlir::LLVM::ConstantOp +genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { + mlir::Type i32Ty = rewriter.getI32Type(); + mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value); + return rewriter.create(loc, i32Ty, attr); +} + +static Type getOmpDeviceType(MLIRContext *c) { return IntegerType::get(c, 32); } + static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { OpBuilder::InsertionGuard guard(rewriter); Block *targetBlock = &targetOp.getRegion().front(); @@ -776,14 +785,66 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { if (isRuntimeCall(op)) runtimeCall = cast(op); - if (allocOp || freeOp || runtimeCall) - continue; - opsToMove.push_back(op); + if (allocOp || freeOp || runtimeCall) { + Value device = targetOp.getDevice(); + if (!device) { + device = genI32Constant(it->getLoc(), rewriter, 0); + } + if (allocOp) { + auto tmpAllocOp = rewriter.create( + allocOp.getLoc(), allocOp.getType(), device, + allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(), + allocOp.getBindcNameAttr(), allocOp.getTypeparams(), + allocOp.getShape()); + auto newAllocOp = cast( + rewriter.clone(*tmpAllocOp.getOperation(), mapping)); + mapping.map(allocOp.getResult(), newAllocOp.getResult()); + rewriter.eraseOp(tmpAllocOp); + } else if (freeOp) { + auto tmpFreeOp = rewriter.create( + freeOp.getLoc(), device, freeOp.getHeapref()); + rewriter.clone(*tmpFreeOp.getOperation(), mapping); + rewriter.eraseOp(tmpFreeOp); + } else if (runtimeCall) { + auto module = runtimeCall->getParentOfType(); + auto callee = cast( + module.lookupSymbol(runtimeCall.getCalleeAttr())); + std::string newCalleeName = (callee.getName()).str(); + mlir::OpBuilder moduleBuilder(module.getBodyRegion()); + func::FuncOp newCallee = + cast_or_null(module.lookupSymbol(newCalleeName)); + if (!newCallee) { + SmallVector argTypes(callee.getFunctionType().getInputs()); + argTypes.push_back(getOmpDeviceType(rewriter.getContext())); + newCallee = moduleBuilder.create( + callee->getLoc(), newCalleeName, + FunctionType::get(rewriter.getContext(), argTypes, + callee.getFunctionType().getResults())); + if (callee.getArgAttrs()) + newCallee.setArgAttrsAttr(*callee.getArgAttrs()); + if (callee.getResAttrs()) + newCallee.setResAttrsAttr(*callee.getResAttrs()); + newCallee.setSymVisibility(callee.getSymVisibility()); + newCallee->setDiscardableAttrs( + callee->getDiscardableAttrDictionary()); + } + SmallVector operands = runtimeCall.getOperands(); + operands.push_back(device); + auto tmpCall = rewriter.create( + runtimeCall.getLoc(), runtimeCall.getResultTypes(), + SymbolRefAttr::get(newCallee), operands, nullptr, nullptr, nullptr, + runtimeCall.getFastmathAttr()); + Operation *newCall = rewriter.clone(*tmpCall, mapping); + mapping.map(&*it, newCall); + rewriter.eraseOp(tmpCall); + } + } else { + Operation *clonedOp = rewriter.clone(*op, mapping); + for (unsigned i = 0; i < op->getNumResults(); ++i) { + mapping.map(op->getResult(i), clonedOp->getResult(i)); + } + } } - // Move ops before targetOp and erase from region - for (Operation *op : opsToMove) - rewriter.clone(*op, mapping); - rewriter.eraseOp(targetOp); } @@ -791,7 +852,7 @@ void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) { auto tuple = getNestedOpToIsolate(targetOp); if (!tuple) { LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n"); - //moveToHost(targetOp, rewriter); + moveToHost(targetOp, rewriter); return; } @@ -801,13 +862,13 @@ void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) { if (splitBefore && splitAfter) { auto res = isolateOp(toIsolate, splitAfter, rewriter); - //moveToHost(res.preTargetOp, rewriter); + moveToHost(res.preTargetOp, rewriter); fissionTarget(res.postTargetOp, rewriter); return; } if (splitBefore) { auto res = isolateOp(toIsolate, splitAfter, rewriter); - //moveToHost(res.preTargetOp, rewriter); + moveToHost(res.preTargetOp, rewriter); return; } if (splitAfter) { From 7feb0199033ab970f4e39eaca40e3c56ef1b146b Mon Sep 17 00:00:00 2001 From: skc7 Date: Fri, 13 Jun 2025 15:03:58 +0530 Subject: [PATCH 19/29] [Flang] Add Assign_omp fortran-rt function This commit is c-p from ivanradanov commit d7e44991415663f93cfc342b06e1c81c03161ed6 --- flang-rt/lib/runtime/CMakeLists.txt | 2 + flang-rt/lib/runtime/assign_omp.cpp | 72 +++++++++++++++++++ flang/include/flang/Runtime/assign.h | 2 + .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 2 +- 4 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 flang-rt/lib/runtime/assign_omp.cpp diff --git a/flang-rt/lib/runtime/CMakeLists.txt b/flang-rt/lib/runtime/CMakeLists.txt index 332c0872e065f..5200b2b710a5e 100644 --- a/flang-rt/lib/runtime/CMakeLists.txt +++ b/flang-rt/lib/runtime/CMakeLists.txt @@ -21,6 +21,7 @@ set(supported_sources allocatable.cpp array-constructor.cpp assign.cpp + assign_omp.cpp buffer.cpp character.cpp connection.cpp @@ -99,6 +100,7 @@ set(gpu_sources allocatable.cpp array-constructor.cpp assign.cpp + assign_omp.cpp buffer.cpp character.cpp connection.cpp diff --git a/flang-rt/lib/runtime/assign_omp.cpp b/flang-rt/lib/runtime/assign_omp.cpp new file mode 100644 index 0000000000000..e3680c59ed5a1 --- /dev/null +++ b/flang-rt/lib/runtime/assign_omp.cpp @@ -0,0 +1,72 @@ +//===-- lib/runtime/assign_omp.cpp ----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Runtime/assign.h" +#include "flang-rt/runtime/assign-impl.h" +#include "flang-rt/runtime/derived.h" +#include "flang-rt/runtime/descriptor.h" +#include "flang-rt/runtime/stat.h" +#include "flang-rt/runtime/terminator.h" +#include "flang-rt/runtime/tools.h" +#include "flang-rt/runtime/type-info.h" + +#include + +namespace Fortran::runtime { + +RT_API_ATTRS static void Assign(Descriptor &to, const Descriptor &from, + Terminator &terminator, int flags, int32_t omp_device) { + std::size_t toElementBytes{to.ElementBytes()}; + std::size_t fromElementBytes{from.ElementBytes()}; + std::size_t toElements{to.Elements()}; + std::size_t fromElements{from.Elements()}; + + if (toElementBytes != fromElementBytes) + terminator.Crash("Assign: toElementBytes != fromElementBytes"); + if (toElements != fromElements) + terminator.Crash("Assign: toElements != fromElements"); + + void *host_to_ptr = to.raw().base_addr; + void *host_from_ptr = from.raw().base_addr; + size_t length = toElements * toElementBytes; + + printf("assign length: %zu\n", length); + + if (!omp_target_is_present(host_to_ptr, omp_device)) + terminator.Crash("Assign: !omp_target_is_present(host_to_ptr, omp_device)"); + if (!omp_target_is_present(host_from_ptr, omp_device)) + terminator.Crash( + "Assign: !omp_target_is_present(host_from_ptr, omp_device)"); + + printf("host_to_ptr: %p\n", host_to_ptr); +#pragma omp target data use_device_ptr(host_to_ptr, host_from_ptr) device(omp_device) + { + printf("device_to_ptr: %p\n", host_to_ptr); + // TODO do we need to handle overlapping memory? does this function do that? + omp_target_memcpy(host_to_ptr, host_from_ptr, length, /*dst_offset*/ 0, + /*src_offset*/ 0, /*dst*/ omp_device, /*src*/ omp_device); + } + + return; +} + +extern "C" { +RT_EXT_API_GROUP_BEGIN +void RTDEF(Assign_omp)(Descriptor &to, const Descriptor &from, + const char *sourceFile, int sourceLine, int32_t omp_device) { + Terminator terminator{sourceFile, sourceLine}; + // All top-level defined assignments can be recognized in semantics and + // will have been already been converted to calls, so don't check for + // defined assignment apart from components. + Assign(to, from, terminator, + MaybeReallocate | NeedFinalization | ComponentCanBeDefinedAssignment, + omp_device); +} +} // extern "C" + +} \ No newline at end of file diff --git a/flang/include/flang/Runtime/assign.h b/flang/include/flang/Runtime/assign.h index 7d198bdcc9e89..0be52413e4814 100644 --- a/flang/include/flang/Runtime/assign.h +++ b/flang/include/flang/Runtime/assign.h @@ -56,6 +56,8 @@ extern "C" { // API for lowering assignment void RTDECL(Assign)(Descriptor &to, const Descriptor &from, const char *sourceFile = nullptr, int sourceLine = 0); +void RTDECL(Assign_omp)(Descriptor &to, const Descriptor &from, + const char *sourceFile = nullptr, int sourceLine = 0, int32_t omp_device = 0); // This variant has no finalization, defined assignment, or allocatable // reallocation. void RTDECL(AssignTemporary)(Descriptor &to, const Descriptor &from, diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index f0b4f24c2db5b..d82a61705eae3 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -809,7 +809,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { auto module = runtimeCall->getParentOfType(); auto callee = cast( module.lookupSymbol(runtimeCall.getCalleeAttr())); - std::string newCalleeName = (callee.getName()).str(); + std::string newCalleeName = (callee.getName() + "_omp").str(); mlir::OpBuilder moduleBuilder(module.getBodyRegion()); func::FuncOp newCallee = cast_or_null(module.lookupSymbol(newCalleeName)); From 099274d9b93182bda46bb740608175519297a3a5 Mon Sep 17 00:00:00 2001 From: skc7 Date: Fri, 13 Jun 2025 17:37:28 +0530 Subject: [PATCH 20/29] [Flang] Fix workdistribute tests --- .../lower-workdistribute-fission-target.mlir | 62 +++++++++---------- .../OpenMP/lower-workdistribute-target.mlir | 3 - 2 files changed, 29 insertions(+), 36 deletions(-) diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir index ed6c641f2e934..53c09581f9816 100644 --- a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir @@ -1,6 +1,6 @@ // RUN: fir-opt --lower-workdistribute %s | FileCheck %s -// CHECK-LABEL: func.func @x +// CHECK-LABEL: func.func @x( // CHECK: %[[VAL_0:.*]] = fir.alloca index {bindc_name = "lb"} // CHECK: fir.store %[[ARG0:.*]] to %[[VAL_0]] : !fir.ref // CHECK: %[[VAL_1:.*]] = fir.alloca index {bindc_name = "ub"} @@ -19,28 +19,26 @@ // CHECK: %[[VAL_11:.*]] = fir.alloca !fir.heap // CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref>, !fir.heap) map_clauses(from) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_from"} // CHECK: %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref>, !fir.heap) map_clauses(to) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_to"} -// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_14:.*]], %[[VAL_8]] -> %[[VAL_15:.*]], %[[VAL_9]] -> %[[VAL_16:.*]], %[[VAL_10]] -> %[[VAL_17:.*]], %[[VAL_12]] -> %[[VAL_18:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { -// CHECK: %[[VAL_19:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_20:.*]] = fir.load %[[VAL_14]] : !fir.ref -// CHECK: %[[VAL_21:.*]] = fir.load %[[VAL_15]] : !fir.ref -// CHECK: %[[VAL_22:.*]] = fir.load %[[VAL_16]] : !fir.ref -// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_21]] : index -// CHECK: %[[VAL_24:.*]] = fir.allocmem index, %[[VAL_19]] {uniq_name = "dev_buf"} -// CHECK: fir.store %[[VAL_24]] to %[[VAL_18]] : !fir.llvm_ptr> -// CHECK: omp.terminator -// CHECK: } -// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_25:.*]], %[[VAL_8]] -> %[[VAL_26:.*]], %[[VAL_9]] -> %[[VAL_27:.*]], %[[VAL_10]] -> %[[VAL_28:.*]], %[[VAL_13]] -> %[[VAL_29:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { -// CHECK: %[[VAL_30:.*]] = fir.load %[[VAL_29]] : !fir.llvm_ptr> -// CHECK: %[[VAL_31:.*]] = fir.load %[[VAL_25]] : !fir.ref -// CHECK: %[[VAL_32:.*]] = fir.load %[[VAL_26]] : !fir.ref -// CHECK: %[[VAL_33:.*]] = fir.load %[[VAL_27]] : !fir.ref -// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_32]], %[[VAL_32]] : index +// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_15:.*]] = fir.load %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_16]] : index +// CHECK: %[[VAL_19:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[VAL_20:.*]] = "fir.omp_target_allocmem"(%[[VAL_19]], %[[VAL_14]]) <{in_type = index, operandSegmentSizes = array, uniq_name = "dev_buf"}> : (i32, index) -> !fir.heap +// CHECK: fir.store %[[VAL_20]] to %[[VAL_11]] : !fir.ref> +// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_21:.*]], %[[VAL_8]] -> %[[VAL_22:.*]], %[[VAL_9]] -> %[[VAL_23:.*]], %[[VAL_10]] -> %[[VAL_24:.*]], %[[VAL_13]] -> %[[VAL_25:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { +// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.llvm_ptr> +// CHECK: %[[VAL_27:.*]] = fir.load %[[VAL_21]] : !fir.ref +// CHECK: %[[VAL_28:.*]] = fir.load %[[VAL_22]] : !fir.ref +// CHECK: %[[VAL_29:.*]] = fir.load %[[VAL_23]] : !fir.ref +// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_28]], %[[VAL_28]] : index // CHECK: omp.teams { // CHECK: omp.parallel { // CHECK: omp.distribute { // CHECK: omp.wsloop { -// CHECK: omp.loop_nest (%[[VAL_35:.*]]) : index = (%[[VAL_31]]) to (%[[VAL_32]]) inclusive step (%[[VAL_33]]) { -// CHECK: fir.store %[[VAL_34]] to %[[VAL_30]] : !fir.heap +// CHECK: omp.loop_nest (%[[VAL_31:.*]]) : index = (%[[VAL_27]]) to (%[[VAL_28]]) inclusive step (%[[VAL_29]]) { +// CHECK: fir.store %[[VAL_30]] to %[[VAL_26]] : !fir.heap // CHECK: omp.yield // CHECK: } // CHECK: } {omp.composite} @@ -51,16 +49,14 @@ // CHECK: } // CHECK: omp.terminator // CHECK: } -// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_36:.*]], %[[VAL_8]] -> %[[VAL_37:.*]], %[[VAL_9]] -> %[[VAL_38:.*]], %[[VAL_10]] -> %[[VAL_39:.*]], %[[VAL_13]] -> %[[VAL_40:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { -// CHECK: %[[VAL_41:.*]] = fir.load %[[VAL_40]] : !fir.llvm_ptr> -// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_36]] : !fir.ref -// CHECK: %[[VAL_43:.*]] = fir.load %[[VAL_37]] : !fir.ref -// CHECK: %[[VAL_44:.*]] = fir.load %[[VAL_38]] : !fir.ref -// CHECK: %[[VAL_45:.*]] = arith.addi %[[VAL_43]], %[[VAL_43]] : index -// CHECK: fir.store %[[VAL_42]] to %[[VAL_41]] : !fir.heap -// CHECK: fir.freemem %[[VAL_41]] : !fir.heap -// CHECK: omp.terminator -// CHECK: } +// CHECK: %[[VAL_32:.*]] = fir.load %[[VAL_11]] : !fir.ref> +// CHECK: %[[VAL_33:.*]] = fir.load %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_34:.*]] = fir.load %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_35:.*]] = fir.load %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_36:.*]] = arith.addi %[[VAL_34]], %[[VAL_34]] : index +// CHECK: fir.store %[[VAL_33]] to %[[VAL_32]] : !fir.heap +// CHECK: %[[VAL_37:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: "fir.omp_target_freemem"(%[[VAL_37]], %[[VAL_32]]) : (i32, !fir.heap) -> () // CHECK: omp.terminator // CHECK: } // CHECK: return @@ -79,10 +75,10 @@ func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref) { %step_map = omp.map.info var_ptr(%step_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} %addr_map = omp.map.info var_ptr(%addr : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} - omp.target map_entries(%lb_map -> %arg0, %ub_map -> %arg1, %step_map -> %arg2, %addr_map -> %arg3 : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { - %lb_val = fir.load %arg0 : !fir.ref - %ub_val = fir.load %arg1 : !fir.ref - %step_val = fir.load %arg2 : !fir.ref + omp.target map_entries(%lb_map -> %ARG0, %ub_map -> %ARG1, %step_map -> %ARG2, %addr_map -> %ARG3 : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { + %lb_val = fir.load %ARG0 : !fir.ref + %ub_val = fir.load %ARG1 : !fir.ref + %step_val = fir.load %ARG2 : !fir.ref %one = arith.constant 1 : index %20 = arith.addi %ub_val, %ub_val : index diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir index e6ca98d3bf596..ad2cd422d9533 100644 --- a/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir @@ -12,9 +12,6 @@ // CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref, !fir.ref) -> !fir.ref}>> {name = "sa", partial_map = true} // CHECK: omp.target_data map_entries(%[[VAL_10]] : !fir.ref}>>) { // CHECK: %[[VAL_11:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref, !fir.ref) -> !fir.ref}>> {name = "sa", partial_map = true} -// CHECK: omp.target map_entries(%[[VAL_11]] -> %[[VAL_12:.*]] : !fir.ref}>>) { -// CHECK: omp.terminator -// CHECK: } // CHECK: omp.terminator // CHECK: } // CHECK: return From 73f1e0df3a010c72f855553eecc4c8b0cb0e381f Mon Sep 17 00:00:00 2001 From: skc7 Date: Mon, 16 Jun 2025 13:40:27 +0530 Subject: [PATCH 21/29] [Flang] Fix omp target alloc mem lowering C-P from ivanradanov commit 73fd86537980fe1d9454d5a60642f1b71290cc55 --- flang/lib/Optimizer/CodeGen/CodeGen.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index f94ccbb91de57..cf4ca5f1436b5 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -1204,7 +1204,7 @@ struct OmpTargetAllocMemOpConversion mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy); if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter)) size = rewriter.create(loc, ity, size, scaleSize); - for (mlir::Value opnd : adaptor.getOperands()) + for (mlir::Value opnd : adaptor.getOperands().drop_front()) size = rewriter.create( loc, ity, size, integerCast(loc, rewriter, ity, opnd)); auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); From 559df28e0db964e95072c319375db1aa7bb67fd4 Mon Sep 17 00:00:00 2001 From: skc7 Date: Wed, 18 Jun 2025 10:10:21 +0530 Subject: [PATCH 22/29] [Flang] Update assign_omp logic --- flang-rt/lib/runtime/assign_omp.cpp | 56 ++++++----- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 99 ++++++++++--------- .../lower-workdistribute-fission-target.mlir | 74 ++++++++------ .../OpenMP/lower-workdistribute-target.mlir | 2 +- 4 files changed, 126 insertions(+), 105 deletions(-) diff --git a/flang-rt/lib/runtime/assign_omp.cpp b/flang-rt/lib/runtime/assign_omp.cpp index e3680c59ed5a1..dee912155829b 100644 --- a/flang-rt/lib/runtime/assign_omp.cpp +++ b/flang-rt/lib/runtime/assign_omp.cpp @@ -18,9 +18,23 @@ #include namespace Fortran::runtime { +namespace omp { + +typedef int32_t OMPDeviceTy; + +template static T *getDevicePtr(T *anyPtr, OMPDeviceTy ompDevice) { + auto voidAnyPtr = reinterpret_cast(anyPtr); + // If not present on the device it should already be a device ptr + if (!omp_target_is_present(voidAnyPtr, ompDevice)) + return anyPtr; + T *device_ptr = nullptr; +#pragma omp target data use_device_ptr(anyPtr) device(ompDevice) + device_ptr = anyPtr; + return device_ptr; +} RT_API_ATTRS static void Assign(Descriptor &to, const Descriptor &from, - Terminator &terminator, int flags, int32_t omp_device) { + Terminator &terminator, int flags, OMPDeviceTy omp_device) { std::size_t toElementBytes{to.ElementBytes()}; std::size_t fromElementBytes{from.ElementBytes()}; std::size_t toElements{to.Elements()}; @@ -31,42 +45,34 @@ RT_API_ATTRS static void Assign(Descriptor &to, const Descriptor &from, if (toElements != fromElements) terminator.Crash("Assign: toElements != fromElements"); - void *host_to_ptr = to.raw().base_addr; - void *host_from_ptr = from.raw().base_addr; + // Get base addresses and calculate length + void *to_base = to.raw().base_addr; + void *from_base = from.raw().base_addr; size_t length = toElements * toElementBytes; - printf("assign length: %zu\n", length); + // Get device pointers after ensuring data is on device + void *to_ptr = getDevicePtr(to_base, omp_device); + void *from_ptr = getDevicePtr(from_base, omp_device); - if (!omp_target_is_present(host_to_ptr, omp_device)) - terminator.Crash("Assign: !omp_target_is_present(host_to_ptr, omp_device)"); - if (!omp_target_is_present(host_from_ptr, omp_device)) - terminator.Crash( - "Assign: !omp_target_is_present(host_from_ptr, omp_device)"); - - printf("host_to_ptr: %p\n", host_to_ptr); -#pragma omp target data use_device_ptr(host_to_ptr, host_from_ptr) device(omp_device) - { - printf("device_to_ptr: %p\n", host_to_ptr); - // TODO do we need to handle overlapping memory? does this function do that? - omp_target_memcpy(host_to_ptr, host_from_ptr, length, /*dst_offset*/ 0, - /*src_offset*/ 0, /*dst*/ omp_device, /*src*/ omp_device); - } + // Perform copy between device pointers + int result = omp_target_memcpy(to_ptr, from_ptr, length, + /*dst_offset*/ 0, /*src_offset*/ 0, omp_device, omp_device); + if (result != 0) + terminator.Crash("Assign: omp_target_memcpy failed"); return; } extern "C" { RT_EXT_API_GROUP_BEGIN void RTDEF(Assign_omp)(Descriptor &to, const Descriptor &from, - const char *sourceFile, int sourceLine, int32_t omp_device) { + const char *sourceFile, int sourceLine, omp::OMPDeviceTy omp_device) { Terminator terminator{sourceFile, sourceLine}; - // All top-level defined assignments can be recognized in semantics and - // will have been already been converted to calls, so don't check for - // defined assignment apart from components. - Assign(to, from, terminator, + omp::Assign(to, from, terminator, MaybeReallocate | NeedFinalization | ComponentCanBeDefinedAssignment, omp_device); } -} // extern "C" -} \ No newline at end of file +} // extern "C" +} // namespace omp +} // namespace Fortran::runtime diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index d82a61705eae3..5900d93c4e770 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -357,55 +357,77 @@ struct SplitTargetResult { /// original data region and avoid unnecessary data movement at each of the /// subkernels - we split the target region into a target_data{target} /// nest where only the outer one moves the data -std::optional splitTargetData(omp::TargetOp targetOp, RewriterBase &rewriter) { - +std::optional splitTargetData(omp::TargetOp targetOp, + RewriterBase &rewriter) { auto loc = targetOp->getLoc(); if (targetOp.getMapVars().empty()) { LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " target region has no data maps\n"); return std::nullopt; } - // Collect all map_entries with capture(ByRef) - SmallVector byRefMapInfos; - SmallVector MapInfos; + SmallVector mapInfos; for (auto opr : targetOp.getMapVars()) { auto mapInfo = cast(opr.getDefiningOp()); - MapInfos.push_back(mapInfo); - if (mapInfo.getMapCaptureType() == omp::VariableCaptureKind::ByRef) - byRefMapInfos.push_back(opr); + mapInfos.push_back(mapInfo); + } + + rewriter.setInsertionPoint(targetOp); + SmallVector innerMapInfos; + SmallVector outerMapInfos; + + for (auto mapInfo : mapInfos) { + auto originalMapType = + (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType()); + auto originalCaptureType = mapInfo.getMapCaptureType(); + llvm::omp::OpenMPOffloadMappingFlags newMapType; + mlir::omp::VariableCaptureKind newCaptureType; + + if (originalCaptureType == mlir::omp::VariableCaptureKind::ByCopy) { + newMapType = originalMapType; + newCaptureType = originalCaptureType; + } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) { + newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + newCaptureType = originalCaptureType; + outerMapInfos.push_back(mapInfo); + } else { + llvm_unreachable("Unhandled case"); + } + auto innerMapInfo = cast(rewriter.clone(*mapInfo)); + innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr( + rewriter.getIntegerType(64, false), + static_cast< + std::underlying_type_t>( + newMapType))); + innerMapInfo.setMapCaptureType(newCaptureType); + innerMapInfos.push_back(innerMapInfo.getResult()); } - // Create the new omp.target_data op with these collected map_entries rewriter.setInsertionPoint(targetOp); auto device = targetOp.getDevice(); auto ifExpr = targetOp.getIfExpr(); auto deviceAddrVars = targetOp.getHasDeviceAddrVars(); auto devicePtrVars = targetOp.getIsDevicePtrVars(); - auto targetDataOp = rewriter.create(loc, device, ifExpr, - mlir::ValueRange{byRefMapInfos}, - deviceAddrVars, - devicePtrVars); - + auto targetDataOp = rewriter.create( + loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars); auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion()); rewriter.create(loc); rewriter.setInsertionPointToStart(taregtDataBlock); - // Clone mapInfo ops inside omp.target_data region - IRMapping mapping; - for (auto mapInfo : MapInfos) { - rewriter.clone(*mapInfo, mapping); - } - // Clone omp.target from exisiting targetOp inside target_data region. - auto newTargetOp = rewriter.clone(*targetOp, mapping); + auto newTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), + targetOp.getHostEvalVars(), targetOp.getIfExpr(), + targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), + targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), + innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), + targetOp.getPrivateSymsAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(), + newTargetOp.getRegion().begin()); - // Erase TargetOp and its MapInfoOps - rewriter.eraseOp(targetOp); - - for (auto mapInfo : MapInfos) { - auto mapInfoRes = mapInfo.getResult(); - if (mapInfoRes.getUsers().empty()) - rewriter.eraseOp(mapInfo); - } + rewriter.replaceOp(targetOp, newTargetOp); return SplitTargetResult{cast(newTargetOp), targetDataOp}; } @@ -521,25 +543,6 @@ static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) { if (isa(op)) return true; - if (auto loadOp = dyn_cast(op)) { - Value memref = loadOp.getMemref(); - if (auto blockArg = dyn_cast(memref)) { - // 'op' is an operation within the targetOp that 'splitBefore' is also in. - Operation *parentOpOfLoadBlock = op->getBlock()->getParentOp(); - // Ensure the blockArg belongs to the entry block of this parent omp.TargetOp. - // This implies the load is from a variable directly mapped into the target region. - if (isa(parentOpOfLoadBlock) && - !parentOpOfLoadBlock->getRegions().empty()) { - Block *targetOpEntryBlock = &parentOpOfLoadBlock->getRegions().front().front(); - if (blockArg.getOwner() == targetOpEntryBlock) { - // This load is from a direct argument of the target op. - // It's safe to recompute. - return true; - } - } - } - } - llvm::SmallVector effects; MemoryEffectOpInterface interface = dyn_cast(op); if (!interface) { diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir index 53c09581f9816..19bdb9ce10fbd 100644 --- a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir @@ -11,34 +11,46 @@ // CHECK: %[[VAL_4:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} // CHECK: %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} // CHECK: %[[VAL_6:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "addr"} // CHECK: omp.target_data map_entries(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { -// CHECK: %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} -// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} -// CHECK: %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} -// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} -// CHECK: %[[VAL_11:.*]] = fir.alloca !fir.heap -// CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref>, !fir.heap) map_clauses(from) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_from"} -// CHECK: %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref>, !fir.heap) map_clauses(to) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_to"} -// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_15:.*]] = fir.load %[[VAL_0]] : !fir.ref -// CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_1]] : !fir.ref -// CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_2]] : !fir.ref -// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_16]] : index -// CHECK: %[[VAL_19:.*]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK: %[[VAL_20:.*]] = "fir.omp_target_allocmem"(%[[VAL_19]], %[[VAL_14]]) <{in_type = index, operandSegmentSizes = array, uniq_name = "dev_buf"}> : (i32, index) -> !fir.heap -// CHECK: fir.store %[[VAL_20]] to %[[VAL_11]] : !fir.ref> -// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_21:.*]], %[[VAL_8]] -> %[[VAL_22:.*]], %[[VAL_9]] -> %[[VAL_23:.*]], %[[VAL_10]] -> %[[VAL_24:.*]], %[[VAL_13]] -> %[[VAL_25:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { -// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.llvm_ptr> -// CHECK: %[[VAL_27:.*]] = fir.load %[[VAL_21]] : !fir.ref -// CHECK: %[[VAL_28:.*]] = fir.load %[[VAL_22]] : !fir.ref -// CHECK: %[[VAL_29:.*]] = fir.load %[[VAL_23]] : !fir.ref -// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_28]], %[[VAL_28]] : index +// CHECK: %[[VAL_11:.*]] = fir.alloca index +// CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_14:.*]] = fir.alloca index +// CHECK: %[[VAL_15:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_16:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_17:.*]] = fir.alloca index +// CHECK: %[[VAL_18:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_19:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_20:.*]] = fir.alloca !fir.heap +// CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(from) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_22:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(to) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_23:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_25:.*]] = fir.load %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_25]] : index +// CHECK: %[[VAL_28:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[VAL_29:.*]] = "fir.omp_target_allocmem"(%[[VAL_28]], %[[VAL_23]]) <{in_type = index, operandSegmentSizes = array, uniq_name = "dev_buf"}> : (i32, index) -> !fir.heap +// CHECK: fir.store %[[VAL_24]] to %[[VAL_11]] : !fir.ref +// CHECK: fir.store %[[VAL_25]] to %[[VAL_14]] : !fir.ref +// CHECK: fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref +// CHECK: fir.store %[[VAL_29]] to %[[VAL_20]] : !fir.ref> +// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_30:.*]], %[[VAL_8]] -> %[[VAL_31:.*]], %[[VAL_9]] -> %[[VAL_32:.*]], %[[VAL_10]] -> %[[VAL_33:.*]], %[[VAL_13]] -> %[[VAL_34:.*]], %[[VAL_16]] -> %[[VAL_35:.*]], %[[VAL_19]] -> %[[VAL_36:.*]], %[[VAL_22]] -> %[[VAL_37:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { +// CHECK: %[[VAL_38:.*]] = fir.load %[[VAL_34]] : !fir.llvm_ptr +// CHECK: %[[VAL_39:.*]] = fir.load %[[VAL_35]] : !fir.llvm_ptr +// CHECK: %[[VAL_40:.*]] = fir.load %[[VAL_36]] : !fir.llvm_ptr +// CHECK: %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.llvm_ptr> +// CHECK: %[[VAL_42:.*]] = arith.addi %[[VAL_39]], %[[VAL_39]] : index // CHECK: omp.teams { // CHECK: omp.parallel { // CHECK: omp.distribute { // CHECK: omp.wsloop { -// CHECK: omp.loop_nest (%[[VAL_31:.*]]) : index = (%[[VAL_27]]) to (%[[VAL_28]]) inclusive step (%[[VAL_29]]) { -// CHECK: fir.store %[[VAL_30]] to %[[VAL_26]] : !fir.heap +// CHECK: omp.loop_nest (%[[VAL_43:.*]]) : index = (%[[VAL_38]]) to (%[[VAL_39]]) inclusive step (%[[VAL_40]]) { +// CHECK: fir.store %[[VAL_42]] to %[[VAL_41]] : !fir.heap // CHECK: omp.yield // CHECK: } // CHECK: } {omp.composite} @@ -49,14 +61,14 @@ // CHECK: } // CHECK: omp.terminator // CHECK: } -// CHECK: %[[VAL_32:.*]] = fir.load %[[VAL_11]] : !fir.ref> -// CHECK: %[[VAL_33:.*]] = fir.load %[[VAL_0]] : !fir.ref -// CHECK: %[[VAL_34:.*]] = fir.load %[[VAL_1]] : !fir.ref -// CHECK: %[[VAL_35:.*]] = fir.load %[[VAL_2]] : !fir.ref -// CHECK: %[[VAL_36:.*]] = arith.addi %[[VAL_34]], %[[VAL_34]] : index -// CHECK: fir.store %[[VAL_33]] to %[[VAL_32]] : !fir.heap -// CHECK: %[[VAL_37:.*]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK: "fir.omp_target_freemem"(%[[VAL_37]], %[[VAL_32]]) : (i32, !fir.heap) -> () +// CHECK: %[[VAL_44:.*]] = fir.load %[[VAL_11]] : !fir.ref +// CHECK: %[[VAL_45:.*]] = fir.load %[[VAL_14]] : !fir.ref +// CHECK: %[[VAL_46:.*]] = fir.load %[[VAL_17]] : !fir.ref +// CHECK: %[[VAL_47:.*]] = fir.load %[[VAL_20]] : !fir.ref> +// CHECK: %[[VAL_48:.*]] = arith.addi %[[VAL_45]], %[[VAL_45]] : index +// CHECK: fir.store %[[VAL_44]] to %[[VAL_47]] : !fir.heap +// CHECK: %[[VAL_49:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: "fir.omp_target_freemem"(%[[VAL_49]], %[[VAL_47]]) : (i32, !fir.heap) -> () // CHECK: omp.terminator // CHECK: } // CHECK: return diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir index ad2cd422d9533..91e6d5b7201a7 100644 --- a/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir @@ -10,8 +10,8 @@ // CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[VAL_6]], r : (!fir.ref>) -> !fir.ref // CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_7]] : !fir.ref, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%[[VAL_4]]%[[VAL_9:.*]]"} // CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref, !fir.ref) -> !fir.ref}>> {name = "sa", partial_map = true} +// CHECK: %[[VAL_11:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref, !fir.ref) -> !fir.ref}>> {name = "sa", partial_map = true} // CHECK: omp.target_data map_entries(%[[VAL_10]] : !fir.ref}>>) { -// CHECK: %[[VAL_11:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref, !fir.ref) -> !fir.ref}>> {name = "sa", partial_map = true} // CHECK: omp.terminator // CHECK: } // CHECK: return From 3b34aef42f52576fef93cdfb813abc57f3b8c9f9 Mon Sep 17 00:00:00 2001 From: skc7 Date: Wed, 18 Jun 2025 15:37:32 +0530 Subject: [PATCH 23/29] [Flang] Bail out if lower-workdistribute didn't patternmatch. --- flang-rt/lib/runtime/assign_omp.cpp | 2 +- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 47 +++++++++---------- .../OpenMP/lower-workdistribute-target.mlir | 3 +- 3 files changed, 24 insertions(+), 28 deletions(-) diff --git a/flang-rt/lib/runtime/assign_omp.cpp b/flang-rt/lib/runtime/assign_omp.cpp index dee912155829b..80c1c5cccb1ca 100644 --- a/flang-rt/lib/runtime/assign_omp.cpp +++ b/flang-rt/lib/runtime/assign_omp.cpp @@ -68,7 +68,7 @@ RT_EXT_API_GROUP_BEGIN void RTDEF(Assign_omp)(Descriptor &to, const Descriptor &from, const char *sourceFile, int sourceLine, omp::OMPDeviceTy omp_device) { Terminator terminator{sourceFile, sourceLine}; - omp::Assign(to, from, terminator, + Fortran::runtime::omp::Assign(to, from, terminator, MaybeReallocate | NeedFinalization | ComponentCanBeDefinedAssignment, omp_device); } diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 5900d93c4e770..3dfb574977036 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -125,6 +125,7 @@ static T getPerfectlyNested(Operation *op) { /// E() struct FissionWorkdistribute : public OpRewritePattern { + static bool fissionWorkdistributePatternMatched; using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(omp::WorkdistributeOp workdistribute, PatternRewriter &rewriter) const override { @@ -210,9 +211,12 @@ struct FissionWorkdistribute : public OpRewritePattern { changed = true; } } + if (changed) + fissionWorkdistributePatternMatched = true; return success(changed); } }; +bool FissionWorkdistribute::fissionWorkdistributePatternMatched = false; /// If fir.do_loop is present inside teams workdistribute /// @@ -296,6 +300,7 @@ static void genWsLoopOp(mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop, } struct WorkdistributeDoLower : public OpRewritePattern { + static bool workdistributeDoLowerPatternMatched; using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(omp::WorkdistributeOp workdistribute, PatternRewriter &rewriter) const override { @@ -309,12 +314,15 @@ struct WorkdistributeDoLower : public OpRewritePattern { genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps); genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true); rewriter.eraseOp(workdistribute); + workdistributeDoLowerPatternMatched = true; return success(); } return failure(); } }; +bool WorkdistributeDoLower::workdistributeDoLowerPatternMatched = false; + /// If A() and B () are present inside teams workdistribute /// /// omp.teams { @@ -331,6 +339,7 @@ struct WorkdistributeDoLower : public OpRewritePattern { /// struct TeamsWorkdistributeToSingle : public OpRewritePattern { + static bool teamsWorkdistributeToSinglePatternMatched; using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(omp::TeamsOp teamsOp, PatternRewriter &rewriter) const override { @@ -343,9 +352,12 @@ struct TeamsWorkdistributeToSingle : public OpRewritePattern { rewriter.eraseOp(workdistributeBlock->getTerminator()); rewriter.inlineBlockBefore(workdistributeBlock, teamsOp); rewriter.eraseOp(workdistributeOp); + teamsWorkdistributeToSinglePatternMatched = true; return success(); } }; +bool TeamsWorkdistributeToSingle::teamsWorkdistributeToSinglePatternMatched = + false; struct SplitTargetResult { omp::TargetOp targetOp; @@ -517,28 +529,6 @@ static bool usedOutsideSplit(Value v, Operation *split) { return false; }; -static bool isOpToBeCached(Operation *op) { - if (auto loadOp = dyn_cast(op)) { - Value memref = loadOp.getMemref(); - if (auto blockArg = dyn_cast(memref)) { - // 'op' is an operation within the targetOp that 'splitBefore' is also in. - Operation *parentOpOfLoadBlock = op->getBlock()->getParentOp(); - // Ensure the blockArg belongs to the entry block of this parent omp.TargetOp. - // This implies the load is from a variable directly mapped into the target region. - if (isa(parentOpOfLoadBlock) && - !parentOpOfLoadBlock->getRegions().empty()) { - Block *targetOpEntryBlock = &parentOpOfLoadBlock->getRegions().front().front(); - if (blockArg.getOwner() == targetOpEntryBlock) { - // This load is from a direct argument of the target op. - // It's safe to recompute. - return false; - } - } - } - } - return true; -} - static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) { if (isa(op)) return true; @@ -892,6 +882,7 @@ class LowerWorkdistributePass config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled); Operation *op = getOperation(); + bool anyPatternChanged = false; { RewritePatternSet patterns(&context); patterns.insert(&context); @@ -899,16 +890,23 @@ class LowerWorkdistributePass emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); signalPassFailure(); } + anyPatternChanged |= + FissionWorkdistribute::fissionWorkdistributePatternMatched; + anyPatternChanged |= + WorkdistributeDoLower::workdistributeDoLowerPatternMatched; } { RewritePatternSet patterns(&context); - patterns.insert(&context); + patterns.insert( + &context); if (failed(applyPatternsGreedily(op, std::move(patterns), config))) { emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); signalPassFailure(); } + anyPatternChanged |= TeamsWorkdistributeToSingle:: + teamsWorkdistributeToSinglePatternMatched; } - { + if (anyPatternChanged) { SmallVector targetOps; op->walk([&](omp::TargetOp targetOp) { targetOps.push_back(targetOp); }); IRRewriter rewriter(&context); @@ -917,7 +915,6 @@ class LowerWorkdistributePass if (res) fissionTarget(res->targetOp, rewriter); } } - } }; } // namespace diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir index 91e6d5b7201a7..d96068b26ca2f 100644 --- a/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir @@ -10,8 +10,7 @@ // CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[VAL_6]], r : (!fir.ref>) -> !fir.ref // CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_7]] : !fir.ref, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%[[VAL_4]]%[[VAL_9:.*]]"} // CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref, !fir.ref) -> !fir.ref}>> {name = "sa", partial_map = true} -// CHECK: %[[VAL_11:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref, !fir.ref) -> !fir.ref}>> {name = "sa", partial_map = true} -// CHECK: omp.target_data map_entries(%[[VAL_10]] : !fir.ref}>>) { +// CHECK: omp.target map_entries(%[[VAL_10]] -> %[[VAL_11:.*]] : !fir.ref}>>) { // CHECK: omp.terminator // CHECK: } // CHECK: return From 7ab0ba7cc714ae7094cd7ffa9ea5c9766e75e248 Mon Sep 17 00:00:00 2001 From: skc7 Date: Thu, 19 Jun 2025 12:20:40 +0530 Subject: [PATCH 24/29] [flang-rt] Use omp_get_mapped_ptr to get device ptrs. --- flang-rt/lib/runtime/assign_omp.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/flang-rt/lib/runtime/assign_omp.cpp b/flang-rt/lib/runtime/assign_omp.cpp index 80c1c5cccb1ca..a214afea11380 100644 --- a/flang-rt/lib/runtime/assign_omp.cpp +++ b/flang-rt/lib/runtime/assign_omp.cpp @@ -27,9 +27,7 @@ template static T *getDevicePtr(T *anyPtr, OMPDeviceTy ompDevice) { // If not present on the device it should already be a device ptr if (!omp_target_is_present(voidAnyPtr, ompDevice)) return anyPtr; - T *device_ptr = nullptr; -#pragma omp target data use_device_ptr(anyPtr) device(ompDevice) - device_ptr = anyPtr; + T *device_ptr = omp_get_mapped_ptr(anyPtr, ompDevice); return device_ptr; } From a5adfcc0c513b8f780b1d750a19fc0d8d94ea6a6 Mon Sep 17 00:00:00 2001 From: skc7 Date: Mon, 23 Jun 2025 12:44:12 +0530 Subject: [PATCH 25/29] [OMP] Update OMP_Workdistribute directive definition. --- llvm/include/llvm/Frontend/OpenMP/OMP.td | 25 ++++++++---------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index b6d92d572206a..73fba71120af1 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -685,16 +685,7 @@ def OMP_CancellationPoint : Directive<[Spelling<"cancellation point">]> { let association = AS_None; let category = CA_Executable; } -def OMP_Coexecute : Directive<"coexecute"> { - let association = AS_Block; - let category = CA_Executable; -} -def OMP_EndCoexecute : Directive<"end coexecute"> { - let leafConstructs = OMP_Coexecute.leafConstructs; - let association = OMP_Coexecute.association; - let category = OMP_Coexecute.category; -} -def OMP_Critical : Directive<"critical"> { +def OMP_Critical : Directive<[Spelling<"critical">]> { let allowedOnceClauses = [ VersionedClause, ]; @@ -1295,11 +1286,11 @@ def OMP_EndWorkshare : Directive<[Spelling<"end workshare">]> { let category = OMP_Workshare.category; let languages = [L_Fortran]; } -def OMP_Workdistribute : Directive<"workdistribute"> { +def OMP_Workdistribute : Directive<[Spelling<"workdistribute">]> { let association = AS_Block; let category = CA_Executable; } -def OMP_EndWorkdistribute : Directive<"end workdistribute"> { +def OMP_EndWorkdistribute : Directive<[Spelling<"end workdistribute">]> { let leafConstructs = OMP_Workdistribute.leafConstructs; let association = OMP_Workdistribute.association; let category = OMP_Workdistribute.category; @@ -2224,7 +2215,7 @@ def OMP_TargetTeams : Directive<[Spelling<"target teams">]> { let leafConstructs = [OMP_Target, OMP_Teams]; let category = CA_Executable; } -def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> { +def OMP_TargetTeamsDistribute : Directive<[Spelling<"target teams distribute">]> { let allowedClauses = [ VersionedClause, VersionedClause, @@ -2446,7 +2437,7 @@ def OMP_TargetTeamsDistributeSimd let leafConstructs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Simd]; let category = CA_Executable; } -def OMP_TargetTeamsWorkdistribute : Directive<"target teams workdistribute"> { +def OMP_TargetTeamsWorkdistribute : Directive<[Spelling<"target teams workdistribute">]> { let allowedClauses = [ VersionedClause, VersionedClause, @@ -2474,7 +2465,7 @@ def OMP_TargetTeamsWorkdistribute : Directive<"target teams workdistribute"> { let leafConstructs = [OMP_Target, OMP_Teams, OMP_Workdistribute]; let category = CA_Executable; } -def OMP_target_teams_loop : Directive<"target teams loop"> { +def OMP_target_teams_loop : Directive<[Spelling<"target teams loop">]> { let allowedClauses = [ VersionedClause, VersionedClause, @@ -2538,7 +2529,7 @@ def OMP_TaskLoopSimd : Directive<[Spelling<"taskloop simd">]> { let leafConstructs = [OMP_TaskLoop, OMP_Simd]; let category = CA_Executable; } -def OMP_TeamsDistribute : Directive<"teams distribute"> { +def OMP_TeamsDistribute : Directive<[Spelling<"teams distribute">]> { let allowedClauses = [ VersionedClause, VersionedClause, @@ -2725,7 +2716,7 @@ def OMP_teams_loop : Directive<[Spelling<"teams loop">]> { let leafConstructs = [OMP_Teams, OMP_loop]; let category = CA_Executable; } -def OMP_TeamsWorkdistribute : Directive<"teams workdistribute"> { +def OMP_TeamsWorkdistribute : Directive<[Spelling<"teams workdistribute">]> { let allowedClauses = [ VersionedClause, VersionedClause, From 8669a400cf1b2ab65ea0cc80fee10fcdd24ab8ce Mon Sep 17 00:00:00 2001 From: skc7 Date: Mon, 23 Jun 2025 14:20:28 +0530 Subject: [PATCH 26/29] [Flang] Update omp::TargetOp calls --- .../lib/Optimizer/OpenMP/LowerWorkdistribute.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 3dfb574977036..1fe2592d1a357 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -434,8 +434,8 @@ std::optional splitTargetData(omp::TargetOp targetOp, targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), - targetOp.getPrivateSymsAttr(), targetOp.getThreadLimit(), - targetOp.getPrivateMapsAttr()); + targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(), newTargetOp.getRegion().begin()); @@ -653,8 +653,8 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), - targetOp.getPrivateSymsAttr(), targetOp.getThreadLimit(), - targetOp.getPrivateMapsAttr()); + targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); auto *preTargetBlock = rewriter.createBlock( &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {}); IRMapping preMapping; @@ -695,8 +695,8 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), - targetOp.getPrivateSymsAttr(), targetOp.getThreadLimit(), - targetOp.getPrivateMapsAttr()); + targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); auto *isolatedTargetBlock = rewriter.createBlock(&isolatedTargetOp.getRegion(), @@ -722,8 +722,8 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), - targetOp.getPrivateSymsAttr(), targetOp.getThreadLimit(), - targetOp.getPrivateMapsAttr()); + targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); auto *postTargetBlock = rewriter.createBlock( &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {}); IRMapping postMapping; From 002610d18a1268b23f3fa1ac76147b36179457f5 Mon Sep 17 00:00:00 2001 From: skc7 Date: Fri, 27 Jun 2025 12:45:22 +0530 Subject: [PATCH 27/29] Remove frontend and pre-requisite chanes, as they are in seperate PRs now. --- flang-rt/lib/runtime/CMakeLists.txt | 2 - flang-rt/lib/runtime/assign_omp.cpp | 76 ------------- .../include/flang/Optimizer/Dialect/FIROps.td | 61 ----------- flang/include/flang/Runtime/assign.h | 2 - .../flang/Semantics/openmp-directive-sets.h | 14 --- flang/lib/Lower/OpenMP/OpenMP.cpp | 26 +---- flang/lib/Optimizer/CodeGen/CodeGen.cpp | 102 +----------------- flang/lib/Parser/openmp-parsers.cpp | 7 +- flang/lib/Semantics/resolve-directives.cpp | 6 -- flang/test/Lower/OpenMP/workdistribute.f90 | 59 ---------- llvm/include/llvm/Frontend/OpenMP/OMP.td | 55 ---------- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 27 ----- 12 files changed, 3 insertions(+), 434 deletions(-) delete mode 100644 flang-rt/lib/runtime/assign_omp.cpp delete mode 100644 flang/test/Lower/OpenMP/workdistribute.f90 diff --git a/flang-rt/lib/runtime/CMakeLists.txt b/flang-rt/lib/runtime/CMakeLists.txt index 5200b2b710a5e..332c0872e065f 100644 --- a/flang-rt/lib/runtime/CMakeLists.txt +++ b/flang-rt/lib/runtime/CMakeLists.txt @@ -21,7 +21,6 @@ set(supported_sources allocatable.cpp array-constructor.cpp assign.cpp - assign_omp.cpp buffer.cpp character.cpp connection.cpp @@ -100,7 +99,6 @@ set(gpu_sources allocatable.cpp array-constructor.cpp assign.cpp - assign_omp.cpp buffer.cpp character.cpp connection.cpp diff --git a/flang-rt/lib/runtime/assign_omp.cpp b/flang-rt/lib/runtime/assign_omp.cpp deleted file mode 100644 index a214afea11380..0000000000000 --- a/flang-rt/lib/runtime/assign_omp.cpp +++ /dev/null @@ -1,76 +0,0 @@ -//===-- lib/runtime/assign_omp.cpp ----------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "flang/Runtime/assign.h" -#include "flang-rt/runtime/assign-impl.h" -#include "flang-rt/runtime/derived.h" -#include "flang-rt/runtime/descriptor.h" -#include "flang-rt/runtime/stat.h" -#include "flang-rt/runtime/terminator.h" -#include "flang-rt/runtime/tools.h" -#include "flang-rt/runtime/type-info.h" - -#include - -namespace Fortran::runtime { -namespace omp { - -typedef int32_t OMPDeviceTy; - -template static T *getDevicePtr(T *anyPtr, OMPDeviceTy ompDevice) { - auto voidAnyPtr = reinterpret_cast(anyPtr); - // If not present on the device it should already be a device ptr - if (!omp_target_is_present(voidAnyPtr, ompDevice)) - return anyPtr; - T *device_ptr = omp_get_mapped_ptr(anyPtr, ompDevice); - return device_ptr; -} - -RT_API_ATTRS static void Assign(Descriptor &to, const Descriptor &from, - Terminator &terminator, int flags, OMPDeviceTy omp_device) { - std::size_t toElementBytes{to.ElementBytes()}; - std::size_t fromElementBytes{from.ElementBytes()}; - std::size_t toElements{to.Elements()}; - std::size_t fromElements{from.Elements()}; - - if (toElementBytes != fromElementBytes) - terminator.Crash("Assign: toElementBytes != fromElementBytes"); - if (toElements != fromElements) - terminator.Crash("Assign: toElements != fromElements"); - - // Get base addresses and calculate length - void *to_base = to.raw().base_addr; - void *from_base = from.raw().base_addr; - size_t length = toElements * toElementBytes; - - // Get device pointers after ensuring data is on device - void *to_ptr = getDevicePtr(to_base, omp_device); - void *from_ptr = getDevicePtr(from_base, omp_device); - - // Perform copy between device pointers - int result = omp_target_memcpy(to_ptr, from_ptr, length, - /*dst_offset*/ 0, /*src_offset*/ 0, omp_device, omp_device); - - if (result != 0) - terminator.Crash("Assign: omp_target_memcpy failed"); - return; -} - -extern "C" { -RT_EXT_API_GROUP_BEGIN -void RTDEF(Assign_omp)(Descriptor &to, const Descriptor &from, - const char *sourceFile, int sourceLine, omp::OMPDeviceTy omp_device) { - Terminator terminator{sourceFile, sourceLine}; - Fortran::runtime::omp::Assign(to, from, terminator, - MaybeReallocate | NeedFinalization | ComponentCanBeDefinedAssignment, - omp_device); -} - -} // extern "C" -} // namespace omp -} // namespace Fortran::runtime diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 466699cc4d476..8ac847dd7dd0a 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -517,67 +517,6 @@ def fir_ZeroOp : fir_OneResultOp<"zero_bits", [NoMemoryEffect]> { let assemblyFormat = "type($intype) attr-dict"; } -def fir_OmpTargetAllocMemOp : fir_Op<"omp_target_allocmem", - [MemoryEffects<[MemAlloc]>, AttrSizedOperandSegments]> { - let summary = "allocate storage on an openmp device for an object of a given type"; - - let description = [{ - Creates a heap memory reference suitable for storing a value of the - given type, T. The heap refernce returned has type `!fir.heap`. - The memory object is in an undefined state. `allocmem` operations must - be paired with `freemem` operations to avoid memory leaks. - - ``` - %0 = fir.omp_target_allocmem !fir.array<10 x f32> - ``` - }]; - - let arguments = (ins - Arg:$device, - TypeAttr:$in_type, - OptionalAttr:$uniq_name, - OptionalAttr:$bindc_name, - Variadic:$typeparams, - Variadic:$shape - ); - let results = (outs fir_HeapType); - - let extraClassDeclaration = [{ - mlir::Type getAllocatedType(); - bool hasLenParams() { return !getTypeparams().empty(); } - bool hasShapeOperands() { return !getShape().empty(); } - unsigned numLenParams() { return getTypeparams().size(); } - operand_range getLenParams() { return getTypeparams(); } - unsigned numShapeOperands() { return getShape().size(); } - operand_range getShapeOperands() { return getShape(); } - static mlir::Type getRefTy(mlir::Type ty); - }]; -} - -def fir_OmpTargetFreeMemOp : fir_Op<"omp_target_freemem", - [MemoryEffects<[MemFree]>]> { - let summary = "free a heap object"; - - let description = [{ - Deallocates a heap memory reference that was allocated by an `allocmem`. - The memory object that is deallocated is placed in an undefined state - after `fir.freemem`. Optimizations may treat the loading of an object - in the undefined state as undefined behavior. This includes aliasing - references, such as the result of an `fir.embox`.Add commentMore actions - - ``` - %21 = fir.omp_target_allocmem !fir.type - ... - fir.omp_target_freemem %21 : !fir.heap> - ``` - }]; - - let arguments = (ins - Arg:$device, - Arg:$heapref - ); -} - //===----------------------------------------------------------------------===// // Terminator operations //===----------------------------------------------------------------------===// diff --git a/flang/include/flang/Runtime/assign.h b/flang/include/flang/Runtime/assign.h index 0be52413e4814..7d198bdcc9e89 100644 --- a/flang/include/flang/Runtime/assign.h +++ b/flang/include/flang/Runtime/assign.h @@ -56,8 +56,6 @@ extern "C" { // API for lowering assignment void RTDECL(Assign)(Descriptor &to, const Descriptor &from, const char *sourceFile = nullptr, int sourceLine = 0); -void RTDECL(Assign_omp)(Descriptor &to, const Descriptor &from, - const char *sourceFile = nullptr, int sourceLine = 0, int32_t omp_device = 0); // This variant has no finalization, defined assignment, or allocatable // reallocation. void RTDECL(AssignTemporary)(Descriptor &to, const Descriptor &from, diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h index 7ced6ed9b44d6..dd610c9702c28 100644 --- a/flang/include/flang/Semantics/openmp-directive-sets.h +++ b/flang/include/flang/Semantics/openmp-directive-sets.h @@ -143,7 +143,6 @@ static const OmpDirectiveSet topTargetSet{ Directive::OMPD_target_teams_distribute_parallel_do_simd, Directive::OMPD_target_teams_distribute_simd, Directive::OMPD_target_teams_loop, - Directive::OMPD_target_teams_workdistribute, }; static const OmpDirectiveSet allTargetSet{topTargetSet}; @@ -173,7 +172,6 @@ static const OmpDirectiveSet topTeamsSet{ Directive::OMPD_teams_distribute_parallel_do_simd, Directive::OMPD_teams_distribute_simd, Directive::OMPD_teams_loop, - Directive::OMPD_teams_workdistribute, }; static const OmpDirectiveSet bottomTeamsSet{ @@ -189,16 +187,9 @@ static const OmpDirectiveSet allTeamsSet{ Directive::OMPD_target_teams_distribute_parallel_do_simd, Directive::OMPD_target_teams_distribute_simd, Directive::OMPD_target_teams_loop, - Directive::OMPD_target_teams_workdistribute, } | topTeamsSet, }; -static const OmpDirectiveSet allWorkdistributeSet{ - Directive::OMPD_workdistribute, - Directive::OMPD_teams_workdistribute, - Directive::OMPD_target_teams_workdistribute, -}; - //===----------------------------------------------------------------------===// // Directive sets for groups of multiple directives //===----------------------------------------------------------------------===// @@ -239,9 +230,6 @@ static const OmpDirectiveSet blockConstructSet{ Directive::OMPD_taskgroup, Directive::OMPD_teams, Directive::OMPD_workshare, - Directive::OMPD_target_teams_workdistribute, - Directive::OMPD_teams_workdistribute, - Directive::OMPD_workdistribute, }; static const OmpDirectiveSet loopConstructSet{ @@ -306,7 +294,6 @@ static const OmpDirectiveSet workShareSet{ Directive::OMPD_scope, Directive::OMPD_sections, Directive::OMPD_single, - Directive::OMPD_workdistribute, } | allDoSet, }; @@ -389,7 +376,6 @@ static const OmpDirectiveSet nestedReduceWorkshareAllowedSet{ }; static const OmpDirectiveSet nestedTeamsAllowedSet{ - Directive::OMPD_workdistribute, Directive::OMPD_distribute, Directive::OMPD_distribute_parallel_do, Directive::OMPD_distribute_parallel_do_simd, diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 1ae6366b3631a..ebd1d038716e4 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -565,16 +565,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, }); break; - case OMPD_teams_workdistribute: - cp.processThreadLimit(stmtCtx, hostInfo.ops); - [[fallthrough]]; - case OMPD_target_teams_workdistribute: - cp.processNumTeams(stmtCtx, hostInfo.ops); - processSingleNestedIf([](Directive nestedDir) { - return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir); - }); - break; - case OMPD_teams_distribute: case OMPD_teams_distribute_simd: cp.processThreadLimit(stmtCtx, hostInfo.ops); @@ -2692,17 +2682,6 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, queue, item, clauseOps); } -static mlir::omp::WorkdistributeOp genWorkdistributeOp( - lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, - mlir::Location loc, const ConstructQueue &queue, - ConstructQueue::const_iterator item) { - return genOpWithBody( - OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, - llvm::omp::Directive::OMPD_workdistribute), - queue, item); -} - //===----------------------------------------------------------------------===// // Code generation functions for the standalone version of constructs that can // also be a leaf of a composite construct @@ -3323,10 +3302,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter, TODO(loc, "Unhandled loop directive (" + llvm::omp::getOpenMPDirectiveName(dir, version) + ")"); } - case llvm::omp::Directive::OMPD_workdistribute: - newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue, - item); - break; + // case llvm::omp::Directive::OMPD_workdistribute: case llvm::omp::Directive::OMPD_workshare: newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index cf4ca5f1436b5..a3de3ae9d116a 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -1168,105 +1168,6 @@ struct FreeMemOpConversion : public fir::FIROpConversion { }; } // namespace -static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) { - auto module = op->getParentOfType(); - if (mlir::LLVM::LLVMFuncOp mallocFunc = - module.lookupSymbol("omp_target_alloc")) - return mallocFunc; - mlir::OpBuilder moduleBuilder(module.getBodyRegion()); - auto i64Ty = mlir::IntegerType::get(module->getContext(), 64); - auto i32Ty = mlir::IntegerType::get(module->getContext(), 32); - return moduleBuilder.create( - moduleBuilder.getUnknownLoc(), "omp_target_alloc", - mlir::LLVM::LLVMFunctionType::get( - mlir::LLVM::LLVMPointerType::get(module->getContext()), - {i64Ty, i32Ty}, - /*isVarArg=*/false)); -} - -namespace { -struct OmpTargetAllocMemOpConversion - : public fir::FIROpConversion { - using FIROpConversion::FIROpConversion; - - mlir::LogicalResult - matchAndRewrite(fir::OmpTargetAllocMemOp heap, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Type heapTy = heap.getType(); - mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(heap); - mlir::Location loc = heap.getLoc(); - auto ity = lowerTy().indexType(); - mlir::Type dataTy = fir::unwrapRefType(heapTy); - mlir::Type llvmObjectTy = convertObjectType(dataTy); - if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) - TODO(loc, "fir.omp_target_allocmem codegen of derived type with length " - "parameters"); - mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy); - if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter)) - size = rewriter.create(loc, ity, size, scaleSize); - for (mlir::Value opnd : adaptor.getOperands().drop_front()) - size = rewriter.create( - loc, ity, size, integerCast(loc, rewriter, ity, opnd)); - auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); - auto mallocTy = - mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); - if (mallocTyWidth != ity.getIntOrFloatBitWidth()) - size = integerCast(loc, rewriter, mallocTy, size); - heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc)); - rewriter.replaceOpWithNewOp( - heap, ::getLlvmPtrType(heap.getContext()), - mlir::SmallVector({size, heap.getDevice()}), - addLLVMOpBundleAttrs(rewriter, heap->getAttrs(), 2)); - return mlir::success(); - } - - /// Compute the allocation size in bytes of the element type of - /// \p llTy pointer type. The result is returned as a value of \p idxTy - /// integer type. - mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy, - mlir::ConversionPatternRewriter &rewriter, - mlir::Type llTy) const { - return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout()); - } -}; -} // namespace - -static mlir::LLVM::LLVMFuncOp getOmpTargetFree(mlir::Operation *op) { - auto module = op->getParentOfType(); - if (mlir::LLVM::LLVMFuncOp freeFunc = - module.lookupSymbol("omp_target_free")) - return freeFunc; - mlir::OpBuilder moduleBuilder(module.getBodyRegion()); - auto i32Ty = mlir::IntegerType::get(module->getContext(), 32); - return moduleBuilder.create( - moduleBuilder.getUnknownLoc(), "omp_target_free", - mlir::LLVM::LLVMFunctionType::get( - mlir::LLVM::LLVMVoidType::get(module->getContext()), - {getLlvmPtrType(module->getContext()), i32Ty}, - /*isVarArg=*/false)); -} - -namespace { -struct OmpTargetFreeMemOpConversion - : public fir::FIROpConversion { - using FIROpConversion::FIROpConversion; - - mlir::LogicalResult - matchAndRewrite(fir::OmpTargetFreeMemOp freemem, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::LLVM::LLVMFuncOp freeFunc = getOmpTargetFree(freemem); - mlir::Location loc = freemem.getLoc(); - freemem->setAttr("callee", mlir::SymbolRefAttr::get(freeFunc)); - rewriter.create( - loc, mlir::TypeRange{}, - mlir::ValueRange{adaptor.getHeapref(), freemem.getDevice()}, - addLLVMOpBundleAttrs(rewriter, freemem->getAttrs(), 2)); - rewriter.eraseOp(freemem); - return mlir::success(); - } -}; -} // namespace - // Convert subcomponent array indices from column-major to row-major ordering. static llvm::SmallVector convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy, @@ -4373,8 +4274,7 @@ void fir::populateFIRToLLVMConversionPatterns( GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion, IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion, LocalitySpecifierOpConversion, MulcOpConversion, NegcOpConversion, - NoReassocOpConversion, OmpTargetAllocMemOpConversion, - OmpTargetFreeMemOpConversion,SelectCaseOpConversion, SelectOpConversion, + NoReassocOpConversion, SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion, SliceOpConversion, StoreOpConversion, StringLitOpConversion, SubcOpConversion, diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp index ad729932a5f00..c55642d969503 100644 --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -1492,17 +1492,12 @@ TYPE_PARSER( "SINGLE" >> pure(llvm::omp::Directive::OMPD_single), "TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data), "TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel), - "TARGET TEAMS WORKDISTRIBUTE" >> - pure(llvm::omp::Directive::OMPD_target_teams_workdistribute), "TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams), "TARGET" >> pure(llvm::omp::Directive::OMPD_target), "TASK"_id >> pure(llvm::omp::Directive::OMPD_task), "TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup), - "TEAMS WORKDISTRIBUTE" >> - pure(llvm::omp::Directive::OMPD_teams_workdistribute), "TEAMS" >> pure(llvm::omp::Directive::OMPD_teams), - "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare), - "WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_workdistribute)))) + "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare)))) TYPE_PARSER(sourced(construct( sourced(Parser{}), Parser{}))) diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index da3315ff1acfb..885c02e6ec74b 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -1656,9 +1656,6 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) { case llvm::omp::Directive::OMPD_task: case llvm::omp::Directive::OMPD_taskgroup: case llvm::omp::Directive::OMPD_teams: - case llvm::omp::Directive::OMPD_workdistribute: - case llvm::omp::Directive::OMPD_teams_workdistribute: - case llvm::omp::Directive::OMPD_target_teams_workdistribute: case llvm::omp::Directive::OMPD_workshare: case llvm::omp::Directive::OMPD_parallel_workshare: case llvm::omp::Directive::OMPD_target_teams: @@ -1692,9 +1689,6 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) { case llvm::omp::Directive::OMPD_target: case llvm::omp::Directive::OMPD_task: case llvm::omp::Directive::OMPD_teams: - case llvm::omp::Directive::OMPD_workdistribute: - case llvm::omp::Directive::OMPD_teams_workdistribute: - case llvm::omp::Directive::OMPD_target_teams_workdistribute: case llvm::omp::Directive::OMPD_parallel_workshare: case llvm::omp::Directive::OMPD_target_teams: case llvm::omp::Directive::OMPD_target_parallel: { diff --git a/flang/test/Lower/OpenMP/workdistribute.f90 b/flang/test/Lower/OpenMP/workdistribute.f90 deleted file mode 100644 index 924205bb72e5e..0000000000000 --- a/flang/test/Lower/OpenMP/workdistribute.f90 +++ /dev/null @@ -1,59 +0,0 @@ -! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s - -! CHECK-LABEL: func @_QPtarget_teams_workdistribute -subroutine target_teams_workdistribute() - ! CHECK: omp.target - ! CHECK: omp.teams - ! CHECK: omp.workdistribute - !$omp target teams workdistribute - ! CHECK: fir.call - call f1() - ! CHECK: omp.terminator - ! CHECK: omp.terminator - ! CHECK: omp.terminator - !$omp end target teams workdistribute -end subroutine target_teams_workdistribute - -! CHECK-LABEL: func @_QPteams_workdistribute -subroutine teams_workdistribute() - ! CHECK: omp.teams - ! CHECK: omp.workdistribute - !$omp teams workdistribute - ! CHECK: fir.call - call f1() - ! CHECK: omp.terminator - ! CHECK: omp.terminator - !$omp end teams workdistribute -end subroutine teams_workdistribute - -! CHECK-LABEL: func @_QPtarget_teams_workdistribute_m -subroutine target_teams_workdistribute_m() - ! CHECK: omp.target - ! CHECK: omp.teams - ! CHECK: omp.workdistribute - !$omp target - !$omp teams - !$omp workdistribute - ! CHECK: fir.call - call f1() - ! CHECK: omp.terminator - ! CHECK: omp.terminator - ! CHECK: omp.terminator - !$omp end workdistribute - !$omp end teams - !$omp end target -end subroutine target_teams_workdistribute_m - -! CHECK-LABEL: func @_QPteams_workdistribute_m -subroutine teams_workdistribute_m() - ! CHECK: omp.teams - ! CHECK: omp.workdistribute - !$omp teams - !$omp workdistribute - ! CHECK: fir.call - call f1() - ! CHECK: omp.terminator - ! CHECK: omp.terminator - !$omp end workdistribute - !$omp end teams -end subroutine teams_workdistribute_m diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index 73fba71120af1..c13215a1beb4d 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -1286,15 +1286,6 @@ def OMP_EndWorkshare : Directive<[Spelling<"end workshare">]> { let category = OMP_Workshare.category; let languages = [L_Fortran]; } -def OMP_Workdistribute : Directive<[Spelling<"workdistribute">]> { - let association = AS_Block; - let category = CA_Executable; -} -def OMP_EndWorkdistribute : Directive<[Spelling<"end workdistribute">]> { - let leafConstructs = OMP_Workdistribute.leafConstructs; - let association = OMP_Workdistribute.association; - let category = OMP_Workdistribute.category; -} //===----------------------------------------------------------------------===// // Definitions of OpenMP compound directives @@ -2437,34 +2428,6 @@ def OMP_TargetTeamsDistributeSimd let leafConstructs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Simd]; let category = CA_Executable; } -def OMP_TargetTeamsWorkdistribute : Directive<[Spelling<"target teams workdistribute">]> { - let allowedClauses = [ - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - ]; - let allowedOnceClauses = [ - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - ]; - let leafConstructs = [OMP_Target, OMP_Teams, OMP_Workdistribute]; - let category = CA_Executable; -} def OMP_target_teams_loop : Directive<[Spelling<"target teams loop">]> { let allowedClauses = [ VersionedClause, @@ -2716,21 +2679,3 @@ def OMP_teams_loop : Directive<[Spelling<"teams loop">]> { let leafConstructs = [OMP_Teams, OMP_loop]; let category = CA_Executable; } -def OMP_TeamsWorkdistribute : Directive<[Spelling<"teams workdistribute">]> { - let allowedClauses = [ - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - ]; - let allowedOnceClauses = [ - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - ]; - let leafConstructs = [OMP_Teams, OMP_Workdistribute]; - let category = CA_Executable; -} diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 8d65b37330eb8..ac80926053a2d 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -325,33 +325,6 @@ def SectionsOp : OpenMP_Op<"sections", traits = [ let hasRegionVerifier = 1; } -//===----------------------------------------------------------------------===// -// workdistribute Construct -//===----------------------------------------------------------------------===// - -def WorkdistributeOp : OpenMP_Op<"workdistribute"> { - let summary = "workdistribute directive"; - let description = [{ - workdistribute divides execution of the enclosed structured block into - separate units of work, each executed only once by each - initial thread in the league. - - ``` - !$omp target teams - !$omp workdistribute - tmp = matmul(x, y) - !$omp end workdistribute - a = tmp(0, 0) ! there is no implicit barrier! the matmul hasnt completed! - !$omp end target teams workdistribute - ``` - - }]; - - let regions = (region AnyRegion:$region); - - let assemblyFormat = "$region attr-dict"; -} - //===----------------------------------------------------------------------===// // 2.8.2 Single Construct //===----------------------------------------------------------------------===// From cd3ed43a9a2755ed8cda96f69641bb3293c40749 Mon Sep 17 00:00:00 2001 From: skc7 Date: Fri, 27 Jun 2025 12:47:41 +0530 Subject: [PATCH 28/29] Remove older changes --- llvm/include/llvm/Frontend/OpenMP/OMP.td | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index c13215a1beb4d..a87111cb5a11d 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -2206,7 +2206,8 @@ def OMP_TargetTeams : Directive<[Spelling<"target teams">]> { let leafConstructs = [OMP_Target, OMP_Teams]; let category = CA_Executable; } -def OMP_TargetTeamsDistribute : Directive<[Spelling<"target teams distribute">]> { +def OMP_TargetTeamsDistribute + : Directive<[Spelling<"target teams distribute">]> { let allowedClauses = [ VersionedClause, VersionedClause, From b6ca26a145c6e445599bdcfcd0de10f2e122c12a Mon Sep 17 00:00:00 2001 From: skc7 Date: Fri, 4 Jul 2025 14:33:46 +0530 Subject: [PATCH 29/29] [Flang] Remove rewrite patterns in lower-workdistribute. --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 302 ++++++++---------- .../OpenMP/lower-workdistribute-doloop.mlir | 2 +- 2 files changed, 141 insertions(+), 163 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 1fe2592d1a357..3f4116d524452 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -124,99 +124,96 @@ static T getPerfectlyNested(Operation *op) { /// } /// E() -struct FissionWorkdistribute : public OpRewritePattern { - static bool fissionWorkdistributePatternMatched; - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(omp::WorkdistributeOp workdistribute, - PatternRewriter &rewriter) const override { - auto loc = workdistribute->getLoc(); - auto teams = dyn_cast(workdistribute->getParentOp()); - if (!teams) { - emitError(loc, "workdistribute not nested in teams\n"); - return failure(); - } - if (workdistribute.getRegion().getBlocks().size() != 1) { - emitError(loc, "workdistribute with multiple blocks\n"); - return failure(); +static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) { + OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast(workdistribute->getParentOp()); + if (!teams) { + emitError(loc, "workdistribute not nested in teams\n"); + return false; + } + if (workdistribute.getRegion().getBlocks().size() != 1) { + emitError(loc, "workdistribute with multiple blocks\n"); + return false; + } + if (teams.getRegion().getBlocks().size() != 1) { + emitError(loc, "teams with multiple blocks\n"); + return false; + } + + auto *teamsBlock = &teams.getRegion().front(); + bool changed = false; + // Move the ops inside teams and before workdistribute outside. + IRMapping irMapping; + llvm::SmallVector teamsHoisted; + for (auto &op : teams.getOps()) { + if (&op == workdistribute) { + break; } - if (teams.getRegion().getBlocks().size() != 1) { - emitError(loc, "teams with multiple blocks\n"); - return failure(); + if (shouldParallelize(&op)) { + emitError(loc, "teams has parallelize ops before first workdistribute\n"); + return false; + } else { + rewriter.setInsertionPoint(teams); + rewriter.clone(op, irMapping); + teamsHoisted.push_back(&op); + changed = true; } - - auto *teamsBlock = &teams.getRegion().front(); - bool changed = false; - // Move the ops inside teams and before workdistribute outside. - IRMapping irMapping; - llvm::SmallVector teamsHoisted; - for (auto &op : teams.getOps()) { - if (&op == workdistribute) { + } + for (auto *op : llvm::reverse(teamsHoisted)) { + op->replaceAllUsesWith(irMapping.lookup(op)); + op->erase(); + } + + // While we have unhandled operations in the original workdistribute + auto *workdistributeBlock = &workdistribute.getRegion().front(); + auto *terminator = workdistributeBlock->getTerminator(); + while (&workdistributeBlock->front() != terminator) { + rewriter.setInsertionPoint(teams); + IRMapping mapping; + llvm::SmallVector hoisted; + Operation *parallelize = nullptr; + for (auto &op : workdistribute.getOps()) { + if (&op == terminator) { break; } if (shouldParallelize(&op)) { - emitError(loc, - "teams has parallelize ops before first workdistribute\n"); - return failure(); + parallelize = &op; + break; } else { - rewriter.setInsertionPoint(teams); - rewriter.clone(op, irMapping); - teamsHoisted.push_back(&op); + rewriter.clone(op, mapping); + hoisted.push_back(&op); changed = true; } } - for (auto *op : teamsHoisted) - rewriter.replaceOp(op, irMapping.lookup(op)); - - // While we have unhandled operations in the original workdistribute - auto *workdistributeBlock = &workdistribute.getRegion().front(); - auto *terminator = workdistributeBlock->getTerminator(); - while (&workdistributeBlock->front() != terminator) { - rewriter.setInsertionPoint(teams); - IRMapping mapping; - llvm::SmallVector hoisted; - Operation *parallelize = nullptr; - for (auto &op : workdistribute.getOps()) { - if (&op == terminator) { - break; - } - if (shouldParallelize(&op)) { - parallelize = &op; - break; - } else { - rewriter.clone(op, mapping); - hoisted.push_back(&op); - changed = true; - } - } - for (auto *op : hoisted) - rewriter.replaceOp(op, mapping.lookup(op)); + for (auto *op : llvm::reverse(hoisted)) { + op->replaceAllUsesWith(mapping.lookup(op)); + op->erase(); + } - if (parallelize && hoisted.empty() && - parallelize->getNextNode() == terminator) - break; - if (parallelize) { - auto newTeams = rewriter.cloneWithoutRegions(teams); - auto *newTeamsBlock = rewriter.createBlock( - &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {}); - for (auto arg : teamsBlock->getArguments()) - newTeamsBlock->addArgument(arg.getType(), arg.getLoc()); - auto newWorkdistribute = rewriter.create(loc); - rewriter.create(loc); - rewriter.createBlock(&newWorkdistribute.getRegion(), - newWorkdistribute.getRegion().begin(), {}, {}); - auto *cloned = rewriter.clone(*parallelize); - rewriter.replaceOp(parallelize, cloned); - rewriter.create(loc); - changed = true; - } + if (parallelize && hoisted.empty() && + parallelize->getNextNode() == terminator) + break; + if (parallelize) { + auto newTeams = rewriter.cloneWithoutRegions(teams); + auto *newTeamsBlock = rewriter.createBlock( + &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {}); + for (auto arg : teamsBlock->getArguments()) + newTeamsBlock->addArgument(arg.getType(), arg.getLoc()); + auto newWorkdistribute = rewriter.create(loc); + rewriter.create(loc); + rewriter.createBlock(&newWorkdistribute.getRegion(), + newWorkdistribute.getRegion().begin(), {}, {}); + auto *cloned = rewriter.clone(*parallelize); + parallelize->replaceAllUsesWith(cloned); + parallelize->erase(); + rewriter.create(loc); + changed = true; } - if (changed) - fissionWorkdistributePatternMatched = true; - return success(changed); } -}; -bool FissionWorkdistribute::fissionWorkdistributePatternMatched = false; + return changed; +} /// If fir.do_loop is present inside teams workdistribute /// @@ -241,8 +238,7 @@ bool FissionWorkdistribute::fissionWorkdistributePatternMatched = false; /// } /// } -static void genParallelOp(Location loc, PatternRewriter &rewriter, - bool composite) { +static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { auto parallelOp = rewriter.create(loc); parallelOp.setComposite(composite); rewriter.createBlock(¶llelOp.getRegion()); @@ -250,8 +246,7 @@ static void genParallelOp(Location loc, PatternRewriter &rewriter, return; } -static void genDistributeOp(Location loc, PatternRewriter &rewriter, - bool composite) { +static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) { mlir::omp::DistributeOperands distributeClauseOps; auto distributeOp = rewriter.create(loc, distributeClauseOps); @@ -262,7 +257,7 @@ static void genDistributeOp(Location loc, PatternRewriter &rewriter, } static void -genLoopNestClauseOps(mlir::PatternRewriter &rewriter, fir::DoLoopOp loop, +genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop, mlir::omp::LoopNestOperands &loopNestClauseOps) { assert(loopNestClauseOps.loopLowerBounds.empty() && "Loop nest bounds were already emitted!"); @@ -272,7 +267,7 @@ genLoopNestClauseOps(mlir::PatternRewriter &rewriter, fir::DoLoopOp loop, loopNestClauseOps.loopInclusive = rewriter.getUnitAttr(); } -static void genWsLoopOp(mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop, +static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, const mlir::omp::LoopNestOperands &clauseOps, bool composite) { @@ -294,34 +289,28 @@ static void genWsLoopOp(mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop, if (auto resultOp = dyn_cast(terminatorOp)) { rewriter.setInsertionPoint(terminatorOp); rewriter.create(doLoop->getLoc()); - rewriter.eraseOp(terminatorOp); + // rewriter.erase(terminatorOp); + terminatorOp->erase(); } return; } -struct WorkdistributeDoLower : public OpRewritePattern { - static bool workdistributeDoLowerPatternMatched; - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(omp::WorkdistributeOp workdistribute, - PatternRewriter &rewriter) const override { - auto doLoop = getPerfectlyNested(workdistribute); - auto wdLoc = workdistribute->getLoc(); - if (doLoop && shouldParallelize(doLoop)) { - assert(doLoop.getReduceOperands().empty()); - genParallelOp(wdLoc, rewriter, true); - genDistributeOp(wdLoc, rewriter, true); - mlir::omp::LoopNestOperands loopNestClauseOps; - genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps); - genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true); - rewriter.eraseOp(workdistribute); - workdistributeDoLowerPatternMatched = true; - return success(); - } - return failure(); +static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) { + OpBuilder rewriter(workdistribute); + auto doLoop = getPerfectlyNested(workdistribute); + auto wdLoc = workdistribute->getLoc(); + if (doLoop && shouldParallelize(doLoop)) { + assert(doLoop.getReduceOperands().empty()); + genParallelOp(wdLoc, rewriter, true); + genDistributeOp(wdLoc, rewriter, true); + mlir::omp::LoopNestOperands loopNestClauseOps; + genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps); + genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true); + workdistribute.erase(); + return true; } -}; - -bool WorkdistributeDoLower::workdistributeDoLowerPatternMatched = false; + return false; +} /// If A() and B () are present inside teams workdistribute /// @@ -338,34 +327,39 @@ bool WorkdistributeDoLower::workdistributeDoLowerPatternMatched = false; /// B() /// -struct TeamsWorkdistributeToSingle : public OpRewritePattern { - static bool teamsWorkdistributeToSinglePatternMatched; - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(omp::TeamsOp teamsOp, - PatternRewriter &rewriter) const override { - auto workdistributeOp = getPerfectlyNested(teamsOp); - if (!workdistributeOp) { - LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n"); - return failure(); - } - Block *workdistributeBlock = &workdistributeOp.getRegion().front(); - rewriter.eraseOp(workdistributeBlock->getTerminator()); - rewriter.inlineBlockBefore(workdistributeBlock, teamsOp); - rewriter.eraseOp(workdistributeOp); - teamsWorkdistributeToSinglePatternMatched = true; - return success(); - } -}; -bool TeamsWorkdistributeToSingle::teamsWorkdistributeToSinglePatternMatched = - false; +static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) { + auto workdistributeOp = getPerfectlyNested(teamsOp); + if (!workdistributeOp) + return false; + // Get the block containing teamsOp (the parent block). + Block *parentBlock = teamsOp->getBlock(); + Block &workdistributeBlock = *workdistributeOp.getRegion().begin(); + auto insertPoint = Block::iterator(teamsOp); + // Get the range of operations to move (excluding the terminator). + auto workdistributeBegin = workdistributeBlock.begin(); + auto workdistributeEnd = workdistributeBlock.getTerminator()->getIterator(); + // Move the operations from workdistribute block to before teamsOp. + parentBlock->getOperations().splice(insertPoint, + workdistributeBlock.getOperations(), + workdistributeBegin, workdistributeEnd); + // Erase the now-empty workdistributeOp. + workdistributeOp.erase(); + Block &teamsBlock = *teamsOp.getRegion().begin(); + // Check if only the terminator remains and erase teams op. + if (teamsBlock.getOperations().size() == 1 && + teamsBlock.getTerminator() != nullptr) { + teamsOp.erase(); + } + return true; +} struct SplitTargetResult { omp::TargetOp targetOp; omp::TargetDataOp dataOp; }; -/// If multiple coexecutes are nested in a target regions, we will need to split -/// the target region, but we want to preserve the data semantics of the +/// If multiple workdistribute are nested in a target regions, we will need to +/// split the target region, but we want to preserve the data semantics of the /// original data region and avoid unnecessary data movement at each of the /// subkernels - we split the target region into a target_data{target} /// nest where only the outer one moves the data @@ -877,38 +871,22 @@ class LowerWorkdistributePass public: void runOnOperation() override { MLIRContext &context = getContext(); - GreedyRewriteConfig config; - // prevent the pattern driver form merging blocks - config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled); - - Operation *op = getOperation(); - bool anyPatternChanged = false; - { - RewritePatternSet patterns(&context); - patterns.insert(&context); - if (failed(applyPatternsGreedily(op, std::move(patterns), config))) { - emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); - signalPassFailure(); - } - anyPatternChanged |= - FissionWorkdistribute::fissionWorkdistributePatternMatched; - anyPatternChanged |= - WorkdistributeDoLower::workdistributeDoLowerPatternMatched; - } - { - RewritePatternSet patterns(&context); - patterns.insert( - &context); - if (failed(applyPatternsGreedily(op, std::move(patterns), config))) { - emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); - signalPassFailure(); - } - anyPatternChanged |= TeamsWorkdistributeToSingle:: - teamsWorkdistributeToSinglePatternMatched; - } - if (anyPatternChanged) { + auto moduleOp = getOperation(); + bool changed = false; + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + changed |= FissionWorkdistribute(workdistribute); + }); + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + changed |= WorkdistributeDoLower(workdistribute); + }); + moduleOp->walk([&](mlir::omp::TeamsOp teams) { + changed |= TeamsWorkdistributeToSingleOp(teams); + }); + + if (changed) { SmallVector targetOps; - op->walk([&](omp::TargetOp targetOp) { targetOps.push_back(targetOp); }); + moduleOp->walk( + [&](omp::TargetOp targetOp) { targetOps.push_back(targetOp); }); IRRewriter rewriter(&context); for (auto targetOp : targetOps) { auto res = splitTargetData(targetOp, rewriter); diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir index f8351bb64e6e8..00d10d6264ec9 100644 --- a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir @@ -1,12 +1,12 @@ // RUN: fir-opt --lower-workdistribute %s | FileCheck %s // CHECK-LABEL: func.func @x({{.*}}) -// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index // CHECK: omp.teams { // CHECK: omp.parallel { // CHECK: omp.distribute { // CHECK: omp.wsloop { // CHECK: omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) { +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index // CHECK: fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref // CHECK: omp.yield // CHECK: }