@@ -430,68 +430,102 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
430430// Parser, printer and verifier for ReductionVarList
431431// ===----------------------------------------------------------------------===//
432432
433- ParseResult
434- parseReductionClause (OpAsmParser &parser, Region ®ion,
435- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
436- SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
437- SmallVectorImpl<OpAsmParser::Argument> &privates) {
438- if (failed (parser.parseOptionalKeyword (" reduction" )))
439- return failure ();
440-
433+ ParseResult parseClauseWithRegionArgs (
434+ OpAsmParser &parser, Region ®ion,
435+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
436+ SmallVectorImpl<Type> &types, ArrayAttr &symbols,
437+ SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs) {
441438 SmallVector<SymbolRefAttr> reductionVec;
439+ unsigned regionArgOffset = regionPrivateArgs.size ();
442440
443441 if (failed (
444442 parser.parseCommaSeparatedList (OpAsmParser::Delimiter::Paren, [&]() {
445443 if (parser.parseAttribute (reductionVec.emplace_back ()) ||
446444 parser.parseOperand (operands.emplace_back ()) ||
447445 parser.parseArrow () ||
448- parser.parseArgument (privates .emplace_back ()) ||
446+ parser.parseArgument (regionPrivateArgs .emplace_back ()) ||
449447 parser.parseColonType (types.emplace_back ()))
450448 return failure ();
451449 return success ();
452450 })))
453451 return failure ();
454452
455- for (auto [prv, type] : llvm::zip_equal (privates, types)) {
453+ auto *argsBegin = regionPrivateArgs.begin ();
454+ MutableArrayRef argsSubrange (argsBegin + regionArgOffset,
455+ argsBegin + regionArgOffset + types.size ());
456+ for (auto [prv, type] : llvm::zip_equal (argsSubrange, types)) {
456457 prv.type = type;
457458 }
458459 SmallVector<Attribute> reductions (reductionVec.begin (), reductionVec.end ());
459- reductionSymbols = ArrayAttr::get (parser.getContext (), reductions);
460+ symbols = ArrayAttr::get (parser.getContext (), reductions);
460461 return success ();
461462}
462463
463- static void printReductionClause (OpAsmPrinter &p, Operation *op,
464- ValueRange reductionArgs, ValueRange operands,
465- TypeRange types, ArrayAttr reductionSymbols) {
466- p << " reduction(" ;
464+ static void printClauseWithRegionArgs (OpAsmPrinter &p, Operation *op,
465+ ValueRange argsSubrange,
466+ StringRef clauseName, ValueRange operands,
467+ TypeRange types, ArrayAttr symbols) {
468+ p << clauseName << " (" ;
467469 llvm::interleaveComma (
468- llvm::zip_equal (reductionSymbols, operands, reductionArgs, types), p,
469- [&p](auto t) {
470+ llvm::zip_equal (symbols, operands, argsSubrange, types), p, [&p](auto t) {
470471 auto [sym, op, arg, type] = t;
471472 p << sym << " " << op << " -> " << arg << " : " << type;
472473 });
473474 p << " ) " ;
474475}
475476
476- static ParseResult
477- parseParallelRegion (OpAsmParser &parser, Region ®ion,
478- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
479- SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
477+ static ParseResult parseParallelRegion (
478+ OpAsmParser &parser, Region ®ion,
479+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVarOperands,
480+ SmallVectorImpl<Type> &reductionVarTypes, ArrayAttr &reductionSymbols,
481+ llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVarOperands,
482+ llvm::SmallVectorImpl<Type> &privateVarsTypes,
483+ ArrayAttr &privatizerSymbols) {
484+ llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
480485
481- llvm::SmallVector<OpAsmParser::Argument> privates;
482- if (succeeded (parseReductionClause (parser, region, operands, types,
483- reductionSymbols, privates)))
484- return parser.parseRegion (region, privates);
486+ if (succeeded (parser.parseOptionalKeyword (" reduction" ))) {
487+ if (failed (parseClauseWithRegionArgs (parser, region, reductionVarOperands,
488+ reductionVarTypes, reductionSymbols,
489+ regionPrivateArgs)))
490+ return failure ();
491+ }
485492
486- return parser.parseRegion (region);
493+ if (succeeded (parser.parseOptionalKeyword (" private" ))) {
494+ if (failed (parseClauseWithRegionArgs (parser, region, privateVarOperands,
495+ privateVarsTypes, privatizerSymbols,
496+ regionPrivateArgs)))
497+ return failure ();
498+ }
499+
500+ return parser.parseRegion (region, regionPrivateArgs);
487501}
488502
489503static void printParallelRegion (OpAsmPrinter &p, Operation *op, Region ®ion,
490- ValueRange operands, TypeRange types,
491- ArrayAttr reductionSymbols) {
492- if (reductionSymbols)
493- printReductionClause (p, op, region.front ().getArguments (), operands, types,
494- reductionSymbols);
504+ ValueRange reductionVarOperands,
505+ TypeRange reductionVarTypes,
506+ ArrayAttr reductionSymbols,
507+ ValueRange privateVarOperands,
508+ TypeRange privateVarTypes,
509+ ArrayAttr privatizerSymbols) {
510+ if (reductionSymbols) {
511+ auto *argsBegin = region.front ().getArguments ().begin ();
512+ MutableArrayRef argsSubrange (argsBegin,
513+ argsBegin + reductionVarTypes.size ());
514+ printClauseWithRegionArgs (p, op, argsSubrange, " reduction" ,
515+ reductionVarOperands, reductionVarTypes,
516+ reductionSymbols);
517+ }
518+
519+ if (privatizerSymbols) {
520+ auto *argsBegin = region.front ().getArguments ().begin ();
521+ MutableArrayRef argsSubrange (argsBegin + reductionVarOperands.size (),
522+ argsBegin + reductionVarOperands.size () +
523+ privateVarTypes.size ());
524+ printClauseWithRegionArgs (p, op, argsSubrange, " private" ,
525+ privateVarOperands, privateVarTypes,
526+ privatizerSymbols);
527+ }
528+
495529 p.printRegion (region, /* printEntryBlockArgs=*/ false );
496530}
497531
@@ -1008,9 +1042,8 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
10081042 }
10091043
10101044 if (always || close || implicit) {
1011- return emitError (
1012- op->getLoc (),
1013- " present, mapper and iterator map type modifiers are permitted" );
1045+ return emitError (op->getLoc (), " present, mapper and iterator map "
1046+ " type modifiers are permitted" );
10141047 }
10151048
10161049 to ? updateToVars.insert (updateVar) : updateFromVars.insert (updateVar);
@@ -1070,14 +1103,63 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
10701103 builder, state, /* if_expr_var=*/ nullptr , /* num_threads_var=*/ nullptr ,
10711104 /* allocate_vars=*/ ValueRange (), /* allocators_vars=*/ ValueRange (),
10721105 /* reduction_vars=*/ ValueRange (), /* reductions=*/ nullptr ,
1073- /* proc_bind_val=*/ nullptr );
1106+ /* proc_bind_val=*/ nullptr , /* private_vars=*/ ValueRange (),
1107+ /* privatizers=*/ nullptr );
10741108 state.addAttributes (attributes);
10751109}
10761110
1111+ static LogicalResult verifyPrivateVarList (ParallelOp &op) {
1112+ auto privateVars = op.getPrivateVars ();
1113+ auto privatizers = op.getPrivatizersAttr ();
1114+
1115+ if (privateVars.empty () && (privatizers == nullptr || privatizers.empty ()))
1116+ return success ();
1117+
1118+ auto numPrivateVars = privateVars.size ();
1119+ auto numPrivatizers = (privatizers == nullptr ) ? 0 : privatizers.size ();
1120+
1121+ if (numPrivateVars != numPrivatizers)
1122+ return op.emitError () << " inconsistent number of private variables and "
1123+ " privatizer op symbols, private vars: "
1124+ << numPrivateVars
1125+ << " vs. privatizer op symbols: " << numPrivatizers;
1126+
1127+ for (auto privateVarInfo : llvm::zip (privateVars, privatizers)) {
1128+ Type varType = std::get<0 >(privateVarInfo).getType ();
1129+ SymbolRefAttr privatizerSym =
1130+ std::get<1 >(privateVarInfo).cast <SymbolRefAttr>();
1131+ PrivateClauseOp privatizerOp =
1132+ SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
1133+ privatizerSym);
1134+
1135+ if (privatizerOp == nullptr )
1136+ return op.emitError () << " failed to lookup privatizer op with symbol: '"
1137+ << privatizerSym << " '" ;
1138+
1139+ Type privatizerType = privatizerOp.getType ();
1140+
1141+ if (varType != privatizerType)
1142+ return op.emitError ()
1143+ << " type mismatch between a "
1144+ << (privatizerOp.getDataSharingType () ==
1145+ DataSharingClauseType::Private
1146+ ? " private"
1147+ : " firstprivate" )
1148+ << " variable and its privatizer op, var type: " << varType
1149+ << " vs. privatizer op type: " << privatizerType;
1150+ }
1151+
1152+ return success ();
1153+ }
1154+
10771155LogicalResult ParallelOp::verify () {
10781156 if (getAllocateVars ().size () != getAllocatorsVars ().size ())
10791157 return emitError (
10801158 " expected equal sizes for allocate and allocator variables" );
1159+
1160+ if (failed (verifyPrivateVarList (*this )))
1161+ return failure ();
1162+
10811163 return verifyReductionVarList (*this , getReductions (), getReductionVars ());
10821164}
10831165
@@ -1111,8 +1193,8 @@ LogicalResult TeamsOp::verify() {
11111193 return emitError (" expected num_teams upper bound to be defined if the "
11121194 " lower bound is defined" );
11131195 if (numTeamsLowerBound.getType () != numTeamsUpperBound.getType ())
1114- return emitError (
1115- " expected num_teams upper bound and lower bound to be the same type" );
1196+ return emitError (" expected num_teams upper bound and lower bound to be "
1197+ " the same type" );
11161198 }
11171199
11181200 // Check for allocate clause restrictions
@@ -1174,9 +1256,10 @@ parseWsLoop(OpAsmParser &parser, Region ®ion,
11741256
11751257 // Parse an optional reduction clause
11761258 llvm::SmallVector<OpAsmParser::Argument> privates;
1177- bool hasReduction = succeeded (
1178- parseReductionClause (parser, region, reductionOperands, reductionTypes,
1179- reductionSymbols, privates));
1259+ bool hasReduction = succeeded (parser.parseOptionalKeyword (" reduction" )) &&
1260+ succeeded (parseClauseWithRegionArgs (
1261+ parser, region, reductionOperands, reductionTypes,
1262+ reductionSymbols, privates));
11801263
11811264 if (parser.parseKeyword (" for" ))
11821265 return failure ();
@@ -1223,8 +1306,9 @@ void printWsLoop(OpAsmPrinter &p, Operation *op, Region ®ion,
12231306 if (reductionSymbols) {
12241307 auto reductionArgs =
12251308 region.front ().getArguments ().drop_front (loopVarTypes.size ());
1226- printReductionClause (p, op, reductionArgs, reductionOperands,
1227- reductionTypes, reductionSymbols);
1309+ printClauseWithRegionArgs (p, op, reductionArgs, " reduction" ,
1310+ reductionOperands, reductionTypes,
1311+ reductionSymbols);
12281312 }
12291313
12301314 p << " for " ;
@@ -1464,9 +1548,9 @@ LogicalResult TaskLoopOp::verify() {
14641548 }
14651549
14661550 if (getGrainSize () && getNumTasks ()) {
1467- return emitError (
1468- " the grainsize clause and num_tasks clause are mutually exclusive and "
1469- " may not appear on the same taskloop directive" );
1551+ return emitError (" the grainsize clause and num_tasks clause are mutually "
1552+ " exclusive and "
1553+ " may not appear on the same taskloop directive" );
14701554 }
14711555 return success ();
14721556}
@@ -1535,7 +1619,8 @@ LogicalResult OrderedOp::verify() {
15351619}
15361620
15371621LogicalResult OrderedRegionOp::verify () {
1538- // TODO: The code generation for ordered simd directive is not supported yet.
1622+ // TODO: The code generation for ordered simd directive is not supported
1623+ // yet.
15391624 if (getSimd ())
15401625 return failure ();
15411626
0 commit comments