@@ -53,7 +53,7 @@ class GenericLoopConversionPattern
5353
5454 switch (combinedInfo) {
5555 case GenericLoopCombinedInfo::Standalone:
56- rewriteToSimdLoop (loopOp, rewriter);
56+ rewriteStandaloneLoop (loopOp, rewriter);
5757 break ;
5858 case GenericLoopCombinedInfo::TargetParallelLoop:
5959 llvm_unreachable (" not yet implemented: `parallel loop` direcitve" );
@@ -87,7 +87,10 @@ class GenericLoopConversionPattern
8787 << loopOp->getName () << " operation" ;
8888 };
8989
90- if (loopOp.getBindKind ())
90+ // For standalone directives, `bind` is already supported. Other combined
91+ // forms will be supported in a follow-up PR.
92+ if (combinedInfo != GenericLoopCombinedInfo::Standalone &&
93+ loopOp.getBindKind ())
9194 return todo (" bind" );
9295
9396 if (loopOp.getOrder ())
@@ -119,7 +122,27 @@ class GenericLoopConversionPattern
119122 return result;
120123 }
121124
122- // / Rewrites standalone `loop` directives to equivalent `simd` constructs.
125+ void rewriteStandaloneLoop (mlir::omp::LoopOp loopOp,
126+ mlir::ConversionPatternRewriter &rewriter) const {
127+ using namespace mlir ::omp;
128+ std::optional<ClauseBindKind> bindKind = loopOp.getBindKind ();
129+
130+ if (!bindKind.has_value ())
131+ return rewriteToSimdLoop (loopOp, rewriter);
132+
133+ switch (*loopOp.getBindKind ()) {
134+ case ClauseBindKind::Parallel:
135+ return rewriteToWsloop (loopOp, rewriter);
136+ case ClauseBindKind::Teams:
137+ return rewriteToDistrbute (loopOp, rewriter);
138+ case ClauseBindKind::Thread:
139+ return rewriteToSimdLoop (loopOp, rewriter);
140+ }
141+ }
142+
143+ // / Rewrites standalone `loop` (with `bind` clause with `bind(parallel)`)
144+ // / directives to equivalent `simd` constructs.
145+ // /
123146 // / The reasoning behind this decision is that according to the spec (version
124147 // / 5.2, section 11.7.1):
125148 // /
@@ -147,30 +170,51 @@ class GenericLoopConversionPattern
147170 // / the directive.
148171 void rewriteToSimdLoop (mlir::omp::LoopOp loopOp,
149172 mlir::ConversionPatternRewriter &rewriter) const {
150- loopOp.emitWarning (" Detected standalone OpenMP `loop` directive, the "
151- " associated loop will be rewritten to `simd`." );
152- mlir::omp::SimdOperands simdClauseOps;
153- simdClauseOps.privateVars = loopOp.getPrivateVars ();
173+ loopOp.emitWarning (
174+ " Detected standalone OpenMP `loop` directive with thread binding, "
175+ " the associated loop will be rewritten to `simd`." );
176+ rewriteToSingleWrapperOp<mlir::omp::SimdOp, mlir::omp::SimdOperands>(
177+ loopOp, rewriter);
178+ }
179+
180+ void rewriteToDistrbute (mlir::omp::LoopOp loopOp,
181+ mlir::ConversionPatternRewriter &rewriter) const {
182+ rewriteToSingleWrapperOp<mlir::omp::DistributeOp,
183+ mlir::omp::DistributeOperands>(loopOp, rewriter);
184+ }
185+
186+ void rewriteToWsloop (mlir::omp::LoopOp loopOp,
187+ mlir::ConversionPatternRewriter &rewriter) const {
188+ rewriteToSingleWrapperOp<mlir::omp::WsloopOp, mlir::omp::WsloopOperands>(
189+ loopOp, rewriter);
190+ }
191+
192+ template <typename OpTy, typename OpOperandsTy>
193+ void
194+ rewriteToSingleWrapperOp (mlir::omp::LoopOp loopOp,
195+ mlir::ConversionPatternRewriter &rewriter) const {
196+ OpOperandsTy distributeClauseOps;
197+ distributeClauseOps.privateVars = loopOp.getPrivateVars ();
154198
155199 auto privateSyms = loopOp.getPrivateSyms ();
156200 if (privateSyms)
157- simdClauseOps .privateSyms .assign (privateSyms->begin (),
158- privateSyms->end ());
201+ distributeClauseOps .privateSyms .assign (privateSyms->begin (),
202+ privateSyms->end ());
159203
160- Fortran::common::openmp::EntryBlockArgs simdArgs ;
161- simdArgs .priv .vars = simdClauseOps .privateVars ;
204+ Fortran::common::openmp::EntryBlockArgs distributeArgs ;
205+ distributeArgs .priv .vars = distributeClauseOps .privateVars ;
162206
163- auto simdOp =
164- rewriter.create <mlir::omp::SimdOp >(loopOp.getLoc (), simdClauseOps );
165- mlir::Block *simdBlock =
166- genEntryBlock (rewriter, simdArgs, simdOp .getRegion ());
207+ auto distributeOp =
208+ rewriter.create <OpTy >(loopOp.getLoc (), distributeClauseOps );
209+ mlir::Block *distributeBlock =
210+ genEntryBlock (rewriter, distributeArgs, distributeOp .getRegion ());
167211
168212 mlir::IRMapping mapper;
169213 mlir::Block &loopBlock = *loopOp.getRegion ().begin ();
170214
171- for (auto [loopOpArg, simdopArg ] :
172- llvm::zip_equal ( loopBlock.getArguments (), simdBlock ->getArguments ()))
173- mapper.map (loopOpArg, simdopArg );
215+ for (auto [loopOpArg, distributeOpArg ] : llvm::zip_equal (
216+ loopBlock.getArguments (), distributeBlock ->getArguments ()))
217+ mapper.map (loopOpArg, distributeOpArg );
174218
175219 rewriter.clone (*loopOp.begin (), mapper);
176220 }
0 commit comments