Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1189,27 +1189,38 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
LogicalResult matchAndRewrite(NewOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
const auto dstTp = getSparseTensorType(op.getResult());
const auto encDst = dstTp.getEncoding();
if (!dstTp.hasEncoding() || getCOOStart(encDst) == 0)
auto stt = getSparseTensorType(op.getResult());
auto enc = stt.getEncoding();
if (!stt.hasEncoding() || getCOOStart(enc) == 0)
return failure();

// Implement the NewOp as follows:
// %orderedCoo = sparse_tensor.new %filename
// %t = sparse_tensor.convert %orderedCoo
// with enveloping reinterpreted_map ops for non-permutations.
RankedTensorType dstTp = stt.getRankedTensorType();
RankedTensorType cooTp = getCOOType(dstTp, /*ordered=*/true);
Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
Value convert = rewriter.replaceOpWithNewOp<ConvertOp>(
op, dstTp.getRankedTensorType(), cooTensor);
Value convert = cooTensor;
if (!stt.isPermutation()) { // demap coo, demap dstTp
auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
}
convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
if (!stt.isPermutation()) // remap to original enc
convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
rewriter.replaceOp(op, convert);

// Release the ordered COO tensor.
// Release the temporary ordered COO tensor.
rewriter.setInsertionPointAfterValue(convert);
rewriter.create<DeallocTensorOp>(loc, cooTensor);

return success();
}
};

/// Sparse rewriting rule for the out operator.
struct OutRewriter : public OpRewritePattern<OutOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(OutOp op,
Expand Down Expand Up @@ -1250,6 +1261,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
primaryTypeFunctionSuffix(eltTp)};
Value value = genAllocaScalar(rewriter, loc, eltTp);
ModuleOp module = op->getParentOfType<ModuleOp>();

// For each element in the source tensor, output the element.
rewriter.create<ForeachOp>(
loc, src, std::nullopt,
Expand Down