diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp index 62887c75c872b..4224925147c84 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp @@ -34,6 +34,20 @@ convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl &fields) { return success(); } +static std::optional +convertIteratorType(IteratorType itTp, SmallVectorImpl &fields) { + // The actually Iterator Values (that are updated every iteration). + auto idxTp = IndexType::get(itTp.getContext()); + // TODO: handle batch dimension. + assert(itTp.getEncoding().getBatchLvlRank() == 0); + if (!itTp.isUnique()) { + // Segment high for non-unique iterator. + fields.push_back(idxTp); + } + fields.push_back(idxTp); + return success(); +} + namespace { /// Sparse codegen rule for number of entries operator. @@ -57,10 +71,114 @@ class ExtractIterSpaceConverter } }; +class SparseIterateOpConverter : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + LogicalResult + matchAndRewrite(IterateOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + if (!op.getCrdUsedLvls().empty()) + return rewriter.notifyMatchFailure( + op, "non-empty coordinates list not implemented."); + + Location loc = op.getLoc(); + + auto iterSpace = SparseIterationSpace::fromValues( + op.getIterSpace().getType(), adaptor.getIterSpace(), 0); + + std::unique_ptr it = + iterSpace.extractIterator(rewriter, loc); + + if (it->iteratableByFor()) { + auto [lo, hi] = it->genForCond(rewriter, loc); + Value step = constantIndex(rewriter, loc, 1); + SmallVector ivs; + for (ValueRange inits : adaptor.getInitArgs()) + llvm::append_range(ivs, inits); + scf::ForOp forOp = rewriter.create(loc, lo, hi, step, ivs); + + Block *loopBody = op.getBody(); + OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes()); + if (failed(typeConverter->convertSignatureArgs( + loopBody->getArgumentTypes(), bodyTypeMapping))) + return failure(); + rewriter.applySignatureConversion(loopBody, bodyTypeMapping); + + rewriter.eraseBlock(forOp.getBody()); + Region &dstRegion = forOp.getRegion(); + rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); + + auto yieldOp = + llvm::cast(forOp.getBody()->getTerminator()); + + rewriter.setInsertionPointToEnd(forOp.getBody()); + // replace sparse_tensor.yield with scf.yield. + rewriter.create(loc, yieldOp.getResults()); + rewriter.eraseOp(yieldOp); + + const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); + rewriter.replaceOp(op, forOp.getResults(), resultMapping); + } else { + SmallVector ivs; + llvm::append_range(ivs, it->getCursor()); + for (ValueRange inits : adaptor.getInitArgs()) + llvm::append_range(ivs, inits); + + assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; })); + + TypeRange types = ValueRange(ivs).getTypes(); + auto whileOp = rewriter.create(loc, types, ivs); + SmallVector l(types.size(), op.getIterator().getLoc()); + + // Generates loop conditions. + Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l); + rewriter.setInsertionPointToStart(before); + ValueRange bArgs = before->getArguments(); + auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs); + assert(remArgs.size() == adaptor.getInitArgs().size()); + rewriter.create(loc, whileCond, before->getArguments()); + + // Generates loop body. + Block *loopBody = op.getBody(); + OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes()); + if (failed(typeConverter->convertSignatureArgs( + loopBody->getArgumentTypes(), bodyTypeMapping))) + return failure(); + rewriter.applySignatureConversion(loopBody, bodyTypeMapping); + + Region &dstRegion = whileOp.getAfter(); + // TODO: handle uses of coordinate! + rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); + ValueRange aArgs = whileOp.getAfterArguments(); + auto yieldOp = llvm::cast( + whileOp.getAfterBody()->getTerminator()); + + rewriter.setInsertionPointToEnd(whileOp.getAfterBody()); + + aArgs = it->linkNewScope(aArgs); + ValueRange nx = it->forward(rewriter, loc); + SmallVector yields; + llvm::append_range(yields, nx); + llvm::append_range(yields, yieldOp.getResults()); + + // replace sparse_tensor.yield with scf.yield. + rewriter.eraseOp(yieldOp); + rewriter.create(loc, yields); + + const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); + rewriter.replaceOp( + op, whileOp.getResults().drop_front(it->getCursor().size()), + resultMapping); + } + return success(); + } +}; + } // namespace mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() { addConversion([](Type type) { return type; }); + addConversion(convertIteratorType); addConversion(convertIterSpaceType); addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp, @@ -74,5 +192,6 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() { void mlir::populateLowerSparseIterationToSCFPatterns( TypeConverter &converter, RewritePatternSet &patterns) { - patterns.add(converter, patterns.getContext()); + patterns.add( + converter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp index be8e15d6ae6f4..ef95fcc84bd90 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp @@ -331,6 +331,13 @@ class TrivialIterator : public ConcreteIterator { TrivialIterator(const SparseTensorLevel &stl) : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {} + TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl, + Value posLo, Value posHi) + : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo), + posHi(posHi) { + seek(posLo); + } + std::string getDebugInterfacePrefix() const override { return std::string("trivial<") + stl.toString() + ">"; } @@ -420,6 +427,14 @@ class DedupIterator : public ConcreteIterator { : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) { assert(!stl.isUnique()); } + + DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl, + Value posLo, Value posHi) + : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) { + assert(!stl.isUnique()); + seek({posLo, genSegmentHigh(b, l, posLo)}); + } + // For LLVM-style RTTI. static bool classof(const SparseIterator *from) { return from->kind == IterKind::kDedup; @@ -1532,6 +1547,11 @@ SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues( return space; } +std::unique_ptr +SparseIterationSpace::extractIterator(OpBuilder &b, Location l) const { + return makeSimpleIterator(b, l, *this); +} + //===----------------------------------------------------------------------===// // SparseIterator factory functions. //===----------------------------------------------------------------------===// @@ -1590,6 +1610,26 @@ sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, return std::make_pair(std::move(stl), std::move(it)); } +std::unique_ptr +sparse_tensor::makeSimpleIterator(OpBuilder &b, Location l, + const SparseIterationSpace &iterSpace) { + // assert(iterSpace.getSpaceDim() == 1); + std::unique_ptr ret; + if (!iterSpace.isUnique()) { + // We always dedupliate the non-unique level, but we should optimize it away + // if possible. + ret = std::make_unique(b, l, iterSpace.getLastLvl(), + iterSpace.getBoundLo(), + iterSpace.getBoundHi()); + } else { + ret = std::make_unique(b, l, iterSpace.getLastLvl(), + iterSpace.getBoundLo(), + iterSpace.getBoundHi()); + } + ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional); + return ret; +} + std::unique_ptr sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl, SparseEmitStrategy strategy) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h index 17636af2b2f9d..91f363db93f1d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h @@ -132,6 +132,10 @@ class SparseIterationSpace { Value getBoundLo() const { return bound.first; } Value getBoundHi() const { return bound.second; } + // Extract an iterator to iterate over the sparse iteration space. + std::unique_ptr extractIterator(OpBuilder &b, + Location l) const; + private: SmallVector> lvls; std::pair bound; @@ -192,6 +196,13 @@ class SparseIterator { crd = nullptr; } + // Reconstructs a iteration space directly from the provided ValueRange. + static std::unique_ptr + fromValues(IteratorType dstTp, ValueRange values, unsigned tid); + + // The inverse operation of `fromValues`. + SmallVector toValues() const { llvm_unreachable("Not implemented"); } + // // Iterator properties. // @@ -345,12 +356,21 @@ std::unique_ptr makeSparseTensorLevel(OpBuilder &b, unsigned tid, Level lvl); -/// Helper function to create a TensorLevel object from given `tensor`. +/// Helper function to create a TensorLevel object from given ValueRange. std::unique_ptr makeSparseTensorLevel(LevelType lt, Value sz, ValueRange buffers, unsigned tid, Level l); -/// Helper function to create a simple SparseIterator object that iterates -/// over the SparseTensorLevel. + +/// Helper function to create a simple SparseIterator object that iterate +/// over the entire iteration space. +std::unique_ptr +makeSimpleIterator(OpBuilder &b, Location l, + const SparseIterationSpace &iterSpace); + +/// Helper function to create a simple SparseIterator object that iterate +/// over the sparse tensor level. +/// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when +/// feature complete. std::unique_ptr makeSimpleIterator( const SparseTensorLevel &stl, SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional); diff --git a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir index 5fcd661bb69b2..77a0e89dc7c81 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s --lower-sparse-iteration-to-scf | FileCheck %s +// RUN: mlir-opt %s --sparse-space-collapse --lower-sparse-iteration-to-scf | FileCheck %s --check-prefix COLLAPSED #COO = #sparse_tensor.encoding<{ map = (i, j) -> ( @@ -7,17 +8,44 @@ ) }> -// CHECK-LABEL: func.func @sparse_1D_space( -// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> { -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[LVL_SIZE:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[C0]] : tensor -// CHECK: %[[POS_MEM:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor to memref -// CHECK: %[[CRD_MEM:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor to memref -// CHECK: %[[POS_LO:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C0]]] : memref -// CHECK: %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref -// CHECK: %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]] -func.func @sparse_1D_space(%sp : tensor) -> !sparse_tensor.iter_space<#COO, lvls = 0> { - %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor -> !sparse_tensor.iter_space<#COO, lvls = 0> - return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0> +// CHECK-LABEL: @sparse_iteration_to_scf +// // deduplication +// CHECK: scf.while {{.*}} { +// CHECK: } do { +// CHECK: } +// CHECK: scf.while {{.*}} { +// CHECK: } do { +// // actual computation +// CHECK: scf.for {{.*}} { +// CHECK: arith.addi +// CHECK: } +// // deduplication +// CHECK: scf.while {{.*}} { +// CHECK: } do { +// CHECK: } +// CHECK: scf.yield +// CHECK: } +// CHECK: return + +// COLLAPSED-LABEL: @sparse_iteration_to_scf +// COLLAPSED: %[[RET:.*]] = scf.for {{.*}} { +// COLLAPSED: %[[VAL:.*]] = arith.addi +// COLLAPSED: scf.yield %[[VAL]] : index +// COLLAPSED: } +// COLLAPSED: return %[[RET]] : index +func.func @sparse_iteration_to_scf(%sp : tensor<4x8xf32, #COO>) -> index { + %i = arith.constant 0 : index + %c1 = arith.constant 1 : index + %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 + : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0> + %r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index { + %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 + : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> -> !sparse_tensor.iter_space<#COO, lvls = 1> + %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index { + %k = arith.addi %inner, %c1 : index + sparse_tensor.yield %k : index + } + sparse_tensor.yield %r2 : index + } + return %r1 : index }