3131#include " llvm/ADT/StringRef.h"
3232#include " llvm/ADT/TypeSwitch.h"
3333#include " llvm/Frontend/OpenMP/OMPConstants.h"
34+ #include " llvm/Frontend/OpenMP/OMPDeviceConstants.h"
3435#include < cstddef>
3536#include < iterator>
3637#include < optional>
@@ -691,8 +692,10 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion,
691692 return parser.parseRegion (region, entryBlockArgs);
692693}
693694
694- static ParseResult parseInReductionMapPrivateRegion (
695+ static ParseResult parseHostEvalInReductionMapPrivateRegion (
695696 OpAsmParser &parser, Region ®ion,
697+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
698+ SmallVectorImpl<Type> &hostEvalTypes,
696699 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
697700 SmallVectorImpl<Type> &inReductionTypes,
698701 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
@@ -702,6 +705,7 @@ static ParseResult parseInReductionMapPrivateRegion(
702705 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
703706 DenseI64ArrayAttr &privateMaps) {
704707 AllRegionParseArgs args;
708+ args.hostEvalArgs .emplace (hostEvalVars, hostEvalTypes);
705709 args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
706710 inReductionByref, inReductionSyms);
707711 args.mapArgs .emplace (mapVars, mapTypes);
@@ -931,13 +935,15 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
931935 p.printRegion (region, /* printEntryBlockArgs=*/ false );
932936}
933937
934- static void printInReductionMapPrivateRegion (
935- OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
938+ static void printHostEvalInReductionMapPrivateRegion (
939+ OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hostEvalVars,
940+ TypeRange hostEvalTypes, ValueRange inReductionVars,
936941 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
937942 ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
938943 ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
939944 DenseI64ArrayAttr privateMaps) {
940945 AllRegionPrintArgs args;
946+ args.hostEvalArgs .emplace (hostEvalVars, hostEvalTypes);
941947 args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
942948 inReductionByref, inReductionSyms);
943949 args.mapArgs .emplace (mapVars, mapTypes);
@@ -1720,11 +1726,12 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
17201726 TargetOp::build (builder, state, /* allocate_vars=*/ {}, /* allocator_vars=*/ {},
17211727 clauses.bare , makeArrayAttr (ctx, clauses.dependKinds ),
17221728 clauses.dependVars , clauses.device , clauses.hasDeviceAddrVars ,
1723- clauses.ifExpr , /* in_reduction_vars=*/ {},
1724- /* in_reduction_byref=*/ nullptr , /* in_reduction_syms=*/ nullptr ,
1725- clauses.isDevicePtrVars , clauses.mapVars , clauses.nowait ,
1726- clauses.privateVars , makeArrayAttr (ctx, clauses.privateSyms ),
1727- clauses.threadLimit , /* private_maps=*/ nullptr );
1729+ clauses.hostEvalVars , clauses.ifExpr ,
1730+ /* in_reduction_vars=*/ {}, /* in_reduction_byref=*/ nullptr ,
1731+ /* in_reduction_syms=*/ nullptr , clauses.isDevicePtrVars ,
1732+ clauses.mapVars , clauses.nowait , clauses.privateVars ,
1733+ makeArrayAttr (ctx, clauses.privateSyms ), clauses.threadLimit ,
1734+ /* private_maps=*/ nullptr );
17281735}
17291736
17301737LogicalResult TargetOp::verify () {
@@ -1742,6 +1749,189 @@ LogicalResult TargetOp::verify() {
17421749 return verifyPrivateVarsMapping (*this );
17431750}
17441751
1752+ LogicalResult TargetOp::verifyRegions () {
1753+ auto teamsOps = getOps<TeamsOp>();
1754+ if (std::distance (teamsOps.begin (), teamsOps.end ()) > 1 )
1755+ return emitError (" target containing multiple 'omp.teams' nested ops" );
1756+
1757+ // Check that host_eval values are only used in legal ways.
1758+ llvm::omp::OMPTgtExecModeFlags execFlags = getKernelExecFlags ();
1759+ for (Value hostEvalArg :
1760+ cast<BlockArgOpenMPOpInterface>(getOperation ()).getHostEvalBlockArgs ()) {
1761+ for (Operation *user : hostEvalArg.getUsers ()) {
1762+ if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
1763+ if (llvm::is_contained ({teamsOp.getNumTeamsLower (),
1764+ teamsOp.getNumTeamsUpper (),
1765+ teamsOp.getThreadLimit ()},
1766+ hostEvalArg))
1767+ continue ;
1768+
1769+ return emitOpError () << " host_eval argument only legal as 'num_teams' "
1770+ " and 'thread_limit' in 'omp.teams'" ;
1771+ }
1772+ if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1773+ if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
1774+ hostEvalArg == parallelOp.getNumThreads ())
1775+ continue ;
1776+
1777+ return emitOpError ()
1778+ << " host_eval argument only legal as 'num_threads' in "
1779+ " 'omp.parallel' when representing target SPMD" ;
1780+ }
1781+ if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1782+ if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
1783+ (llvm::is_contained (loopNestOp.getLoopLowerBounds (), hostEvalArg) ||
1784+ llvm::is_contained (loopNestOp.getLoopUpperBounds (), hostEvalArg) ||
1785+ llvm::is_contained (loopNestOp.getLoopSteps (), hostEvalArg)))
1786+ continue ;
1787+
1788+ return emitOpError () << " host_eval argument only legal as loop bounds "
1789+ " and steps in 'omp.loop_nest' when "
1790+ " representing target SPMD or Generic-SPMD" ;
1791+ }
1792+
1793+ return emitOpError () << " host_eval argument illegal use in '"
1794+ << user->getName () << " ' operation" ;
1795+ }
1796+ }
1797+ return success ();
1798+ }
1799+
1800+ // / Only allow OpenMP terminators and non-OpenMP ops that have known memory
1801+ // / effects, but don't include a memory write effect.
1802+ static bool siblingAllowedInCapture (Operation *op) {
1803+ if (!op)
1804+ return false ;
1805+
1806+ bool isOmpDialect =
1807+ op->getContext ()->getLoadedDialect <omp::OpenMPDialect>() ==
1808+ op->getDialect ();
1809+
1810+ if (isOmpDialect)
1811+ return op->hasTrait <OpTrait::IsTerminator>();
1812+
1813+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
1814+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 > effects;
1815+ memOp.getEffects (effects);
1816+ return !llvm::any_of (effects, [&](MemoryEffects::EffectInstance &effect) {
1817+ return isa<MemoryEffects::Write>(effect.getEffect ()) &&
1818+ isa<SideEffects::AutomaticAllocationScopeResource>(
1819+ effect.getResource ());
1820+ });
1821+ }
1822+ return true ;
1823+ }
1824+
1825+ Operation *TargetOp::getInnermostCapturedOmpOp () {
1826+ Dialect *ompDialect = (*this )->getDialect ();
1827+ Operation *capturedOp = nullptr ;
1828+ DominanceInfo domInfo;
1829+
1830+ // Process in pre-order to check operations from outermost to innermost,
1831+ // ensuring we only enter the region of an operation if it meets the criteria
1832+ // for being captured. We stop the exploration of nested operations as soon as
1833+ // we process a region holding no operations to be captured.
1834+ walk<WalkOrder::PreOrder>([&](Operation *op) {
1835+ if (op == *this )
1836+ return WalkResult::advance ();
1837+
1838+ // Ignore operations of other dialects or omp operations with no regions,
1839+ // because these will only be checked if they are siblings of an omp
1840+ // operation that can potentially be captured.
1841+ bool isOmpDialect = op->getDialect () == ompDialect;
1842+ bool hasRegions = op->getNumRegions () > 0 ;
1843+ if (!isOmpDialect || !hasRegions)
1844+ return WalkResult::skip ();
1845+
1846+ // This operation cannot be captured if it can be executed more than once
1847+ // (i.e. its block's successors can reach it) or if it's not guaranteed to
1848+ // be executed before all exits of the region (i.e. it doesn't dominate all
1849+ // blocks with no successors reachable from the entry block).
1850+ Region *parentRegion = op->getParentRegion ();
1851+ Block *parentBlock = op->getBlock ();
1852+
1853+ for (Block *successor : parentBlock->getSuccessors ())
1854+ if (successor->isReachable (parentBlock))
1855+ return WalkResult::interrupt ();
1856+
1857+ for (Block &block : *parentRegion)
1858+ if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
1859+ !domInfo.dominates (parentBlock, &block))
1860+ return WalkResult::interrupt ();
1861+
1862+ // Don't capture this op if it has a not-allowed sibling, and stop recursing
1863+ // into nested operations.
1864+ for (Operation &sibling : op->getParentRegion ()->getOps ())
1865+ if (&sibling != op && !siblingAllowedInCapture (&sibling))
1866+ return WalkResult::interrupt ();
1867+
1868+ // Don't continue capturing nested operations if we reach an omp.loop_nest.
1869+ // Otherwise, process the contents of this operation.
1870+ capturedOp = op;
1871+ return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt ()
1872+ : WalkResult::advance ();
1873+ });
1874+
1875+ return capturedOp;
1876+ }
1877+
1878+ llvm::omp::OMPTgtExecModeFlags TargetOp::getKernelExecFlags () {
1879+ using namespace llvm ::omp;
1880+
1881+ // Make sure this region is capturing a loop. Otherwise, it's a generic
1882+ // kernel.
1883+ Operation *capturedOp = getInnermostCapturedOmpOp ();
1884+ if (!isa_and_present<LoopNestOp>(capturedOp))
1885+ return OMP_TGT_EXEC_MODE_GENERIC;
1886+
1887+ SmallVector<LoopWrapperInterface> wrappers;
1888+ cast<LoopNestOp>(capturedOp).gatherWrappers (wrappers);
1889+ assert (!wrappers.empty ());
1890+
1891+ // Ignore optional SIMD leaf construct.
1892+ auto *innermostWrapper = wrappers.begin ();
1893+ if (isa<SimdOp>(innermostWrapper))
1894+ innermostWrapper = std::next (innermostWrapper);
1895+
1896+ long numWrappers = std::distance (innermostWrapper, wrappers.end ());
1897+
1898+ // Detect Generic-SPMD: target-teams-distribute[-simd].
1899+ if (numWrappers == 1 ) {
1900+ if (!isa<DistributeOp>(innermostWrapper))
1901+ return OMP_TGT_EXEC_MODE_GENERIC;
1902+
1903+ Operation *teamsOp = (*innermostWrapper)->getParentOp ();
1904+ if (!isa_and_present<TeamsOp>(teamsOp))
1905+ return OMP_TGT_EXEC_MODE_GENERIC;
1906+
1907+ if (teamsOp->getParentOp () == *this )
1908+ return OMP_TGT_EXEC_MODE_GENERIC_SPMD;
1909+ }
1910+
1911+ // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
1912+ if (numWrappers == 2 ) {
1913+ if (!isa<WsloopOp>(innermostWrapper))
1914+ return OMP_TGT_EXEC_MODE_GENERIC;
1915+
1916+ innermostWrapper = std::next (innermostWrapper);
1917+ if (!isa<DistributeOp>(innermostWrapper))
1918+ return OMP_TGT_EXEC_MODE_GENERIC;
1919+
1920+ Operation *parallelOp = (*innermostWrapper)->getParentOp ();
1921+ if (!isa_and_present<ParallelOp>(parallelOp))
1922+ return OMP_TGT_EXEC_MODE_GENERIC;
1923+
1924+ Operation *teamsOp = parallelOp->getParentOp ();
1925+ if (!isa_and_present<TeamsOp>(teamsOp))
1926+ return OMP_TGT_EXEC_MODE_GENERIC;
1927+
1928+ if (teamsOp->getParentOp () == *this )
1929+ return OMP_TGT_EXEC_MODE_SPMD;
1930+ }
1931+
1932+ return OMP_TGT_EXEC_MODE_GENERIC;
1933+ }
1934+
17451935// ===----------------------------------------------------------------------===//
17461936// ParallelOp
17471937// ===----------------------------------------------------------------------===//
0 commit comments