11
22#include " Utils/CodegenUtils.h"
3+ #include " Utils/LoopEmitter.h"
34#include " Utils/SparseTensorIterator.h"
45
56#include " mlir/Dialect/MemRef/IR/MemRef.h"
@@ -49,6 +50,144 @@ convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
4950 return success ();
5051}
5152
53+ static ValueRange
54+ genCoIterateBranchNest (PatternRewriter &rewriter, Location loc, CoIterateOp op,
55+ Value loopCrd,
56+ ArrayRef<std::unique_ptr<SparseIterator>> iters,
57+ ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
58+ if (subCases.empty ())
59+ return userReduc;
60+
61+ // The current branch that we are handling.
62+ Region *b = subCases.front ();
63+ Value casePred = constantI1 (rewriter, loc, true );
64+ I64BitSet caseBits = op.getRegionDefinedSpace (b->getRegionNumber ());
65+ for (unsigned i : caseBits.bits ()) {
66+ SparseIterator *it = iters[i].get ();
67+ Value pred = rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
68+ it->getCrd (), loopCrd);
69+ casePred = rewriter.create <arith::AndIOp>(loc, casePred, pred);
70+ }
71+ scf::IfOp ifOp = rewriter.create <scf::IfOp>(
72+ loc, ValueRange (userReduc).getTypes (), casePred, /* else=*/ true );
73+ rewriter.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
74+
75+ // Erase the empty block.
76+ rewriter.eraseBlock (&ifOp.getThenRegion ().front ());
77+ // Set up block arguments: user-provided values -> loop coord -> iterators.
78+ SmallVector<Value> blockArgs (userReduc);
79+ blockArgs.push_back (loopCrd);
80+ for (unsigned idx : caseBits.bits ())
81+ llvm::append_range (blockArgs, iters[idx]->getCursor ());
82+
83+ IRMapping mapping;
84+ for (auto [from, to] :
85+ llvm::zip_equal (b->front ().getArguments (), blockArgs)) {
86+ mapping.map (from, to);
87+ }
88+
89+ // Clone the region, we can not erase the region now because the same region
90+ // might be a subcase for multiple lattice point.
91+ rewriter.cloneRegionBefore (*b, ifOp.getThenRegion (),
92+ ifOp.getThenRegion ().begin (), mapping);
93+
94+ // replace sparse_tensor::YieldOp -> scf::YieldOp
95+ auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion ().front ().back ());
96+ ValueRange yields = spY.getResults ();
97+ rewriter.eraseOp (spY);
98+ rewriter.setInsertionPointToEnd (&ifOp.getThenRegion ().front ());
99+ rewriter.create <scf::YieldOp>(loc, yields);
100+
101+ // Generates remaining case recursively.
102+ rewriter.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
103+ ValueRange res = genCoIterateBranchNest (rewriter, loc, op, loopCrd, iters,
104+ subCases.drop_front (), userReduc);
105+ if (!res.empty ())
106+ rewriter.create <scf::YieldOp>(loc, res);
107+
108+ rewriter.setInsertionPointAfter (ifOp);
109+ return ifOp.getResults ();
110+ }
111+
112+ static ValueRange genLoopWithIterator (
113+ PatternRewriter &rewriter, Location loc, SparseIterator *it,
114+ ValueRange reduc, bool iterFirst,
115+ function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
116+ Region &loopBody, SparseIterator *it,
117+ ValueRange reduc)>
118+ bodyBuilder) {
119+ if (it->iteratableByFor ()) {
120+ auto [lo, hi] = it->genForCond (rewriter, loc);
121+ Value step = constantIndex (rewriter, loc, 1 );
122+ scf::ForOp forOp = rewriter.create <scf::ForOp>(loc, lo, hi, step, reduc);
123+ {
124+ OpBuilder::InsertionGuard guard (rewriter);
125+ // Erase the implicit yield operation created by ForOp when there is no
126+ // yielding values.
127+ if (!forOp.getBody ()->empty ())
128+ rewriter.eraseOp (&forOp.getBody ()->front ());
129+ assert (forOp.getBody ()->empty ());
130+
131+ it->linkNewScope (forOp.getInductionVar ());
132+ rewriter.setInsertionPointToStart (forOp.getBody ());
133+ SmallVector<Value> ret = bodyBuilder (rewriter, loc, forOp.getBodyRegion (),
134+ it, forOp.getRegionIterArgs ());
135+
136+ rewriter.setInsertionPointToEnd (forOp.getBody ());
137+ rewriter.create <scf::YieldOp>(loc, ret);
138+ }
139+ return forOp.getResults ();
140+ }
141+ SmallVector<Value> ivs;
142+ // TODO: always put iterator SSA values at the end of argument list to be
143+ // consistent with coiterate operation.
144+ if (!iterFirst)
145+ llvm::append_range (ivs, it->getCursor ());
146+ // Appends the user-provided values.
147+ llvm::append_range (ivs, reduc);
148+ if (iterFirst)
149+ llvm::append_range (ivs, it->getCursor ());
150+
151+ TypeRange types = ValueRange (ivs).getTypes ();
152+ auto whileOp = rewriter.create <scf::WhileOp>(loc, types, ivs);
153+ {
154+ OpBuilder::InsertionGuard guard (rewriter);
155+ // Generates loop conditions.
156+ SmallVector<Location> l (types.size (), loc);
157+ Block *before = rewriter.createBlock (&whileOp.getBefore (), {}, types, l);
158+ rewriter.setInsertionPointToStart (before);
159+ ValueRange bArgs = before->getArguments ();
160+ auto [whileCond, remArgs] = it->genWhileCond (rewriter, loc, bArgs);
161+ rewriter.create <scf::ConditionOp>(loc, whileCond, before->getArguments ());
162+
163+ // Delegates loop body generation.
164+ Region &dstRegion = whileOp.getAfter ();
165+ Block *after = rewriter.createBlock (&dstRegion, {}, types, l);
166+ ValueRange aArgs = whileOp.getAfterArguments ();
167+ if (iterFirst) {
168+ aArgs = it->linkNewScope (aArgs);
169+ } else {
170+ aArgs = aArgs.take_front (reduc.size ());
171+ it->linkNewScope (aArgs.drop_front (reduc.size ()));
172+ }
173+
174+ rewriter.setInsertionPointToStart (after);
175+ SmallVector<Value> ret = bodyBuilder (rewriter, loc, dstRegion, it, aArgs);
176+ rewriter.setInsertionPointToEnd (after);
177+
178+ // Forward loops
179+ SmallVector<Value> yields;
180+ ValueRange nx = it->forward (rewriter, loc);
181+ if (iterFirst)
182+ llvm::append_range (yields, nx);
183+ llvm::append_range (yields, ret);
184+ if (!iterFirst)
185+ llvm::append_range (yields, nx);
186+ rewriter.create <scf::YieldOp>(loc, yields);
187+ }
188+ return whileOp.getResults ().drop_front (it->getCursor ().size ());
189+ }
190+
52191namespace {
53192
54193// / Sparse codegen rule for number of entries operator.
@@ -136,6 +275,8 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
136275 rewriter.replaceOp (op, forOp.getResults (), resultMapping);
137276 } else {
138277 SmallVector<Value> ivs;
278+ // TODO: put iterator at the end of argument list to be consistent with
279+ // coiterate operation.
139280 llvm::append_range (ivs, it->getCursor ());
140281 for (ValueRange inits : adaptor.getInitArgs ())
141282 llvm::append_range (ivs, inits);
@@ -189,6 +330,153 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
189330 }
190331};
191332
333+ class SparseCoIterateOpConverter
334+ : public OneToNOpConversionPattern<CoIterateOp> {
335+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
336+
337+ LogicalResult
338+ matchAndRewrite (CoIterateOp op, OpAdaptor adaptor,
339+ OneToNPatternRewriter &rewriter) const override {
340+ assert (op.getSpaceDim () == 1 && " Not implemented" );
341+ Location loc = op.getLoc ();
342+
343+ I64BitSet denseBits (0 );
344+ for (auto [idx, spaceTp] : llvm::enumerate (op.getIterSpaces ().getTypes ()))
345+ if (all_of (cast<IterSpaceType>(spaceTp).getLvlTypes (), isDenseLT))
346+ denseBits.set (idx);
347+
348+ // If there exists a case that only contains dense spaces. I.e., case
349+ // bits is a subset of dense bits, or when there is a full empty case (due
350+ // to complements), we need a universal pointer to forward the coiteration
351+ // loop.
352+ bool needUniv =
353+ any_of (op.getRegionDefinedSpaces (), [denseBits](I64BitSet caseBits) {
354+ // A case for complement.
355+ if (caseBits.count () == 0 )
356+ return true ;
357+ // An all-dense case.
358+ return caseBits.isSubSetOf (denseBits);
359+ });
360+ assert (!needUniv && " Not implemented" );
361+ (void )needUniv;
362+
363+ for (Region ®ion : op.getCaseRegions ()) {
364+ // Do a one-shot type conversion on all region blocks, since the same
365+ // region might be used multiple time.
366+ Block *block = ®ion.getBlocks ().front ();
367+ OneToNTypeMapping blockTypeMapping (block->getArgumentTypes ());
368+ if (failed (typeConverter->convertSignatureArgs (block->getArgumentTypes (),
369+ blockTypeMapping)))
370+ return rewriter.notifyMatchFailure (
371+ op, " failed to convert coiterate region argurment types" );
372+
373+ rewriter.applySignatureConversion (block, blockTypeMapping);
374+ }
375+
376+ SmallVector<SparseIterationSpace> spaces;
377+ SmallVector<std::unique_ptr<SparseIterator>> iters;
378+ for (auto [spaceTp, spaceVals] : llvm::zip_equal (
379+ op.getIterSpaces ().getTypes (), adaptor.getIterSpaces ())) {
380+ // TODO: do we really need tid?
381+ spaces.push_back (SparseIterationSpace::fromValues (
382+ cast<IterSpaceType>(spaceTp), spaceVals, /* tid=*/ 0 ));
383+ // Extract the iterator.
384+ iters.push_back (spaces.back ().extractIterator (rewriter, loc));
385+ }
386+
387+ auto getFilteredIters = [&iters](I64BitSet caseBits) {
388+ // Retrives a vector of pointers to the iterators used in the case.
389+ SmallVector<SparseIterator *> validIters;
390+ for (auto idx : caseBits.bits ())
391+ validIters.push_back (iters[idx].get ());
392+ return validIters;
393+ };
394+
395+ // Get a flattened user-provided loop reduction values.
396+ SmallVector<Value> userReduc;
397+ for (ValueRange r : adaptor.getInitArgs ())
398+ llvm::append_range (userReduc, r);
399+
400+ // TODO: we need to sort the cases such that they appears in lexical order.
401+ // Although sparsification always generates cases in that order, it might
402+ // not be the case for human-written code.
403+
404+ // Generates a loop sequence, one loop per case.
405+ for (auto [r, caseBits] :
406+ llvm::zip_equal (op.getCaseRegions (), op.getRegionDefinedSpaces ())) {
407+ assert (caseBits.count () > 0 && " Complement space not implemented" );
408+
409+ // Retrives a vector of pointers to the iterators used in the case.
410+ SmallVector<SparseIterator *> validIters = getFilteredIters (caseBits);
411+
412+ if (validIters.size () > 1 ) {
413+ auto [loop, loopCrd] =
414+ genCoIteration (rewriter, loc, validIters, userReduc,
415+ /* uniIdx=*/ nullptr , /* userReducFirst=*/ true );
416+
417+ // 1st. find all the cases that is a strict subset of the current case
418+ // condition, for which we generate one branch per case inside the loop.
419+ // The subcases are never empty, it must contains at least the current
420+ // region itself.
421+ // TODO: these cases should be sorted.
422+ SmallVector<Region *> subCases = op.getSubCasesOf (r.getRegionNumber ());
423+ assert (!subCases.empty ());
424+
425+ ValueRange res = genCoIterateBranchNest (rewriter, loc, op, loopCrd,
426+ iters, subCases, userReduc);
427+
428+ SmallVector<Value> nextIterYields (res);
429+ // 2nd. foward the loop.
430+ for (SparseIterator *it : validIters) {
431+ Value cmp = rewriter.create <arith::CmpIOp>(
432+ loc, arith::CmpIPredicate::eq, it->getCrd (), loopCrd);
433+ it->forwardIf (rewriter, loc, cmp);
434+ llvm::append_range (nextIterYields, it->getCursor ());
435+ }
436+ rewriter.create <scf::YieldOp>(loc, nextIterYields);
437+
438+ // Exit the loop, relink the iterator SSA value.
439+ rewriter.setInsertionPointAfter (loop);
440+ ValueRange iterVals = loop->getResults ().drop_front (userReduc.size ());
441+ for (SparseIterator *it : validIters)
442+ iterVals = it->linkNewScope (iterVals);
443+ assert (iterVals.empty ());
444+
445+ ValueRange curResult = loop->getResults ().take_front (userReduc.size ());
446+ userReduc.assign (curResult.begin (), curResult.end ());
447+ } else {
448+ // This is a simple iteration loop.
449+ assert (caseBits.count () == 1 );
450+
451+ Block *block = &r.getBlocks ().front ();
452+ ValueRange curResult = genLoopWithIterator (
453+ rewriter, loc, validIters.front (), userReduc, /* iterFirst=*/ false ,
454+ /* bodyBuilder=*/
455+ [block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
456+ SparseIterator *it,
457+ ValueRange reduc) -> SmallVector<Value> {
458+ SmallVector<Value> blockArgs (reduc);
459+ blockArgs.push_back (it->deref (rewriter, loc));
460+ llvm::append_range (blockArgs, it->getCursor ());
461+
462+ Block *dstBlock = &dstRegion.getBlocks ().front ();
463+ rewriter.inlineBlockBefore (
464+ block, dstBlock, rewriter.getInsertionPoint (), blockArgs);
465+ auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back ());
466+ SmallVector<Value> result (yield.getResults ());
467+ rewriter.eraseOp (yield);
468+ return result;
469+ });
470+
471+ userReduc.assign (curResult.begin (), curResult.end ());
472+ }
473+ }
474+
475+ rewriter.replaceOp (op, userReduc);
476+ return success ();
477+ }
478+ };
479+
192480} // namespace
193481
194482mlir::SparseIterationTypeConverter::SparseIterationTypeConverter () {
@@ -210,5 +498,6 @@ void mlir::populateLowerSparseIterationToSCFPatterns(
210498
211499 IterateOp::getCanonicalizationPatterns (patterns, patterns.getContext ());
212500 patterns.add <ExtractIterSpaceConverter, ExtractValOpConverter,
213- SparseIterateOpConverter>(converter, patterns.getContext ());
501+ SparseIterateOpConverter, SparseCoIterateOpConverter>(
502+ converter, patterns.getContext ());
214503}
0 commit comments