@@ -672,8 +672,10 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion,
672672 return parser.parseRegion (region, entryBlockArgs);
673673}
674674
675- static ParseResult parseInReductionMapPrivateRegion (
675+ static ParseResult parseHostEvalInReductionMapPrivateRegion (
676676 OpAsmParser &parser, Region ®ion,
677+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
678+ SmallVectorImpl<Type> &hostEvalTypes,
677679 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
678680 SmallVectorImpl<Type> &inReductionTypes,
679681 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
@@ -682,6 +684,7 @@ static ParseResult parseInReductionMapPrivateRegion(
682684 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
683685 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
684686 AllRegionParseArgs args;
687+ args.hostEvalArgs .emplace (hostEvalVars, hostEvalTypes);
685688 args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
686689 inReductionByref, inReductionSyms);
687690 args.mapArgs .emplace (mapVars, mapTypes);
@@ -896,12 +899,14 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
896899 p.printRegion (region, /* printEntryBlockArgs=*/ false );
897900}
898901
899- static void printInReductionMapPrivateRegion (
900- OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
902+ static void printHostEvalInReductionMapPrivateRegion (
903+ OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hostEvalVars,
904+ TypeRange hostEvalTypes, ValueRange inReductionVars,
901905 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
902906 ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
903907 ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
904908 AllRegionPrintArgs args;
909+ args.hostEvalArgs .emplace (hostEvalVars, hostEvalTypes);
905910 args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
906911 inReductionByref, inReductionSyms);
907912 args.mapArgs .emplace (mapVars, mapTypes);
@@ -1685,7 +1690,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
16851690 // inReductionByref, inReductionSyms.
16861691 TargetOp::build (builder, state, /* allocate_vars=*/ {}, /* allocator_vars=*/ {},
16871692 makeArrayAttr (ctx, clauses.dependKinds ), clauses.dependVars ,
1688- clauses.device , clauses.hasDeviceAddrVars , clauses.ifExpr ,
1693+ clauses.device , clauses.hasDeviceAddrVars ,
1694+ clauses.hostEvalVars , clauses.ifExpr ,
16891695 /* in_reduction_vars=*/ {}, /* in_reduction_byref=*/ nullptr ,
16901696 /* in_reduction_syms=*/ nullptr , clauses.isDevicePtrVars ,
16911697 clauses.mapVars , clauses.nowait , clauses.privateVars ,
@@ -1699,6 +1705,159 @@ LogicalResult TargetOp::verify() {
16991705 : verifyMapClause (*this , getMapVars ());
17001706}
17011707
1708+ LogicalResult TargetOp::verifyRegions () {
1709+ auto teamsOps = getOps<TeamsOp>();
1710+ if (std::distance (teamsOps.begin (), teamsOps.end ()) > 1 )
1711+ return emitError (" target containing multiple 'omp.teams' nested ops" );
1712+
1713+ // Check that host_eval values are only used in legal ways.
1714+ bool isTargetSPMD = isTargetSPMDLoop ();
1715+ for (Value hostEvalArg :
1716+ cast<BlockArgOpenMPOpInterface>(getOperation ()).getHostEvalBlockArgs ()) {
1717+ for (Operation *user : hostEvalArg.getUsers ()) {
1718+ if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
1719+ if (llvm::is_contained ({teamsOp.getNumTeamsLower (),
1720+ teamsOp.getNumTeamsUpper (),
1721+ teamsOp.getThreadLimit ()},
1722+ hostEvalArg))
1723+ continue ;
1724+
1725+ return emitOpError () << " host_eval argument only legal as 'num_teams' "
1726+ " and 'thread_limit' in 'omp.teams'" ;
1727+ }
1728+ if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1729+ if (isTargetSPMD && hostEvalArg == parallelOp.getNumThreads ())
1730+ continue ;
1731+
1732+ return emitOpError ()
1733+ << " host_eval argument only legal as 'num_threads' in "
1734+ " 'omp.parallel' when representing target SPMD" ;
1735+ }
1736+ if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1737+ if (isTargetSPMD &&
1738+ (llvm::is_contained (loopNestOp.getLoopLowerBounds (), hostEvalArg) ||
1739+ llvm::is_contained (loopNestOp.getLoopUpperBounds (), hostEvalArg) ||
1740+ llvm::is_contained (loopNestOp.getLoopSteps (), hostEvalArg)))
1741+ continue ;
1742+
1743+ return emitOpError ()
1744+ << " host_eval argument only legal as loop bounds and steps in "
1745+ " 'omp.loop_nest' when representing target SPMD" ;
1746+ }
1747+
1748+ return emitOpError () << " host_eval argument illegal use in '"
1749+ << user->getName () << " ' operation" ;
1750+ }
1751+ }
1752+ return success ();
1753+ }
1754+
1755+ // / Only allow OpenMP terminators and non-OpenMP ops that have known memory
1756+ // / effects, but don't include a memory write effect.
1757+ static bool siblingAllowedInCapture (Operation *op) {
1758+ if (!op)
1759+ return false ;
1760+
1761+ bool isOmpDialect =
1762+ op->getContext ()->getLoadedDialect <omp::OpenMPDialect>() ==
1763+ op->getDialect ();
1764+
1765+ if (isOmpDialect)
1766+ return op->hasTrait <OpTrait::IsTerminator>();
1767+
1768+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
1769+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 > effects;
1770+ memOp.getEffects (effects);
1771+ return !llvm::any_of (effects, [&](MemoryEffects::EffectInstance &effect) {
1772+ return isa<MemoryEffects::Write>(effect.getEffect ()) &&
1773+ isa<SideEffects::AutomaticAllocationScopeResource>(
1774+ effect.getResource ());
1775+ });
1776+ }
1777+ return true ;
1778+ }
1779+
1780+ Operation *TargetOp::getInnermostCapturedOmpOp () {
1781+ Dialect *ompDialect = (*this )->getDialect ();
1782+ Operation *capturedOp = nullptr ;
1783+
1784+ // Process in pre-order to check operations from outermost to innermost,
1785+ // ensuring we only enter the region of an operation if it meets the criteria
1786+ // for being captured. We stop the exploration of nested operations as soon as
1787+ // we process a region holding no operations to be captured.
1788+ walk<WalkOrder::PreOrder>([&](Operation *op) {
1789+ if (op == *this )
1790+ return WalkResult::advance ();
1791+
1792+ // Ignore operations of other dialects or omp operations with no regions,
1793+ // because these will only be checked if they are siblings of an omp
1794+ // operation that can potentially be captured.
1795+ bool isOmpDialect = op->getDialect () == ompDialect;
1796+ bool hasRegions = op->getNumRegions () > 0 ;
1797+ if (!isOmpDialect || !hasRegions)
1798+ return WalkResult::skip ();
1799+
1800+ // Don't capture this op if it has a not-allowed sibling, and stop recursing
1801+ // into nested operations.
1802+ for (Operation &sibling : op->getParentRegion ()->getOps ())
1803+ if (&sibling != op && !siblingAllowedInCapture (&sibling))
1804+ return WalkResult::interrupt ();
1805+
1806+ // Don't continue capturing nested operations if we reach an omp.loop_nest.
1807+ // Otherwise, process the contents of this operation.
1808+ capturedOp = op;
1809+ return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt ()
1810+ : WalkResult::advance ();
1811+ });
1812+
1813+ return capturedOp;
1814+ }
1815+
1816+ bool TargetOp::isTargetSPMDLoop () {
1817+ // The expected MLIR representation for a target SPMD loop is:
1818+ // omp.target {
1819+ // omp.teams {
1820+ // omp.parallel {
1821+ // omp.distribute {
1822+ // omp.wsloop {
1823+ // omp.loop_nest ... { ... }
1824+ // } {omp.composite}
1825+ // } {omp.composite}
1826+ // omp.terminator
1827+ // } {omp.composite}
1828+ // omp.terminator
1829+ // }
1830+ // omp.terminator
1831+ // }
1832+
1833+ Operation *capturedOp = getInnermostCapturedOmpOp ();
1834+ if (!isa_and_present<LoopNestOp>(capturedOp))
1835+ return false ;
1836+
1837+ Operation *workshareOp = capturedOp->getParentOp ();
1838+
1839+ // Accept an optional omp.simd loop wrapper as part of the SPMD pattern.
1840+ if (isa_and_present<SimdOp>(workshareOp))
1841+ workshareOp = workshareOp->getParentOp ();
1842+
1843+ if (!isa_and_present<WsloopOp>(workshareOp))
1844+ return false ;
1845+
1846+ Operation *distributeOp = workshareOp->getParentOp ();
1847+ if (!isa_and_present<DistributeOp>(distributeOp))
1848+ return false ;
1849+
1850+ Operation *parallelOp = distributeOp->getParentOp ();
1851+ if (!isa_and_present<ParallelOp>(parallelOp))
1852+ return false ;
1853+
1854+ Operation *teamsOp = parallelOp->getParentOp ();
1855+ if (!isa_and_present<TeamsOp>(teamsOp))
1856+ return false ;
1857+
1858+ return teamsOp->getParentOp () == (*this );
1859+ }
1860+
17021861// ===----------------------------------------------------------------------===//
17031862// ParallelOp
17041863// ===----------------------------------------------------------------------===//
0 commit comments