@@ -1189,27 +1189,38 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
11891189 LogicalResult matchAndRewrite (NewOp op,
11901190 PatternRewriter &rewriter) const override {
11911191 Location loc = op.getLoc ();
1192- const auto dstTp = getSparseTensorType (op.getResult ());
1193- const auto encDst = dstTp .getEncoding ();
1194- if (!dstTp .hasEncoding () || getCOOStart (encDst ) == 0 )
1192+ auto stt = getSparseTensorType (op.getResult ());
1193+ auto enc = stt .getEncoding ();
1194+ if (!stt .hasEncoding () || getCOOStart (enc ) == 0 )
11951195 return failure ();
11961196
11971197 // Implement the NewOp as follows:
11981198 // %orderedCoo = sparse_tensor.new %filename
11991199 // %t = sparse_tensor.convert %orderedCoo
1200+ // with enveloping reinterpreted_map ops for non-permutations.
1201+ RankedTensorType dstTp = stt.getRankedTensorType ();
12001202 RankedTensorType cooTp = getCOOType (dstTp, /* ordered=*/ true );
12011203 Value cooTensor = rewriter.create <NewOp>(loc, cooTp, op.getSource ());
1202- Value convert = rewriter.replaceOpWithNewOp <ConvertOp>(
1203- op, dstTp.getRankedTensorType (), cooTensor);
1204+ Value convert = cooTensor;
1205+ if (!stt.isPermutation ()) { // demap coo, demap dstTp
1206+ auto coo = getSparseTensorType (cooTensor).getEncoding ().withoutDimToLvl ();
1207+ convert = rewriter.create <ReinterpretMapOp>(loc, coo, convert);
1208+ dstTp = getSparseTensorType (convert).withEncoding (enc.withoutDimToLvl ());
1209+ }
1210+ convert = rewriter.create <ConvertOp>(loc, dstTp, convert);
1211+ if (!stt.isPermutation ()) // remap to original enc
1212+ convert = rewriter.create <ReinterpretMapOp>(loc, enc, convert);
1213+ rewriter.replaceOp (op, convert);
12041214
1205- // Release the ordered COO tensor.
1215+ // Release the temporary ordered COO tensor.
12061216 rewriter.setInsertionPointAfterValue (convert);
12071217 rewriter.create <DeallocTensorOp>(loc, cooTensor);
12081218
12091219 return success ();
12101220 }
12111221};
12121222
1223+ // / Sparse rewriting rule for the out operator.
12131224struct OutRewriter : public OpRewritePattern <OutOp> {
12141225 using OpRewritePattern::OpRewritePattern;
12151226 LogicalResult matchAndRewrite (OutOp op,
@@ -1250,6 +1261,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
12501261 primaryTypeFunctionSuffix (eltTp)};
12511262 Value value = genAllocaScalar (rewriter, loc, eltTp);
12521263 ModuleOp module = op->getParentOfType <ModuleOp>();
1264+
12531265 // For each element in the source tensor, output the element.
12541266 rewriter.create <ForeachOp>(
12551267 loc, src, std::nullopt ,
0 commit comments