Skip to content

[flang] Add support for workdistribute construct in flang frontend #146029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

skc7
Copy link
Contributor

@skc7 skc7 commented Jun 27, 2025

This PR introduces wordistribute construct support in flang frontend.
Also adds a workdistribute mlir op.

The work in this PR is C-P and updated from @ivanradanov commit from coexecute implementation:
flang_workdistribute_iwomp_2024

@skc7 skc7 requested a review from mjklemm June 27, 2025 11:09
@skc7 skc7 marked this pull request as ready for review July 4, 2025 09:09
@llvmbot llvmbot added mlir flang Flang issues not falling into any other category mlir:openmp flang:fir-hlfir flang:openmp flang:semantics flang:parser clang:openmp OpenMP related changes to Clang labels Jul 4, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2025

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-flang-parser

@llvm/pr-subscribers-flang-openmp

Author: Chaitanya (skc7)

Changes

This PR introduces wordistribute construct support in flang frontend.
Also adds a workdistribute mlir op.

The work in this PR is C-P and updated from @ivanradanov commit from coexecute implementation:
flang_workdistribute_iwomp_2024


Full diff: https://github.com/llvm/llvm-project/pull/146029.diff

10 Files Affected:

  • (modified) flang/include/flang/Semantics/openmp-directive-sets.h (+14)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+25-1)
  • (modified) flang/lib/Parser/openmp-parsers.cpp (+6-1)
  • (modified) flang/lib/Semantics/resolve-directives.cpp (+7-1)
  • (added) flang/test/Lower/OpenMP/workdistribute.f90 (+59)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+55)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+23)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+15)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+21)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+13)
diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index dd610c9702c28..7ced6ed9b44d6 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -143,6 +143,7 @@ static const OmpDirectiveSet topTargetSet{
     Directive::OMPD_target_teams_distribute_parallel_do_simd,
     Directive::OMPD_target_teams_distribute_simd,
     Directive::OMPD_target_teams_loop,
+    Directive::OMPD_target_teams_workdistribute,
 };
 
 static const OmpDirectiveSet allTargetSet{topTargetSet};
@@ -172,6 +173,7 @@ static const OmpDirectiveSet topTeamsSet{
     Directive::OMPD_teams_distribute_parallel_do_simd,
     Directive::OMPD_teams_distribute_simd,
     Directive::OMPD_teams_loop,
+    Directive::OMPD_teams_workdistribute,
 };
 
 static const OmpDirectiveSet bottomTeamsSet{
@@ -187,9 +189,16 @@ static const OmpDirectiveSet allTeamsSet{
         Directive::OMPD_target_teams_distribute_parallel_do_simd,
         Directive::OMPD_target_teams_distribute_simd,
         Directive::OMPD_target_teams_loop,
+        Directive::OMPD_target_teams_workdistribute,
     } | topTeamsSet,
 };
 
+static const OmpDirectiveSet allWorkdistributeSet{
+    Directive::OMPD_workdistribute,
+    Directive::OMPD_teams_workdistribute,
+    Directive::OMPD_target_teams_workdistribute,
+};
+
 //===----------------------------------------------------------------------===//
 // Directive sets for groups of multiple directives
 //===----------------------------------------------------------------------===//
@@ -230,6 +239,9 @@ static const OmpDirectiveSet blockConstructSet{
     Directive::OMPD_taskgroup,
     Directive::OMPD_teams,
     Directive::OMPD_workshare,
+    Directive::OMPD_target_teams_workdistribute,
+    Directive::OMPD_teams_workdistribute,
+    Directive::OMPD_workdistribute,
 };
 
 static const OmpDirectiveSet loopConstructSet{
@@ -294,6 +306,7 @@ static const OmpDirectiveSet workShareSet{
         Directive::OMPD_scope,
         Directive::OMPD_sections,
         Directive::OMPD_single,
+        Directive::OMPD_workdistribute,
     } | allDoSet,
 };
 
@@ -376,6 +389,7 @@ static const OmpDirectiveSet nestedReduceWorkshareAllowedSet{
 };
 
 static const OmpDirectiveSet nestedTeamsAllowedSet{
+    Directive::OMPD_workdistribute,
     Directive::OMPD_distribute,
     Directive::OMPD_distribute_parallel_do,
     Directive::OMPD_distribute_parallel_do_simd,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index ebd1d038716e4..16d58b6be535f 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -585,6 +585,16 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
       cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
       break;
 
+    case OMPD_teams_workdistribute:
+      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams_workdistribute:
+      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      processSingleNestedIf([](Directive nestedDir) {
+        return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir);
+      });
+      break;
+
     // Standalone 'target' case.
     case OMPD_target: {
       processSingleNestedIf(
@@ -2682,6 +2692,17 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       queue, item, clauseOps);
 }
 
+static mlir::omp::WorkdistributeOp genWorkdistributeOp(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+    mlir::Location loc, const ConstructQueue &queue,
+    ConstructQueue::const_iterator item) {
+  return genOpWithBody<mlir::omp::WorkdistributeOp>(
+      OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
+                        llvm::omp::Directive::OMPD_workdistribute),
+      queue, item);
+}
+
 //===----------------------------------------------------------------------===//
 // Code generation functions for the standalone version of constructs that can
 // also be a leaf of a composite construct
@@ -3302,7 +3323,10 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     TODO(loc, "Unhandled loop directive (" +
                   llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
   }
-  // case llvm::omp::Directive::OMPD_workdistribute:
+  case llvm::omp::Directive::OMPD_workdistribute:
+    newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue,
+                                item);
+    break;
   case llvm::omp::Directive::OMPD_workshare:
     newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc,
                            queue, item);
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index c55642d969503..ad729932a5f00 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -1492,12 +1492,17 @@ TYPE_PARSER(
         "SINGLE" >> pure(llvm::omp::Directive::OMPD_single),
         "TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data),
         "TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel),
+        "TARGET TEAMS WORKDISTRIBUTE" >>
+            pure(llvm::omp::Directive::OMPD_target_teams_workdistribute),
         "TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams),
         "TARGET" >> pure(llvm::omp::Directive::OMPD_target),
         "TASK"_id >> pure(llvm::omp::Directive::OMPD_task),
         "TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup),
+        "TEAMS WORKDISTRIBUTE" >>
+            pure(llvm::omp::Directive::OMPD_teams_workdistribute),
         "TEAMS" >> pure(llvm::omp::Directive::OMPD_teams),
-        "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare))))
+        "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare),
+        "WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_workdistribute))))
 
 TYPE_PARSER(sourced(construct<OmpBeginBlockDirective>(
     sourced(Parser<OmpBlockDirective>{}), Parser<OmpClauseList>{})))
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 885c02e6ec74b..2e4e05f9e293b 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -1656,10 +1656,13 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) {
   case llvm::omp::Directive::OMPD_task:
   case llvm::omp::Directive::OMPD_taskgroup:
   case llvm::omp::Directive::OMPD_teams:
+  case llvm::omp::Directive::OMPD_workdistribute:
   case llvm::omp::Directive::OMPD_workshare:
   case llvm::omp::Directive::OMPD_parallel_workshare:
   case llvm::omp::Directive::OMPD_target_teams:
+  case llvm::omp::Directive::OMPD_target_teams_workdistribute:
   case llvm::omp::Directive::OMPD_target_parallel:
+  case llvm::omp::Directive::OMPD_teams_workdistribute:
     PushContext(beginDir.source, beginDir.v);
     break;
   default:
@@ -1689,9 +1692,12 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) {
   case llvm::omp::Directive::OMPD_target:
   case llvm::omp::Directive::OMPD_task:
   case llvm::omp::Directive::OMPD_teams:
+  case llvm::omp::Directive::OMPD_workdistribute:
   case llvm::omp::Directive::OMPD_parallel_workshare:
   case llvm::omp::Directive::OMPD_target_teams:
-  case llvm::omp::Directive::OMPD_target_parallel: {
+  case llvm::omp::Directive::OMPD_target_parallel:
+  case llvm::omp::Directive::OMPD_target_teams_workdistribute:
+  case llvm::omp::Directive::OMPD_teams_workdistribute: {
     bool hasPrivate;
     for (const auto *allocName : allocateNames_) {
       hasPrivate = false;
diff --git a/flang/test/Lower/OpenMP/workdistribute.f90 b/flang/test/Lower/OpenMP/workdistribute.f90
new file mode 100644
index 0000000000000..924205bb72e5e
--- /dev/null
+++ b/flang/test/Lower/OpenMP/workdistribute.f90
@@ -0,0 +1,59 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute
+subroutine target_teams_workdistribute()
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp target teams workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end target teams workdistribute
+end subroutine target_teams_workdistribute
+
+! CHECK-LABEL: func @_QPteams_workdistribute
+subroutine teams_workdistribute()
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp teams workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end teams workdistribute
+end subroutine teams_workdistribute
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute_m
+subroutine target_teams_workdistribute_m()
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp target
+  !$omp teams
+  !$omp workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end workdistribute
+  !$omp end teams
+  !$omp end target
+end subroutine target_teams_workdistribute_m
+
+! CHECK-LABEL: func @_QPteams_workdistribute_m
+subroutine teams_workdistribute_m()
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp teams
+  !$omp workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end workdistribute
+  !$omp end teams
+end subroutine teams_workdistribute_m
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index a87111cb5a11d..d1831db37fc46 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -1286,6 +1286,15 @@ def OMP_EndWorkshare : Directive<[Spelling<"end workshare">]> {
   let category = OMP_Workshare.category;
   let languages = [L_Fortran];
 }
+def OMP_Workdistribute : Directive<[Spelling<"workdistribute">]> {
+  let association = AS_Block;
+  let category = CA_Executable;
+}
+def OMP_EndWorkdistribute : Directive<[Spelling<"end workdistribute">]> {
+  let leafConstructs = OMP_Workdistribute.leafConstructs;
+  let association = OMP_Workdistribute.association;
+  let category = OMP_Workdistribute.category;
+}
 
 //===----------------------------------------------------------------------===//
 // Definitions of OpenMP compound directives
@@ -2429,6 +2438,34 @@ def OMP_TargetTeamsDistributeSimd
   let leafConstructs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Simd];
   let category = CA_Executable;
 }
+def OMP_TargetTeamsWorkdistribute : Directive<[Spelling<"target teams workdistribute">]> {
+  let allowedClauses = [
+    VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_Depend>,
+    VersionedClause<OMPC_FirstPrivate>,
+    VersionedClause<OMPC_HasDeviceAddr, 51>,
+    VersionedClause<OMPC_If>,
+    VersionedClause<OMPC_IsDevicePtr>,
+    VersionedClause<OMPC_Map>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Shared>,
+    VersionedClause<OMPC_UsesAllocators, 50>,
+  ];
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_DefaultMap>,
+    VersionedClause<OMPC_Device>,
+    VersionedClause<OMPC_NoWait>,
+    VersionedClause<OMPC_NumTeams>,
+    VersionedClause<OMPC_OMPX_DynCGroupMem>,
+    VersionedClause<OMPC_OMPX_Bare>,
+    VersionedClause<OMPC_ThreadLimit>,
+  ];
+  let leafConstructs = [OMP_Target, OMP_Teams, OMP_Workdistribute];
+  let category = CA_Executable;
+}
 def OMP_target_teams_loop : Directive<[Spelling<"target teams loop">]> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
@@ -2659,6 +2696,24 @@ def OMP_TeamsDistributeSimd : Directive<[Spelling<"teams distribute simd">]> {
   let leafConstructs = [OMP_Teams, OMP_Distribute, OMP_Simd];
   let category = CA_Executable;
 }
+def OMP_TeamsWorkdistribute : Directive<[Spelling<"teams workdistribute">]> {
+  let allowedClauses = [
+    VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_FirstPrivate>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Shared>,
+  ];
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_If, 52>,
+    VersionedClause<OMPC_NumTeams>,
+    VersionedClause<OMPC_ThreadLimit>,
+  ];
+  let leafConstructs = [OMP_Teams, OMP_Workdistribute];
+  let category = CA_Executable;
+}
 def OMP_teams_loop : Directive<[Spelling<"teams loop">]> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index ac80926053a2d..a58e09d7bda71 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1887,4 +1887,27 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// workdistribute Construct
+//===----------------------------------------------------------------------===//
+
+def WorkdistributeOp : OpenMP_Op<"workdistribute"> {
+  let summary = "workdistribute directive";
+  let description = [{
+    workdistribute divides execution of the enclosed structured block into
+    separate units of work, each executed only once by each
+    initial thread in the league.
+    ```
+    !$omp target teams
+        !$omp workdistribute
+        y = a * x + y 
+        !$omp end workdistribute
+    !$omp end target teams
+    ```
+  }];
+  let regions = (region AnyRegion:$region);
+  let hasVerifier = 1;
+  let assemblyFormat = "$region attr-dict";
+}
+
 #endif // OPENMP_OPS
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index e94d570b57122..e2dd338829e76 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3493,6 +3493,21 @@ LogicalResult ScanOp::verify() {
                    "reduction modifier");
 }
 
+//===----------------------------------------------------------------------===//
+// WorkdistributeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WorkdistributeOp::verify() {
+  Region &region = getRegion();
+  if (!region.hasOneBlock())
+    return emitOpError("region must contain exactly one block");
+
+  Operation *parentOp = (*this)->getParentOp();
+  if (!llvm::dyn_cast<TeamsOp>(parentOp))
+    return emitOpError("workdistribute must be nested under teams");
+  return success();
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 060b3cd2455a0..522d20558a2b5 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2960,3 +2960,24 @@ llvm.func @invalid_mapper(%0 : !llvm.ptr) {
   }
   llvm.return
 }
+
+func.func @invalid_workdistribute_with_multiple_blocks() {
+  // expected-error @below {{workdistribute must be nested under teams}}
+  omp.workdistribute {
+    omp.terminator
+  }
+  return
+}
+
+func.func @invalid_workdistribute_with_multiple_blocks() {
+  omp.teams {
+  // expected-error @below {{region must contain exactly one block}}
+  omp.workdistribute {
+    cf.br ^bb1
+  ^bb1:
+    omp.terminator
+  }
+  omp.terminator
+  }
+  return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 47cfc5278a5d0..af80284e53537 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3197,3 +3197,16 @@ func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
   }
   return
 }
+
+// CHECK-LABEL: func @omp_workdistribute
+func.func @omp_workdistribute() {
+  // CHECK: omp.teams
+  omp.teams {
+  // CHECK: omp.workdistribute
+  omp.workdistribute {
+    omp.terminator
+  }
+  omp.terminator
+  }
+  return
+}

@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2025

@llvm/pr-subscribers-flang-semantics

Author: Chaitanya (skc7)

Changes

This PR introduces wordistribute construct support in flang frontend.
Also adds a workdistribute mlir op.

The work in this PR is C-P and updated from @ivanradanov commit from coexecute implementation:
flang_workdistribute_iwomp_2024


Full diff: https://github.com/llvm/llvm-project/pull/146029.diff

10 Files Affected:

  • (modified) flang/include/flang/Semantics/openmp-directive-sets.h (+14)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+25-1)
  • (modified) flang/lib/Parser/openmp-parsers.cpp (+6-1)
  • (modified) flang/lib/Semantics/resolve-directives.cpp (+7-1)
  • (added) flang/test/Lower/OpenMP/workdistribute.f90 (+59)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+55)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+23)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+15)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+21)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+13)
diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index dd610c9702c28..7ced6ed9b44d6 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -143,6 +143,7 @@ static const OmpDirectiveSet topTargetSet{
     Directive::OMPD_target_teams_distribute_parallel_do_simd,
     Directive::OMPD_target_teams_distribute_simd,
     Directive::OMPD_target_teams_loop,
+    Directive::OMPD_target_teams_workdistribute,
 };
 
 static const OmpDirectiveSet allTargetSet{topTargetSet};
@@ -172,6 +173,7 @@ static const OmpDirectiveSet topTeamsSet{
     Directive::OMPD_teams_distribute_parallel_do_simd,
     Directive::OMPD_teams_distribute_simd,
     Directive::OMPD_teams_loop,
+    Directive::OMPD_teams_workdistribute,
 };
 
 static const OmpDirectiveSet bottomTeamsSet{
@@ -187,9 +189,16 @@ static const OmpDirectiveSet allTeamsSet{
         Directive::OMPD_target_teams_distribute_parallel_do_simd,
         Directive::OMPD_target_teams_distribute_simd,
         Directive::OMPD_target_teams_loop,
+        Directive::OMPD_target_teams_workdistribute,
     } | topTeamsSet,
 };
 
+static const OmpDirectiveSet allWorkdistributeSet{
+    Directive::OMPD_workdistribute,
+    Directive::OMPD_teams_workdistribute,
+    Directive::OMPD_target_teams_workdistribute,
+};
+
 //===----------------------------------------------------------------------===//
 // Directive sets for groups of multiple directives
 //===----------------------------------------------------------------------===//
@@ -230,6 +239,9 @@ static const OmpDirectiveSet blockConstructSet{
     Directive::OMPD_taskgroup,
     Directive::OMPD_teams,
     Directive::OMPD_workshare,
+    Directive::OMPD_target_teams_workdistribute,
+    Directive::OMPD_teams_workdistribute,
+    Directive::OMPD_workdistribute,
 };
 
 static const OmpDirectiveSet loopConstructSet{
@@ -294,6 +306,7 @@ static const OmpDirectiveSet workShareSet{
         Directive::OMPD_scope,
         Directive::OMPD_sections,
         Directive::OMPD_single,
+        Directive::OMPD_workdistribute,
     } | allDoSet,
 };
 
@@ -376,6 +389,7 @@ static const OmpDirectiveSet nestedReduceWorkshareAllowedSet{
 };
 
 static const OmpDirectiveSet nestedTeamsAllowedSet{
+    Directive::OMPD_workdistribute,
     Directive::OMPD_distribute,
     Directive::OMPD_distribute_parallel_do,
     Directive::OMPD_distribute_parallel_do_simd,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index ebd1d038716e4..16d58b6be535f 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -585,6 +585,16 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
       cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
       break;
 
+    case OMPD_teams_workdistribute:
+      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams_workdistribute:
+      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      processSingleNestedIf([](Directive nestedDir) {
+        return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir);
+      });
+      break;
+
     // Standalone 'target' case.
     case OMPD_target: {
       processSingleNestedIf(
@@ -2682,6 +2692,17 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       queue, item, clauseOps);
 }
 
+static mlir::omp::WorkdistributeOp genWorkdistributeOp(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+    mlir::Location loc, const ConstructQueue &queue,
+    ConstructQueue::const_iterator item) {
+  return genOpWithBody<mlir::omp::WorkdistributeOp>(
+      OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
+                        llvm::omp::Directive::OMPD_workdistribute),
+      queue, item);
+}
+
 //===----------------------------------------------------------------------===//
 // Code generation functions for the standalone version of constructs that can
 // also be a leaf of a composite construct
@@ -3302,7 +3323,10 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     TODO(loc, "Unhandled loop directive (" +
                   llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
   }
-  // case llvm::omp::Directive::OMPD_workdistribute:
+  case llvm::omp::Directive::OMPD_workdistribute:
+    newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue,
+                                item);
+    break;
   case llvm::omp::Directive::OMPD_workshare:
     newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc,
                            queue, item);
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index c55642d969503..ad729932a5f00 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -1492,12 +1492,17 @@ TYPE_PARSER(
         "SINGLE" >> pure(llvm::omp::Directive::OMPD_single),
         "TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data),
         "TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel),
+        "TARGET TEAMS WORKDISTRIBUTE" >>
+            pure(llvm::omp::Directive::OMPD_target_teams_workdistribute),
         "TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams),
         "TARGET" >> pure(llvm::omp::Directive::OMPD_target),
         "TASK"_id >> pure(llvm::omp::Directive::OMPD_task),
         "TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup),
+        "TEAMS WORKDISTRIBUTE" >>
+            pure(llvm::omp::Directive::OMPD_teams_workdistribute),
         "TEAMS" >> pure(llvm::omp::Directive::OMPD_teams),
-        "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare))))
+        "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare),
+        "WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_workdistribute))))
 
 TYPE_PARSER(sourced(construct<OmpBeginBlockDirective>(
     sourced(Parser<OmpBlockDirective>{}), Parser<OmpClauseList>{})))
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 885c02e6ec74b..2e4e05f9e293b 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -1656,10 +1656,13 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) {
   case llvm::omp::Directive::OMPD_task:
   case llvm::omp::Directive::OMPD_taskgroup:
   case llvm::omp::Directive::OMPD_teams:
+  case llvm::omp::Directive::OMPD_workdistribute:
   case llvm::omp::Directive::OMPD_workshare:
   case llvm::omp::Directive::OMPD_parallel_workshare:
   case llvm::omp::Directive::OMPD_target_teams:
+  case llvm::omp::Directive::OMPD_target_teams_workdistribute:
   case llvm::omp::Directive::OMPD_target_parallel:
+  case llvm::omp::Directive::OMPD_teams_workdistribute:
     PushContext(beginDir.source, beginDir.v);
     break;
   default:
@@ -1689,9 +1692,12 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) {
   case llvm::omp::Directive::OMPD_target:
   case llvm::omp::Directive::OMPD_task:
   case llvm::omp::Directive::OMPD_teams:
+  case llvm::omp::Directive::OMPD_workdistribute:
   case llvm::omp::Directive::OMPD_parallel_workshare:
   case llvm::omp::Directive::OMPD_target_teams:
-  case llvm::omp::Directive::OMPD_target_parallel: {
+  case llvm::omp::Directive::OMPD_target_parallel:
+  case llvm::omp::Directive::OMPD_target_teams_workdistribute:
+  case llvm::omp::Directive::OMPD_teams_workdistribute: {
     bool hasPrivate;
     for (const auto *allocName : allocateNames_) {
       hasPrivate = false;
diff --git a/flang/test/Lower/OpenMP/workdistribute.f90 b/flang/test/Lower/OpenMP/workdistribute.f90
new file mode 100644
index 0000000000000..924205bb72e5e
--- /dev/null
+++ b/flang/test/Lower/OpenMP/workdistribute.f90
@@ -0,0 +1,59 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute
+subroutine target_teams_workdistribute()
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp target teams workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end target teams workdistribute
+end subroutine target_teams_workdistribute
+
+! CHECK-LABEL: func @_QPteams_workdistribute
+subroutine teams_workdistribute()
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp teams workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end teams workdistribute
+end subroutine teams_workdistribute
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute_m
+subroutine target_teams_workdistribute_m()
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp target
+  !$omp teams
+  !$omp workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end workdistribute
+  !$omp end teams
+  !$omp end target
+end subroutine target_teams_workdistribute_m
+
+! CHECK-LABEL: func @_QPteams_workdistribute_m
+subroutine teams_workdistribute_m()
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp teams
+  !$omp workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end workdistribute
+  !$omp end teams
+end subroutine teams_workdistribute_m
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index a87111cb5a11d..d1831db37fc46 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -1286,6 +1286,15 @@ def OMP_EndWorkshare : Directive<[Spelling<"end workshare">]> {
   let category = OMP_Workshare.category;
   let languages = [L_Fortran];
 }
+def OMP_Workdistribute : Directive<[Spelling<"workdistribute">]> {
+  let association = AS_Block;
+  let category = CA_Executable;
+}
+def OMP_EndWorkdistribute : Directive<[Spelling<"end workdistribute">]> {
+  let leafConstructs = OMP_Workdistribute.leafConstructs;
+  let association = OMP_Workdistribute.association;
+  let category = OMP_Workdistribute.category;
+}
 
 //===----------------------------------------------------------------------===//
 // Definitions of OpenMP compound directives
@@ -2429,6 +2438,34 @@ def OMP_TargetTeamsDistributeSimd
   let leafConstructs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Simd];
   let category = CA_Executable;
 }
+def OMP_TargetTeamsWorkdistribute : Directive<[Spelling<"target teams workdistribute">]> {
+  let allowedClauses = [
+    VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_Depend>,
+    VersionedClause<OMPC_FirstPrivate>,
+    VersionedClause<OMPC_HasDeviceAddr, 51>,
+    VersionedClause<OMPC_If>,
+    VersionedClause<OMPC_IsDevicePtr>,
+    VersionedClause<OMPC_Map>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Shared>,
+    VersionedClause<OMPC_UsesAllocators, 50>,
+  ];
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_DefaultMap>,
+    VersionedClause<OMPC_Device>,
+    VersionedClause<OMPC_NoWait>,
+    VersionedClause<OMPC_NumTeams>,
+    VersionedClause<OMPC_OMPX_DynCGroupMem>,
+    VersionedClause<OMPC_OMPX_Bare>,
+    VersionedClause<OMPC_ThreadLimit>,
+  ];
+  let leafConstructs = [OMP_Target, OMP_Teams, OMP_Workdistribute];
+  let category = CA_Executable;
+}
 def OMP_target_teams_loop : Directive<[Spelling<"target teams loop">]> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
@@ -2659,6 +2696,24 @@ def OMP_TeamsDistributeSimd : Directive<[Spelling<"teams distribute simd">]> {
   let leafConstructs = [OMP_Teams, OMP_Distribute, OMP_Simd];
   let category = CA_Executable;
 }
+def OMP_TeamsWorkdistribute : Directive<[Spelling<"teams workdistribute">]> {
+  let allowedClauses = [
+    VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_FirstPrivate>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Shared>,
+  ];
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_If, 52>,
+    VersionedClause<OMPC_NumTeams>,
+    VersionedClause<OMPC_ThreadLimit>,
+  ];
+  let leafConstructs = [OMP_Teams, OMP_Workdistribute];
+  let category = CA_Executable;
+}
 def OMP_teams_loop : Directive<[Spelling<"teams loop">]> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index ac80926053a2d..a58e09d7bda71 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1887,4 +1887,27 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// workdistribute Construct
+//===----------------------------------------------------------------------===//
+
+def WorkdistributeOp : OpenMP_Op<"workdistribute"> {
+  let summary = "workdistribute directive";
+  let description = [{
+    workdistribute divides execution of the enclosed structured block into
+    separate units of work, each executed only once by each
+    initial thread in the league.
+    ```
+    !$omp target teams
+        !$omp workdistribute
+        y = a * x + y 
+        !$omp end workdistribute
+    !$omp end target teams
+    ```
+  }];
+  let regions = (region AnyRegion:$region);
+  let hasVerifier = 1;
+  let assemblyFormat = "$region attr-dict";
+}
+
 #endif // OPENMP_OPS
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index e94d570b57122..e2dd338829e76 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3493,6 +3493,21 @@ LogicalResult ScanOp::verify() {
                    "reduction modifier");
 }
 
+//===----------------------------------------------------------------------===//
+// WorkdistributeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WorkdistributeOp::verify() {
+  Region &region = getRegion();
+  if (!region.hasOneBlock())
+    return emitOpError("region must contain exactly one block");
+
+  Operation *parentOp = (*this)->getParentOp();
+  if (!llvm::dyn_cast<TeamsOp>(parentOp))
+    return emitOpError("workdistribute must be nested under teams");
+  return success();
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 060b3cd2455a0..522d20558a2b5 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2960,3 +2960,24 @@ llvm.func @invalid_mapper(%0 : !llvm.ptr) {
   }
   llvm.return
 }
+
+func.func @invalid_workdistribute_with_multiple_blocks() {
+  // expected-error @below {{workdistribute must be nested under teams}}
+  omp.workdistribute {
+    omp.terminator
+  }
+  return
+}
+
+func.func @invalid_workdistribute_with_multiple_blocks() {
+  omp.teams {
+  // expected-error @below {{region must contain exactly one block}}
+  omp.workdistribute {
+    cf.br ^bb1
+  ^bb1:
+    omp.terminator
+  }
+  omp.terminator
+  }
+  return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 47cfc5278a5d0..af80284e53537 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3197,3 +3197,16 @@ func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
   }
   return
 }
+
+// CHECK-LABEL: func @omp_workdistribute
+func.func @omp_workdistribute() {
+  // CHECK: omp.teams
+  omp.teams {
+  // CHECK: omp.workdistribute
+  omp.workdistribute {
+    omp.terminator
+  }
+  omp.terminator
+  }
+  return
+}

@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2025

@llvm/pr-subscribers-mlir

Author: Chaitanya (skc7)

Changes

This PR introduces wordistribute construct support in flang frontend.
Also adds a workdistribute mlir op.

The work in this PR is C-P and updated from @ivanradanov commit from coexecute implementation:
flang_workdistribute_iwomp_2024


Full diff: https://github.com/llvm/llvm-project/pull/146029.diff

10 Files Affected:

  • (modified) flang/include/flang/Semantics/openmp-directive-sets.h (+14)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+25-1)
  • (modified) flang/lib/Parser/openmp-parsers.cpp (+6-1)
  • (modified) flang/lib/Semantics/resolve-directives.cpp (+7-1)
  • (added) flang/test/Lower/OpenMP/workdistribute.f90 (+59)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+55)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+23)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+15)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+21)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+13)
diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index dd610c9702c28..7ced6ed9b44d6 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -143,6 +143,7 @@ static const OmpDirectiveSet topTargetSet{
     Directive::OMPD_target_teams_distribute_parallel_do_simd,
     Directive::OMPD_target_teams_distribute_simd,
     Directive::OMPD_target_teams_loop,
+    Directive::OMPD_target_teams_workdistribute,
 };
 
 static const OmpDirectiveSet allTargetSet{topTargetSet};
@@ -172,6 +173,7 @@ static const OmpDirectiveSet topTeamsSet{
     Directive::OMPD_teams_distribute_parallel_do_simd,
     Directive::OMPD_teams_distribute_simd,
     Directive::OMPD_teams_loop,
+    Directive::OMPD_teams_workdistribute,
 };
 
 static const OmpDirectiveSet bottomTeamsSet{
@@ -187,9 +189,16 @@ static const OmpDirectiveSet allTeamsSet{
         Directive::OMPD_target_teams_distribute_parallel_do_simd,
         Directive::OMPD_target_teams_distribute_simd,
         Directive::OMPD_target_teams_loop,
+        Directive::OMPD_target_teams_workdistribute,
     } | topTeamsSet,
 };
 
+static const OmpDirectiveSet allWorkdistributeSet{
+    Directive::OMPD_workdistribute,
+    Directive::OMPD_teams_workdistribute,
+    Directive::OMPD_target_teams_workdistribute,
+};
+
 //===----------------------------------------------------------------------===//
 // Directive sets for groups of multiple directives
 //===----------------------------------------------------------------------===//
@@ -230,6 +239,9 @@ static const OmpDirectiveSet blockConstructSet{
     Directive::OMPD_taskgroup,
     Directive::OMPD_teams,
     Directive::OMPD_workshare,
+    Directive::OMPD_target_teams_workdistribute,
+    Directive::OMPD_teams_workdistribute,
+    Directive::OMPD_workdistribute,
 };
 
 static const OmpDirectiveSet loopConstructSet{
@@ -294,6 +306,7 @@ static const OmpDirectiveSet workShareSet{
         Directive::OMPD_scope,
         Directive::OMPD_sections,
         Directive::OMPD_single,
+        Directive::OMPD_workdistribute,
     } | allDoSet,
 };
 
@@ -376,6 +389,7 @@ static const OmpDirectiveSet nestedReduceWorkshareAllowedSet{
 };
 
 static const OmpDirectiveSet nestedTeamsAllowedSet{
+    Directive::OMPD_workdistribute,
     Directive::OMPD_distribute,
     Directive::OMPD_distribute_parallel_do,
     Directive::OMPD_distribute_parallel_do_simd,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index ebd1d038716e4..16d58b6be535f 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -585,6 +585,16 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
       cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
       break;
 
+    case OMPD_teams_workdistribute:
+      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams_workdistribute:
+      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      processSingleNestedIf([](Directive nestedDir) {
+        return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir);
+      });
+      break;
+
     // Standalone 'target' case.
     case OMPD_target: {
       processSingleNestedIf(
@@ -2682,6 +2692,17 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       queue, item, clauseOps);
 }
 
+static mlir::omp::WorkdistributeOp genWorkdistributeOp(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+    mlir::Location loc, const ConstructQueue &queue,
+    ConstructQueue::const_iterator item) {
+  return genOpWithBody<mlir::omp::WorkdistributeOp>(
+      OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
+                        llvm::omp::Directive::OMPD_workdistribute),
+      queue, item);
+}
+
 //===----------------------------------------------------------------------===//
 // Code generation functions for the standalone version of constructs that can
 // also be a leaf of a composite construct
@@ -3302,7 +3323,10 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     TODO(loc, "Unhandled loop directive (" +
                   llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
   }
-  // case llvm::omp::Directive::OMPD_workdistribute:
+  case llvm::omp::Directive::OMPD_workdistribute:
+    newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue,
+                                item);
+    break;
   case llvm::omp::Directive::OMPD_workshare:
     newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc,
                            queue, item);
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index c55642d969503..ad729932a5f00 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -1492,12 +1492,17 @@ TYPE_PARSER(
         "SINGLE" >> pure(llvm::omp::Directive::OMPD_single),
         "TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data),
         "TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel),
+        "TARGET TEAMS WORKDISTRIBUTE" >>
+            pure(llvm::omp::Directive::OMPD_target_teams_workdistribute),
         "TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams),
         "TARGET" >> pure(llvm::omp::Directive::OMPD_target),
         "TASK"_id >> pure(llvm::omp::Directive::OMPD_task),
         "TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup),
+        "TEAMS WORKDISTRIBUTE" >>
+            pure(llvm::omp::Directive::OMPD_teams_workdistribute),
         "TEAMS" >> pure(llvm::omp::Directive::OMPD_teams),
-        "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare))))
+        "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare),
+        "WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_workdistribute))))
 
 TYPE_PARSER(sourced(construct<OmpBeginBlockDirective>(
     sourced(Parser<OmpBlockDirective>{}), Parser<OmpClauseList>{})))
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 885c02e6ec74b..2e4e05f9e293b 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -1656,10 +1656,13 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) {
   case llvm::omp::Directive::OMPD_task:
   case llvm::omp::Directive::OMPD_taskgroup:
   case llvm::omp::Directive::OMPD_teams:
+  case llvm::omp::Directive::OMPD_workdistribute:
   case llvm::omp::Directive::OMPD_workshare:
   case llvm::omp::Directive::OMPD_parallel_workshare:
   case llvm::omp::Directive::OMPD_target_teams:
+  case llvm::omp::Directive::OMPD_target_teams_workdistribute:
   case llvm::omp::Directive::OMPD_target_parallel:
+  case llvm::omp::Directive::OMPD_teams_workdistribute:
     PushContext(beginDir.source, beginDir.v);
     break;
   default:
@@ -1689,9 +1692,12 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) {
   case llvm::omp::Directive::OMPD_target:
   case llvm::omp::Directive::OMPD_task:
   case llvm::omp::Directive::OMPD_teams:
+  case llvm::omp::Directive::OMPD_workdistribute:
   case llvm::omp::Directive::OMPD_parallel_workshare:
   case llvm::omp::Directive::OMPD_target_teams:
-  case llvm::omp::Directive::OMPD_target_parallel: {
+  case llvm::omp::Directive::OMPD_target_parallel:
+  case llvm::omp::Directive::OMPD_target_teams_workdistribute:
+  case llvm::omp::Directive::OMPD_teams_workdistribute: {
     bool hasPrivate;
     for (const auto *allocName : allocateNames_) {
       hasPrivate = false;
diff --git a/flang/test/Lower/OpenMP/workdistribute.f90 b/flang/test/Lower/OpenMP/workdistribute.f90
new file mode 100644
index 0000000000000..924205bb72e5e
--- /dev/null
+++ b/flang/test/Lower/OpenMP/workdistribute.f90
@@ -0,0 +1,59 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute
+subroutine target_teams_workdistribute()
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp target teams workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end target teams workdistribute
+end subroutine target_teams_workdistribute
+
+! CHECK-LABEL: func @_QPteams_workdistribute
+subroutine teams_workdistribute()
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp teams workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end teams workdistribute
+end subroutine teams_workdistribute
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute_m
+subroutine target_teams_workdistribute_m()
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp target
+  !$omp teams
+  !$omp workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end workdistribute
+  !$omp end teams
+  !$omp end target
+end subroutine target_teams_workdistribute_m
+
+! CHECK-LABEL: func @_QPteams_workdistribute_m
+subroutine teams_workdistribute_m()
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp teams
+  !$omp workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end workdistribute
+  !$omp end teams
+end subroutine teams_workdistribute_m
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index a87111cb5a11d..d1831db37fc46 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -1286,6 +1286,15 @@ def OMP_EndWorkshare : Directive<[Spelling<"end workshare">]> {
   let category = OMP_Workshare.category;
   let languages = [L_Fortran];
 }
+def OMP_Workdistribute : Directive<[Spelling<"workdistribute">]> {
+  let association = AS_Block;
+  let category = CA_Executable;
+}
+def OMP_EndWorkdistribute : Directive<[Spelling<"end workdistribute">]> {
+  let leafConstructs = OMP_Workdistribute.leafConstructs;
+  let association = OMP_Workdistribute.association;
+  let category = OMP_Workdistribute.category;
+}
 
 //===----------------------------------------------------------------------===//
 // Definitions of OpenMP compound directives
@@ -2429,6 +2438,34 @@ def OMP_TargetTeamsDistributeSimd
   let leafConstructs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Simd];
   let category = CA_Executable;
 }
+def OMP_TargetTeamsWorkdistribute : Directive<[Spelling<"target teams workdistribute">]> {
+  let allowedClauses = [
+    VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_Depend>,
+    VersionedClause<OMPC_FirstPrivate>,
+    VersionedClause<OMPC_HasDeviceAddr, 51>,
+    VersionedClause<OMPC_If>,
+    VersionedClause<OMPC_IsDevicePtr>,
+    VersionedClause<OMPC_Map>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Shared>,
+    VersionedClause<OMPC_UsesAllocators, 50>,
+  ];
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_DefaultMap>,
+    VersionedClause<OMPC_Device>,
+    VersionedClause<OMPC_NoWait>,
+    VersionedClause<OMPC_NumTeams>,
+    VersionedClause<OMPC_OMPX_DynCGroupMem>,
+    VersionedClause<OMPC_OMPX_Bare>,
+    VersionedClause<OMPC_ThreadLimit>,
+  ];
+  let leafConstructs = [OMP_Target, OMP_Teams, OMP_Workdistribute];
+  let category = CA_Executable;
+}
 def OMP_target_teams_loop : Directive<[Spelling<"target teams loop">]> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
@@ -2659,6 +2696,24 @@ def OMP_TeamsDistributeSimd : Directive<[Spelling<"teams distribute simd">]> {
   let leafConstructs = [OMP_Teams, OMP_Distribute, OMP_Simd];
   let category = CA_Executable;
 }
+def OMP_TeamsWorkdistribute : Directive<[Spelling<"teams workdistribute">]> {
+  let allowedClauses = [
+    VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_FirstPrivate>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Shared>,
+  ];
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_If, 52>,
+    VersionedClause<OMPC_NumTeams>,
+    VersionedClause<OMPC_ThreadLimit>,
+  ];
+  let leafConstructs = [OMP_Teams, OMP_Workdistribute];
+  let category = CA_Executable;
+}
 def OMP_teams_loop : Directive<[Spelling<"teams loop">]> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index ac80926053a2d..a58e09d7bda71 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1887,4 +1887,27 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// workdistribute Construct
+//===----------------------------------------------------------------------===//
+
+def WorkdistributeOp : OpenMP_Op<"workdistribute"> {
+  let summary = "workdistribute directive";
+  let description = [{
+    workdistribute divides execution of the enclosed structured block into
+    separate units of work, each executed only once by each
+    initial thread in the league.
+    ```
+    !$omp target teams
+        !$omp workdistribute
+        y = a * x + y 
+        !$omp end workdistribute
+    !$omp end target teams
+    ```
+  }];
+  let regions = (region AnyRegion:$region);
+  let hasVerifier = 1;
+  let assemblyFormat = "$region attr-dict";
+}
+
 #endif // OPENMP_OPS
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index e94d570b57122..e2dd338829e76 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3493,6 +3493,21 @@ LogicalResult ScanOp::verify() {
                    "reduction modifier");
 }
 
+//===----------------------------------------------------------------------===//
+// WorkdistributeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WorkdistributeOp::verify() {
+  Region &region = getRegion();
+  if (!region.hasOneBlock())
+    return emitOpError("region must contain exactly one block");
+
+  Operation *parentOp = (*this)->getParentOp();
+  if (!llvm::dyn_cast<TeamsOp>(parentOp))
+    return emitOpError("workdistribute must be nested under teams");
+  return success();
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 060b3cd2455a0..522d20558a2b5 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2960,3 +2960,24 @@ llvm.func @invalid_mapper(%0 : !llvm.ptr) {
   }
   llvm.return
 }
+
+func.func @invalid_workdistribute_with_multiple_blocks() {
+  // expected-error @below {{workdistribute must be nested under teams}}
+  omp.workdistribute {
+    omp.terminator
+  }
+  return
+}
+
+func.func @invalid_workdistribute_with_multiple_blocks() {
+  omp.teams {
+  // expected-error @below {{region must contain exactly one block}}
+  omp.workdistribute {
+    cf.br ^bb1
+  ^bb1:
+    omp.terminator
+  }
+  omp.terminator
+  }
+  return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 47cfc5278a5d0..af80284e53537 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3197,3 +3197,16 @@ func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
   }
   return
 }
+
+// CHECK-LABEL: func @omp_workdistribute
+func.func @omp_workdistribute() {
+  // CHECK: omp.teams
+  omp.teams {
+  // CHECK: omp.workdistribute
+  omp.workdistribute {
+    omp.terminator
+  }
+  omp.terminator
+  }
+  return
+}

@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2025

@llvm/pr-subscribers-mlir-openmp

Author: Chaitanya (skc7)

Changes

This PR introduces wordistribute construct support in flang frontend.
Also adds a workdistribute mlir op.

The work in this PR is C-P and updated from @ivanradanov commit from coexecute implementation:
flang_workdistribute_iwomp_2024


Full diff: https://github.com/llvm/llvm-project/pull/146029.diff

10 Files Affected:

  • (modified) flang/include/flang/Semantics/openmp-directive-sets.h (+14)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+25-1)
  • (modified) flang/lib/Parser/openmp-parsers.cpp (+6-1)
  • (modified) flang/lib/Semantics/resolve-directives.cpp (+7-1)
  • (added) flang/test/Lower/OpenMP/workdistribute.f90 (+59)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+55)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+23)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+15)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+21)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+13)
diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index dd610c9702c28..7ced6ed9b44d6 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -143,6 +143,7 @@ static const OmpDirectiveSet topTargetSet{
     Directive::OMPD_target_teams_distribute_parallel_do_simd,
     Directive::OMPD_target_teams_distribute_simd,
     Directive::OMPD_target_teams_loop,
+    Directive::OMPD_target_teams_workdistribute,
 };
 
 static const OmpDirectiveSet allTargetSet{topTargetSet};
@@ -172,6 +173,7 @@ static const OmpDirectiveSet topTeamsSet{
     Directive::OMPD_teams_distribute_parallel_do_simd,
     Directive::OMPD_teams_distribute_simd,
     Directive::OMPD_teams_loop,
+    Directive::OMPD_teams_workdistribute,
 };
 
 static const OmpDirectiveSet bottomTeamsSet{
@@ -187,9 +189,16 @@ static const OmpDirectiveSet allTeamsSet{
         Directive::OMPD_target_teams_distribute_parallel_do_simd,
         Directive::OMPD_target_teams_distribute_simd,
         Directive::OMPD_target_teams_loop,
+        Directive::OMPD_target_teams_workdistribute,
     } | topTeamsSet,
 };
 
+static const OmpDirectiveSet allWorkdistributeSet{
+    Directive::OMPD_workdistribute,
+    Directive::OMPD_teams_workdistribute,
+    Directive::OMPD_target_teams_workdistribute,
+};
+
 //===----------------------------------------------------------------------===//
 // Directive sets for groups of multiple directives
 //===----------------------------------------------------------------------===//
@@ -230,6 +239,9 @@ static const OmpDirectiveSet blockConstructSet{
     Directive::OMPD_taskgroup,
     Directive::OMPD_teams,
     Directive::OMPD_workshare,
+    Directive::OMPD_target_teams_workdistribute,
+    Directive::OMPD_teams_workdistribute,
+    Directive::OMPD_workdistribute,
 };
 
 static const OmpDirectiveSet loopConstructSet{
@@ -294,6 +306,7 @@ static const OmpDirectiveSet workShareSet{
         Directive::OMPD_scope,
         Directive::OMPD_sections,
         Directive::OMPD_single,
+        Directive::OMPD_workdistribute,
     } | allDoSet,
 };
 
@@ -376,6 +389,7 @@ static const OmpDirectiveSet nestedReduceWorkshareAllowedSet{
 };
 
 static const OmpDirectiveSet nestedTeamsAllowedSet{
+    Directive::OMPD_workdistribute,
     Directive::OMPD_distribute,
     Directive::OMPD_distribute_parallel_do,
     Directive::OMPD_distribute_parallel_do_simd,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index ebd1d038716e4..16d58b6be535f 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -585,6 +585,16 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
       cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
       break;
 
+    case OMPD_teams_workdistribute:
+      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams_workdistribute:
+      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      processSingleNestedIf([](Directive nestedDir) {
+        return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir);
+      });
+      break;
+
     // Standalone 'target' case.
     case OMPD_target: {
       processSingleNestedIf(
@@ -2682,6 +2692,17 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       queue, item, clauseOps);
 }
 
+static mlir::omp::WorkdistributeOp genWorkdistributeOp(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+    mlir::Location loc, const ConstructQueue &queue,
+    ConstructQueue::const_iterator item) {
+  return genOpWithBody<mlir::omp::WorkdistributeOp>(
+      OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
+                        llvm::omp::Directive::OMPD_workdistribute),
+      queue, item);
+}
+
 //===----------------------------------------------------------------------===//
 // Code generation functions for the standalone version of constructs that can
 // also be a leaf of a composite construct
@@ -3302,7 +3323,10 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     TODO(loc, "Unhandled loop directive (" +
                   llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
   }
-  // case llvm::omp::Directive::OMPD_workdistribute:
+  case llvm::omp::Directive::OMPD_workdistribute:
+    newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue,
+                                item);
+    break;
   case llvm::omp::Directive::OMPD_workshare:
     newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc,
                            queue, item);
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index c55642d969503..ad729932a5f00 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -1492,12 +1492,17 @@ TYPE_PARSER(
         "SINGLE" >> pure(llvm::omp::Directive::OMPD_single),
         "TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data),
         "TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel),
+        "TARGET TEAMS WORKDISTRIBUTE" >>
+            pure(llvm::omp::Directive::OMPD_target_teams_workdistribute),
         "TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams),
         "TARGET" >> pure(llvm::omp::Directive::OMPD_target),
         "TASK"_id >> pure(llvm::omp::Directive::OMPD_task),
         "TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup),
+        "TEAMS WORKDISTRIBUTE" >>
+            pure(llvm::omp::Directive::OMPD_teams_workdistribute),
         "TEAMS" >> pure(llvm::omp::Directive::OMPD_teams),
-        "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare))))
+        "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare),
+        "WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_workdistribute))))
 
 TYPE_PARSER(sourced(construct<OmpBeginBlockDirective>(
     sourced(Parser<OmpBlockDirective>{}), Parser<OmpClauseList>{})))
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 885c02e6ec74b..2e4e05f9e293b 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -1656,10 +1656,13 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) {
   case llvm::omp::Directive::OMPD_task:
   case llvm::omp::Directive::OMPD_taskgroup:
   case llvm::omp::Directive::OMPD_teams:
+  case llvm::omp::Directive::OMPD_workdistribute:
   case llvm::omp::Directive::OMPD_workshare:
   case llvm::omp::Directive::OMPD_parallel_workshare:
   case llvm::omp::Directive::OMPD_target_teams:
+  case llvm::omp::Directive::OMPD_target_teams_workdistribute:
   case llvm::omp::Directive::OMPD_target_parallel:
+  case llvm::omp::Directive::OMPD_teams_workdistribute:
     PushContext(beginDir.source, beginDir.v);
     break;
   default:
@@ -1689,9 +1692,12 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) {
   case llvm::omp::Directive::OMPD_target:
   case llvm::omp::Directive::OMPD_task:
   case llvm::omp::Directive::OMPD_teams:
+  case llvm::omp::Directive::OMPD_workdistribute:
   case llvm::omp::Directive::OMPD_parallel_workshare:
   case llvm::omp::Directive::OMPD_target_teams:
-  case llvm::omp::Directive::OMPD_target_parallel: {
+  case llvm::omp::Directive::OMPD_target_parallel:
+  case llvm::omp::Directive::OMPD_target_teams_workdistribute:
+  case llvm::omp::Directive::OMPD_teams_workdistribute: {
     bool hasPrivate;
     for (const auto *allocName : allocateNames_) {
       hasPrivate = false;
diff --git a/flang/test/Lower/OpenMP/workdistribute.f90 b/flang/test/Lower/OpenMP/workdistribute.f90
new file mode 100644
index 0000000000000..924205bb72e5e
--- /dev/null
+++ b/flang/test/Lower/OpenMP/workdistribute.f90
@@ -0,0 +1,59 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute
+subroutine target_teams_workdistribute()
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp target teams workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end target teams workdistribute
+end subroutine target_teams_workdistribute
+
+! CHECK-LABEL: func @_QPteams_workdistribute
+subroutine teams_workdistribute()
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp teams workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end teams workdistribute
+end subroutine teams_workdistribute
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute_m
+subroutine target_teams_workdistribute_m()
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp target
+  !$omp teams
+  !$omp workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end workdistribute
+  !$omp end teams
+  !$omp end target
+end subroutine target_teams_workdistribute_m
+
+! CHECK-LABEL: func @_QPteams_workdistribute_m
+subroutine teams_workdistribute_m()
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp teams
+  !$omp workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end workdistribute
+  !$omp end teams
+end subroutine teams_workdistribute_m
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index a87111cb5a11d..d1831db37fc46 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -1286,6 +1286,15 @@ def OMP_EndWorkshare : Directive<[Spelling<"end workshare">]> {
   let category = OMP_Workshare.category;
   let languages = [L_Fortran];
 }
+def OMP_Workdistribute : Directive<[Spelling<"workdistribute">]> {
+  let association = AS_Block;
+  let category = CA_Executable;
+}
+def OMP_EndWorkdistribute : Directive<[Spelling<"end workdistribute">]> {
+  let leafConstructs = OMP_Workdistribute.leafConstructs;
+  let association = OMP_Workdistribute.association;
+  let category = OMP_Workdistribute.category;
+}
 
 //===----------------------------------------------------------------------===//
 // Definitions of OpenMP compound directives
@@ -2429,6 +2438,34 @@ def OMP_TargetTeamsDistributeSimd
   let leafConstructs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Simd];
   let category = CA_Executable;
 }
+def OMP_TargetTeamsWorkdistribute : Directive<[Spelling<"target teams workdistribute">]> {
+  let allowedClauses = [
+    VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_Depend>,
+    VersionedClause<OMPC_FirstPrivate>,
+    VersionedClause<OMPC_HasDeviceAddr, 51>,
+    VersionedClause<OMPC_If>,
+    VersionedClause<OMPC_IsDevicePtr>,
+    VersionedClause<OMPC_Map>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Shared>,
+    VersionedClause<OMPC_UsesAllocators, 50>,
+  ];
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_DefaultMap>,
+    VersionedClause<OMPC_Device>,
+    VersionedClause<OMPC_NoWait>,
+    VersionedClause<OMPC_NumTeams>,
+    VersionedClause<OMPC_OMPX_DynCGroupMem>,
+    VersionedClause<OMPC_OMPX_Bare>,
+    VersionedClause<OMPC_ThreadLimit>,
+  ];
+  let leafConstructs = [OMP_Target, OMP_Teams, OMP_Workdistribute];
+  let category = CA_Executable;
+}
 def OMP_target_teams_loop : Directive<[Spelling<"target teams loop">]> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
@@ -2659,6 +2696,24 @@ def OMP_TeamsDistributeSimd : Directive<[Spelling<"teams distribute simd">]> {
   let leafConstructs = [OMP_Teams, OMP_Distribute, OMP_Simd];
   let category = CA_Executable;
 }
+def OMP_TeamsWorkdistribute : Directive<[Spelling<"teams workdistribute">]> {
+  let allowedClauses = [
+    VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_FirstPrivate>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Shared>,
+  ];
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_If, 52>,
+    VersionedClause<OMPC_NumTeams>,
+    VersionedClause<OMPC_ThreadLimit>,
+  ];
+  let leafConstructs = [OMP_Teams, OMP_Workdistribute];
+  let category = CA_Executable;
+}
 def OMP_teams_loop : Directive<[Spelling<"teams loop">]> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index ac80926053a2d..a58e09d7bda71 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1887,4 +1887,27 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// workdistribute Construct
+//===----------------------------------------------------------------------===//
+
+def WorkdistributeOp : OpenMP_Op<"workdistribute"> {
+  let summary = "workdistribute directive";
+  let description = [{
+    workdistribute divides execution of the enclosed structured block into
+    separate units of work, each executed only once by each
+    initial thread in the league.
+    ```
+    !$omp target teams
+        !$omp workdistribute
+        y = a * x + y 
+        !$omp end workdistribute
+    !$omp end target teams
+    ```
+  }];
+  let regions = (region AnyRegion:$region);
+  let hasVerifier = 1;
+  let assemblyFormat = "$region attr-dict";
+}
+
 #endif // OPENMP_OPS
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index e94d570b57122..e2dd338829e76 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3493,6 +3493,21 @@ LogicalResult ScanOp::verify() {
                    "reduction modifier");
 }
 
+//===----------------------------------------------------------------------===//
+// WorkdistributeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WorkdistributeOp::verify() {
+  Region &region = getRegion();
+  if (!region.hasOneBlock())
+    return emitOpError("region must contain exactly one block");
+
+  Operation *parentOp = (*this)->getParentOp();
+  if (!llvm::dyn_cast<TeamsOp>(parentOp))
+    return emitOpError("workdistribute must be nested under teams");
+  return success();
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 060b3cd2455a0..522d20558a2b5 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2960,3 +2960,24 @@ llvm.func @invalid_mapper(%0 : !llvm.ptr) {
   }
   llvm.return
 }
+
+func.func @invalid_workdistribute_with_multiple_blocks() {
+  // expected-error @below {{workdistribute must be nested under teams}}
+  omp.workdistribute {
+    omp.terminator
+  }
+  return
+}
+
+func.func @invalid_workdistribute_with_multiple_blocks() {
+  omp.teams {
+  // expected-error @below {{region must contain exactly one block}}
+  omp.workdistribute {
+    cf.br ^bb1
+  ^bb1:
+    omp.terminator
+  }
+  omp.terminator
+  }
+  return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 47cfc5278a5d0..af80284e53537 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3197,3 +3197,16 @@ func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
   }
   return
 }
+
+// CHECK-LABEL: func @omp_workdistribute
+func.func @omp_workdistribute() {
+  // CHECK: omp.teams
+  omp.teams {
+  // CHECK: omp.workdistribute
+  omp.workdistribute {
+    omp.terminator
+  }
+  omp.terminator
+  }
+  return
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:openmp OpenMP related changes to Clang flang:fir-hlfir flang:openmp flang:parser flang:semantics flang Flang issues not falling into any other category mlir:openmp mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants