@@ -89,7 +89,7 @@ struct LoopPipelinerInternal {
8989 bool initializeLoopInfo (ForOp op, const triton::PipeliningOption &options);
9090 // / Emits the prologue, this creates `maxStage - 1` part which will contain
9191 // / operations from stages [0; i], where i is the part index.
92- void emitPrologue (RewriterBase &rewriter);
92+ LogicalResult emitPrologue (RewriterBase &rewriter);
9393 // / Gather liverange information for Values that are used in a different stage
9494 // / than its definition.
9595 llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues ();
@@ -275,7 +275,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
275275 return clone;
276276}
277277
278- void LoopPipelinerInternal::emitPrologue (RewriterBase &rewriter) {
278+ LogicalResult LoopPipelinerInternal::emitPrologue (RewriterBase &rewriter) {
279279 // Initialize the iteration argument to the loop initiale values.
280280 for (auto [arg, operand] :
281281 llvm::zip (forOp.getRegionIterArgs (), forOp.getInitsMutable ())) {
@@ -323,7 +323,8 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
323323 if (predicates[predicateIdx]) {
324324 OpBuilder::InsertionGuard insertGuard (rewriter);
325325 newOp = predicateFn (rewriter, newOp, predicates[predicateIdx]);
326- assert (newOp && " failed to predicate op." );
326+ if (newOp == nullptr )
327+ return failure ();
327328 }
328329 if (annotateFn)
329330 annotateFn (newOp, triton::PipeliningOption::PipelinerPart::Prologue, i);
@@ -351,6 +352,7 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
351352 }
352353 }
353354 }
355+ return success ();
354356}
355357
356358llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
@@ -787,7 +789,8 @@ mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
787789 *modifiedIR = true ;
788790
789791 // 1. Emit prologue.
790- pipeliner.emitPrologue (rewriter);
792+ if (failed (pipeliner.emitPrologue (rewriter)))
793+ return failure ();
791794
792795 // 2. Track values used across stages. When a value cross stages it will
793796 // need to be passed as loop iteration arguments.
0 commit comments