Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 36 additions & 82 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,88 +244,41 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
std::unique_ptr<SparseIterator> it =
iterSpace.extractIterator(rewriter, loc);

if (it->iteratableByFor()) {
auto [lo, hi] = it->genForCond(rewriter, loc);
Value step = constantIndex(rewriter, loc, 1);
SmallVector<Value> ivs;
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);

Block *loopBody = op.getBody();
OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(
loopBody->getArgumentTypes(), bodyTypeMapping)))
return failure();
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);

rewriter.eraseBlock(forOp.getBody());
Region &dstRegion = forOp.getRegion();
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());

auto yieldOp =
llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());

rewriter.setInsertionPointToEnd(forOp.getBody());
// replace sparse_tensor.yield with scf.yield.
rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
rewriter.eraseOp(yieldOp);

const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
rewriter.replaceOp(op, forOp.getResults(), resultMapping);
} else {
SmallVector<Value> ivs;
// TODO: put iterator at the end of argument list to be consistent with
// coiterate operation.
llvm::append_range(ivs, it->getCursor());
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);

assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));

TypeRange types = ValueRange(ivs).getTypes();
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
SmallVector<Location> l(types.size(), op.getIterator().getLoc());

// Generates loop conditions.
Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
rewriter.setInsertionPointToStart(before);
ValueRange bArgs = before->getArguments();
auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
assert(remArgs.size() == adaptor.getInitArgs().size());
rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());

// Generates loop body.
Block *loopBody = op.getBody();
OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(
loopBody->getArgumentTypes(), bodyTypeMapping)))
return failure();
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);

Region &dstRegion = whileOp.getAfter();
// TODO: handle uses of coordinate!
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
ValueRange aArgs = whileOp.getAfterArguments();
auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
whileOp.getAfterBody()->getTerminator());

rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
SmallVector<Value> ivs;
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);

// Type conversion on iterate op block.
OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(
op.getBody()->getArgumentTypes(), blockTypeMapping)))
return rewriter.notifyMatchFailure(
op, "failed to convert iterate region argurment types");
rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);

Block *block = op.getBody();
ValueRange ret = genLoopWithIterator(
rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
SmallVector<Value> blockArgs(it->getCursor());
// TODO: Also appends coordinates if used.
// blockArgs.push_back(it->deref(rewriter, loc));
llvm::append_range(blockArgs, reduc);

Block *dstBlock = &loopBody.getBlocks().front();
rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
blockArgs);
auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
// We can not use ValueRange as the operation holding the values will
// be destoryed.
SmallVector<Value> result(yield.getResults());
rewriter.eraseOp(yield);
return result;
});

aArgs = it->linkNewScope(aArgs);
ValueRange nx = it->forward(rewriter, loc);
SmallVector<Value> yields;
llvm::append_range(yields, nx);
llvm::append_range(yields, yieldOp.getResults());

// replace sparse_tensor.yield with scf.yield.
rewriter.eraseOp(yieldOp);
rewriter.create<scf::YieldOp>(loc, yields);
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
rewriter.replaceOp(
op, whileOp.getResults().drop_front(it->getCursor().size()),
resultMapping);
}
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
rewriter.replaceOp(op, ret, resultMapping);
return success();
}
};
Expand Down Expand Up @@ -366,9 +319,10 @@ class SparseCoIterateOpConverter
Block *block = &region.getBlocks().front();
OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
blockTypeMapping)))
blockTypeMapping))) {
return rewriter.notifyMatchFailure(
op, "failed to convert coiterate region argurment types");
}

rewriter.applySignatureConversion(block, blockTypeMapping);
}
Expand Down