diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h index 388efd1c454b1..fca2629d72efc 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -96,24 +96,32 @@ class I64BitSet { return *this; } + bool isSubSetOf(const I64BitSet p) const { + I64BitSet tmp = *this; + tmp |= p; + return tmp == p; + } + // Needed by `llvm::const_set_bits_iterator_impl`. int find_first() const { return min(); } int find_next(unsigned prev) const { - if (prev >= max()) + if (prev >= max() - 1) return -1; - uint64_t b = storage >> (prev + 1); - if (b == 0) - return -1; + uint64_t b = storage >> (prev + static_cast(1)); + assert(b != 0); - return llvm::countr_zero(b) + prev + 1; + return llvm::countr_zero(b) + prev + static_cast(1); } bool operator[](unsigned i) const { assert(i < 64); - return (storage & (1 << i)) != 0; + return (storage & (static_cast(1) << i)) != 0; + } + unsigned min() const { + unsigned m = llvm::countr_zero(storage); + return m == 64 ? -1 : m; } - unsigned min() const { return llvm::countr_zero(storage); } unsigned max() const { return 64 - llvm::countl_zero(storage); } unsigned count() const { return llvm::popcount(storage); } bool empty() const { return storage == 0; } diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 2803223354d5e..20512f972e67c 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -1787,6 +1787,10 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate", .take_back(getRegionDefinedSpace(regionIdx).count()); } ValueRange getYieldedValues(unsigned regionIdx); + + // Returns a vector of regions that are the `sub-cases` of the given case region. + // E.g., `case %it1, _, %it3` is a subcase of `case %it1, %it2, %it3`. + SmallVector getSubCasesOf(unsigned regionIdx); }]; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index a143189c301a4..16856b958d4f1 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -2745,6 +2745,16 @@ LogicalResult CoIterateOp::verifyRegions() { return success(); } +SmallVector CoIterateOp::getSubCasesOf(unsigned regionIdx) { + SmallVector ret; + I64BitSet caseBit = getRegionDefinedSpace(regionIdx); + for (Region &r : getCaseRegions()) + if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit)) + ret.push_back(&r); + + return ret; +} + //===----------------------------------------------------------------------===// // Sparse Tensor Dialect Setups. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp index b1451dee738ac..d6c0da4a9e457 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp @@ -1,5 +1,6 @@ #include "Utils/CodegenUtils.h" +#include "Utils/LoopEmitter.h" #include "Utils/SparseTensorIterator.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -49,6 +50,144 @@ convertIteratorType(IteratorType itTp, SmallVectorImpl &fields) { return success(); } +static ValueRange +genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, + Value loopCrd, + ArrayRef> iters, + ArrayRef subCases, ArrayRef userReduc) { + if (subCases.empty()) + return userReduc; + + // The current branch that we are handling. + Region *b = subCases.front(); + Value casePred = constantI1(rewriter, loc, true); + I64BitSet caseBits = op.getRegionDefinedSpace(b->getRegionNumber()); + for (unsigned i : caseBits.bits()) { + SparseIterator *it = iters[i].get(); + Value pred = rewriter.create(loc, arith::CmpIPredicate::eq, + it->getCrd(), loopCrd); + casePred = rewriter.create(loc, casePred, pred); + } + scf::IfOp ifOp = rewriter.create( + loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true); + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + + // Erase the empty block. + rewriter.eraseBlock(&ifOp.getThenRegion().front()); + // Set up block arguments: user-provided values -> loop coord -> iterators. + SmallVector blockArgs(userReduc); + blockArgs.push_back(loopCrd); + for (unsigned idx : caseBits.bits()) + llvm::append_range(blockArgs, iters[idx]->getCursor()); + + IRMapping mapping; + for (auto [from, to] : + llvm::zip_equal(b->front().getArguments(), blockArgs)) { + mapping.map(from, to); + } + + // Clone the region, we can not erase the region now because the same region + // might be a subcase for multiple lattice point. + rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(), + ifOp.getThenRegion().begin(), mapping); + + // replace sparse_tensor::YieldOp -> scf::YieldOp + auto spY = cast(&ifOp.getThenRegion().front().back()); + ValueRange yields = spY.getResults(); + rewriter.eraseOp(spY); + rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front()); + rewriter.create(loc, yields); + + // Generates remaining case recursively. + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters, + subCases.drop_front(), userReduc); + if (!res.empty()) + rewriter.create(loc, res); + + rewriter.setInsertionPointAfter(ifOp); + return ifOp.getResults(); +} + +static ValueRange genLoopWithIterator( + PatternRewriter &rewriter, Location loc, SparseIterator *it, + ValueRange reduc, bool iterFirst, + function_ref(PatternRewriter &rewriter, Location loc, + Region &loopBody, SparseIterator *it, + ValueRange reduc)> + bodyBuilder) { + if (it->iteratableByFor()) { + auto [lo, hi] = it->genForCond(rewriter, loc); + Value step = constantIndex(rewriter, loc, 1); + scf::ForOp forOp = rewriter.create(loc, lo, hi, step, reduc); + { + OpBuilder::InsertionGuard guard(rewriter); + // Erase the implicit yield operation created by ForOp when there is no + // yielding values. + if (!forOp.getBody()->empty()) + rewriter.eraseOp(&forOp.getBody()->front()); + assert(forOp.getBody()->empty()); + + it->linkNewScope(forOp.getInductionVar()); + rewriter.setInsertionPointToStart(forOp.getBody()); + SmallVector ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(), + it, forOp.getRegionIterArgs()); + + rewriter.setInsertionPointToEnd(forOp.getBody()); + rewriter.create(loc, ret); + } + return forOp.getResults(); + } + SmallVector ivs; + // TODO: always put iterator SSA values at the end of argument list to be + // consistent with coiterate operation. + if (!iterFirst) + llvm::append_range(ivs, it->getCursor()); + // Appends the user-provided values. + llvm::append_range(ivs, reduc); + if (iterFirst) + llvm::append_range(ivs, it->getCursor()); + + TypeRange types = ValueRange(ivs).getTypes(); + auto whileOp = rewriter.create(loc, types, ivs); + { + OpBuilder::InsertionGuard guard(rewriter); + // Generates loop conditions. + SmallVector l(types.size(), loc); + Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l); + rewriter.setInsertionPointToStart(before); + ValueRange bArgs = before->getArguments(); + auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs); + rewriter.create(loc, whileCond, before->getArguments()); + + // Delegates loop body generation. + Region &dstRegion = whileOp.getAfter(); + Block *after = rewriter.createBlock(&dstRegion, {}, types, l); + ValueRange aArgs = whileOp.getAfterArguments(); + if (iterFirst) { + aArgs = it->linkNewScope(aArgs); + } else { + aArgs = aArgs.take_front(reduc.size()); + it->linkNewScope(aArgs.drop_front(reduc.size())); + } + + rewriter.setInsertionPointToStart(after); + SmallVector ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs); + rewriter.setInsertionPointToEnd(after); + + // Forward loops + SmallVector yields; + ValueRange nx = it->forward(rewriter, loc); + if (iterFirst) + llvm::append_range(yields, nx); + llvm::append_range(yields, ret); + if (!iterFirst) + llvm::append_range(yields, nx); + rewriter.create(loc, yields); + } + return whileOp.getResults().drop_front(it->getCursor().size()); +} + namespace { /// Sparse codegen rule for number of entries operator. @@ -136,6 +275,8 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern { rewriter.replaceOp(op, forOp.getResults(), resultMapping); } else { SmallVector ivs; + // TODO: put iterator at the end of argument list to be consistent with + // coiterate operation. llvm::append_range(ivs, it->getCursor()); for (ValueRange inits : adaptor.getInitArgs()) llvm::append_range(ivs, inits); @@ -189,6 +330,153 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern { } }; +class SparseCoIterateOpConverter + : public OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(CoIterateOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + assert(op.getSpaceDim() == 1 && "Not implemented"); + Location loc = op.getLoc(); + + I64BitSet denseBits(0); + for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes())) + if (all_of(cast(spaceTp).getLvlTypes(), isDenseLT)) + denseBits.set(idx); + + // If there exists a case that only contains dense spaces. I.e., case + // bits is a subset of dense bits, or when there is a full empty case (due + // to complements), we need a universal pointer to forward the coiteration + // loop. + bool needUniv = + any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) { + // A case for complement. + if (caseBits.count() == 0) + return true; + // An all-dense case. + return caseBits.isSubSetOf(denseBits); + }); + assert(!needUniv && "Not implemented"); + (void)needUniv; + + for (Region ®ion : op.getCaseRegions()) { + // Do a one-shot type conversion on all region blocks, since the same + // region might be used multiple time. + Block *block = ®ion.getBlocks().front(); + OneToNTypeMapping blockTypeMapping(block->getArgumentTypes()); + if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), + blockTypeMapping))) + return rewriter.notifyMatchFailure( + op, "failed to convert coiterate region argurment types"); + + rewriter.applySignatureConversion(block, blockTypeMapping); + } + + SmallVector spaces; + SmallVector> iters; + for (auto [spaceTp, spaceVals] : llvm::zip_equal( + op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) { + // TODO: do we really need tid? + spaces.push_back(SparseIterationSpace::fromValues( + cast(spaceTp), spaceVals, /*tid=*/0)); + // Extract the iterator. + iters.push_back(spaces.back().extractIterator(rewriter, loc)); + } + + auto getFilteredIters = [&iters](I64BitSet caseBits) { + // Retrives a vector of pointers to the iterators used in the case. + SmallVector validIters; + for (auto idx : caseBits.bits()) + validIters.push_back(iters[idx].get()); + return validIters; + }; + + // Get a flattened user-provided loop reduction values. + SmallVector userReduc; + for (ValueRange r : adaptor.getInitArgs()) + llvm::append_range(userReduc, r); + + // TODO: we need to sort the cases such that they appears in lexical order. + // Although sparsification always generates cases in that order, it might + // not be the case for human-written code. + + // Generates a loop sequence, one loop per case. + for (auto [r, caseBits] : + llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) { + assert(caseBits.count() > 0 && "Complement space not implemented"); + + // Retrives a vector of pointers to the iterators used in the case. + SmallVector validIters = getFilteredIters(caseBits); + + if (validIters.size() > 1) { + auto [loop, loopCrd] = + genCoIteration(rewriter, loc, validIters, userReduc, + /*uniIdx=*/nullptr, /*userReducFirst=*/true); + + // 1st. find all the cases that is a strict subset of the current case + // condition, for which we generate one branch per case inside the loop. + // The subcases are never empty, it must contains at least the current + // region itself. + // TODO: these cases should be sorted. + SmallVector subCases = op.getSubCasesOf(r.getRegionNumber()); + assert(!subCases.empty()); + + ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, + iters, subCases, userReduc); + + SmallVector nextIterYields(res); + // 2nd. foward the loop. + for (SparseIterator *it : validIters) { + Value cmp = rewriter.create( + loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd); + it->forwardIf(rewriter, loc, cmp); + llvm::append_range(nextIterYields, it->getCursor()); + } + rewriter.create(loc, nextIterYields); + + // Exit the loop, relink the iterator SSA value. + rewriter.setInsertionPointAfter(loop); + ValueRange iterVals = loop->getResults().drop_front(userReduc.size()); + for (SparseIterator *it : validIters) + iterVals = it->linkNewScope(iterVals); + assert(iterVals.empty()); + + ValueRange curResult = loop->getResults().take_front(userReduc.size()); + userReduc.assign(curResult.begin(), curResult.end()); + } else { + // This is a simple iteration loop. + assert(caseBits.count() == 1); + + Block *block = &r.getBlocks().front(); + ValueRange curResult = genLoopWithIterator( + rewriter, loc, validIters.front(), userReduc, /*iterFirst=*/false, + /*bodyBuilder=*/ + [block](PatternRewriter &rewriter, Location loc, Region &dstRegion, + SparseIterator *it, + ValueRange reduc) -> SmallVector { + SmallVector blockArgs(reduc); + blockArgs.push_back(it->deref(rewriter, loc)); + llvm::append_range(blockArgs, it->getCursor()); + + Block *dstBlock = &dstRegion.getBlocks().front(); + rewriter.inlineBlockBefore( + block, dstBlock, rewriter.getInsertionPoint(), blockArgs); + auto yield = llvm::cast(dstBlock->back()); + SmallVector result(yield.getResults()); + rewriter.eraseOp(yield); + return result; + }); + + userReduc.assign(curResult.begin(), curResult.end()); + } + } + + rewriter.replaceOp(op, userReduc); + return success(); + } +}; + } // namespace mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() { @@ -210,5 +498,6 @@ void mlir::populateLowerSparseIterationToSCFPatterns( IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext()); patterns.add(converter, patterns.getContext()); + SparseIterateOpConverter, SparseCoIterateOpConverter>( + converter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp index efb3295fb2a4b..cb5874ff45068 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp @@ -524,84 +524,8 @@ std::pair LoopEmitter::emitForLoopOverTensorAtLvl( std::pair LoopEmitter::emitWhileLoopOverTensorsAtLvls( OpBuilder &builder, Location loc, ArrayRef spIters, MutableArrayRef reduc, bool needsUniv) { - // NOTE: the slice driven tensor-related reduction variable must - // appear before normal tensors. - - // The set of induction variables for the while loop. - SmallVector ivs; - - // Construct the while-loop with a parameter for each coordinate. - for (SparseIterator *it : spIters) { - ValueRange itVals = it->getCursor(); - ivs.append(itVals.begin(), itVals.end()); - } - - // The position where user-supplied reduction variable starts. - ivs.append(reduc.begin(), reduc.end()); - // Update universal index. - if (needsUniv) - ivs.push_back(loopSeqStack.back().first); - - // Ensures all operands are valid. - assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; })); - TypeRange types = ValueRange(ivs).getTypes(); - auto whileOp = builder.create(loc, types, ivs); - - SmallVector locs(types.size(), loc); - Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); - Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); - - // Generates loop conditions. - builder.setInsertionPointToStart(before); - ValueRange bArgs = before->getArguments(); - Value whileCond = nullptr; // bool values for loop condition. - - for (SparseIterator *it : spIters) { - auto [cond, remArgs] = it->genWhileCond(builder, loc, bArgs); - whileCond = !whileCond ? cond : ANDI(whileCond, cond); - bArgs = remArgs; - } - // The remaining block arguments are user-provided reduction values and an - // optional universal index. Make sure their sizes match. - assert(bArgs.size() == reduc.size() + needsUniv); - builder.create(loc, whileCond, before->getArguments()); - - // Generates loop body. - builder.setInsertionPointToStart(after); - ValueRange aArgs = after->getArguments(); - // Since some LoopCondKind might need extra checks to filter out invalid - // iterations, we maintains another array to hold the iteration arguments to - // yield if the checks fails. - SmallVector nextArgs(aArgs.begin(), aArgs.end()); - - for (SparseIterator *it : spIters) { - aArgs = it->linkNewScope(aArgs); - // Dereference the iterator to cache the coordinate. - it->deref(builder, loc); - } - - // In-place update on reduction variable. - assert(aArgs.size() == reduc.size() + needsUniv); - for (unsigned i = 0, e = reduc.size(); i < e; i++) - reduc[i] = aArgs[i]; - - Value min; - // Finds the minimum coordinate - if (!needsUniv) { - for (SparseIterator *it : spIters) { - if (min) { - Value cmp = CMPI(ult, it->getCrd(), min); - min = SELECT(cmp, it->getCrd(), min); - } else { - min = it->getCrd(); - } - } - } else { - // Otherwise, universal index is the minimal pos. - min = whileOp.getAfterArguments().back(); - } - - return {whileOp, min}; + return genCoIteration(builder, loc, spIters, reduc, + needsUniv ? loopSeqStack.back().first : nullptr); } bool LoopEmitter::shouldIteratedByForLoop(ArrayRef spIters) { @@ -972,6 +896,100 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc, loopStack.pop_back(); } +//===----------------------------------------------------------------------===// +// Loop generation utils +//===----------------------------------------------------------------------===// + +std::pair sparse_tensor::genCoIteration( + OpBuilder &builder, Location loc, ArrayRef spIters, + MutableArrayRef reduc, Value uniIdx, bool userReducFirst) { + // NOTE: the slice driven tensor-related reduction variable must + // appear before normal tensors. + + // The set of induction variables for the while loop. + SmallVector ivs; + + // TODO: remove the flag after full migration. Currently + // `sparse_tensor.coiterate` operation (must) put user provided reduction + // values at the front of the block list, while direct sparsification to scf + // loops put them at the end. + if (userReducFirst) + ivs.append(reduc.begin(), reduc.end()); + + // Construct the while-loop with a parameter for each coordinate. + for (SparseIterator *it : spIters) { + ValueRange itVals = it->getCursor(); + ivs.append(itVals.begin(), itVals.end()); + } + + if (!userReducFirst) + ivs.append(reduc.begin(), reduc.end()); + + // Update universal index. + if (uniIdx) + ivs.push_back(uniIdx); + + // Ensures all operands are valid. + assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; })); + TypeRange types = ValueRange(ivs).getTypes(); + auto whileOp = builder.create(loc, types, ivs); + + SmallVector locs(types.size(), loc); + Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); + Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); + + // Generates loop conditions. + builder.setInsertionPointToStart(before); + ValueRange bArgs = before->getArguments(); + Value whileCond = nullptr; // bool values for loop condition. + + for (SparseIterator *it : spIters) { + auto [cond, remArgs] = it->genWhileCond(builder, loc, bArgs); + whileCond = !whileCond ? cond : ANDI(whileCond, cond); + bArgs = remArgs; + } + // The remaining block arguments are user-provided reduction values and an + // optional universal index. Make sure their sizes match. + assert(bArgs.size() == reduc.size() + (uniIdx ? 1 : 0)); + builder.create(loc, whileCond, before->getArguments()); + + // Generates loop body. + builder.setInsertionPointToStart(after); + ValueRange aArgs = after->getArguments(); + // Since some LoopCondKind might need extra checks to filter out invalid + // iterations, we maintains another array to hold the iteration arguments to + // yield if the checks fails. + SmallVector nextArgs(aArgs.begin(), aArgs.end()); + + for (SparseIterator *it : spIters) { + aArgs = it->linkNewScope(aArgs); + // Dereference the iterator to cache the coordinate. + it->deref(builder, loc); + } + + // In-place update on reduction variable. + for (unsigned i = 0, e = reduc.size(); i < e; i++) + reduc[i] = aArgs[i]; + + Value min; + // Finds the minimum coordinate + if (!uniIdx) { + for (SparseIterator *it : spIters) { + if (min) { + Value cmp = CMPI(ult, it->getCrd(), min); + min = SELECT(cmp, it->getCrd(), min); + } else { + min = it->getCrd(); + } + } + } else { + // Otherwise, universal index is the minimal pos. + min = whileOp.getAfterArguments().back(); + } + + return {whileOp, min}; +} + #undef CMPI #undef C_IDX #undef YIELD diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h index a9eb888c8b6be..3e61b5f27fcc2 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h @@ -436,6 +436,17 @@ class LoopEmitter { std::vector> spIterVals; }; +// +// Utils functions to generate sparse loops. +// + +// Generate a while loop that co-iterates over a set of iterators. +std::pair genCoIteration(OpBuilder &builder, Location loc, + ArrayRef iters, + MutableArrayRef reduc, + Value uniIdx, + bool userReducFirst = false); + } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h index 91f363db93f1d..642cb1afa156b 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h @@ -95,6 +95,8 @@ enum class IterKind : uint8_t { class SparseIterationSpace { public: SparseIterationSpace() = default; + SparseIterationSpace(SparseIterationSpace &) = delete; + SparseIterationSpace(SparseIterationSpace &&) = default; // Constructs a N-D iteration space. SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid, diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir index 2487156a9a2e4..f819458e03858 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir @@ -1,7 +1,5 @@ // RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse | FileCheck %s --check-prefix="ITER" - -// TODO: temporarilly disabled since there is no lowering rules from `coiterate` to `scf`. -// R_U_N: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion | FileCheck %s +// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion -cse --canonicalize | FileCheck %s @@ -79,6 +77,79 @@ func.func @sqsum(%arg0: tensor) -> tensor { // ITER: bufferization.to_tensor // ITER: return // ITER: } + +// CHECK-LABEL: func.func @add( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xi32, #sparse{{.*}}>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<10xi32, #sparse{{.*}}>) -> tensor<10xi32> { +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_5:.*]] = arith.constant dense<0> : tensor<10xi32> +// CHECK: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_5]] : memref<10xi32> +// CHECK: linalg.fill ins(%[[VAL_4]] : i32) outs(%[[VAL_6]] : memref<10xi32>) +// CHECK: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref +// CHECK: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref +// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref +// CHECK: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_9]], %[[VAL_17:.*]] = %[[VAL_13]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_10]] : index +// CHECK: %[[VAL_19:.*]] = arith.cmpi ult, %[[VAL_17]], %[[VAL_14]] : index +// CHECK: %[[VAL_20:.*]] = arith.andi %[[VAL_18]], %[[VAL_19]] : i1 +// CHECK: scf.condition(%[[VAL_20]]) %[[VAL_16]], %[[VAL_17]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_21:.*]]: index, %[[VAL_22:.*]]: index): +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_21]]] : memref +// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_25:.*]] = arith.cmpi ult, %[[VAL_24]], %[[VAL_23]] : index +// CHECK: %[[VAL_26:.*]] = arith.select %[[VAL_25]], %[[VAL_24]], %[[VAL_23]] : index +// CHECK: %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_26]] : index +// CHECK: %[[VAL_28:.*]] = arith.cmpi eq, %[[VAL_24]], %[[VAL_26]] : index +// CHECK: %[[VAL_29:.*]] = arith.andi %[[VAL_27]], %[[VAL_28]] : i1 +// CHECK: scf.if %[[VAL_29]] { +// CHECK: %[[VAL_30:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref +// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_21]]] : memref +// CHECK: %[[VAL_32:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref +// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_31]], %[[VAL_33]] : i32 +// CHECK: memref.store %[[VAL_34]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32> +// CHECK: } else { +// CHECK: scf.if %[[VAL_27]] { +// CHECK: %[[VAL_35:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref +// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_21]]] : memref +// CHECK: memref.store %[[VAL_36]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32> +// CHECK: } else { +// CHECK: scf.if %[[VAL_28]] { +// CHECK: %[[VAL_37:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref +// CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_22]]] : memref +// CHECK: memref.store %[[VAL_38]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_21]], %[[VAL_2]] : index +// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_27]], %[[VAL_39]], %[[VAL_21]] : index +// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_22]], %[[VAL_2]] : index +// CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_28]], %[[VAL_41]], %[[VAL_22]] : index +// CHECK: scf.yield %[[VAL_40]], %[[VAL_42]] : index, index +// CHECK: } +// CHECK: %[[VAL_43:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref +// CHECK: scf.for %[[VAL_44:.*]] = %[[VAL_45:.*]]#0 to %[[VAL_10]] step %[[VAL_2]] { +// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_44]]] : memref +// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_43]]{{\[}}%[[VAL_44]]] : memref +// CHECK: memref.store %[[VAL_47]], %[[VAL_6]]{{\[}}%[[VAL_46]]] : memref<10xi32> +// CHECK: } +// CHECK: %[[VAL_48:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref +// CHECK: scf.for %[[VAL_49:.*]] = %[[VAL_50:.*]]#1 to %[[VAL_14]] step %[[VAL_2]] { +// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_49]]] : memref +// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_48]]{{\[}}%[[VAL_49]]] : memref +// CHECK: memref.store %[[VAL_52]], %[[VAL_6]]{{\[}}%[[VAL_51]]] : memref<10xi32> +// CHECK: } +// CHECK: %[[VAL_53:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<10xi32> +// CHECK: return %[[VAL_53]] : tensor<10xi32> +// CHECK: } func.func @add(%arg0: tensor<10xi32, #VEC>, %arg1: tensor<10xi32, #VEC>) -> tensor<10xi32> { %cst = arith.constant dense<0> : tensor<10xi32> %0 = linalg.generic { diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-sqsum.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-kernel.mlir similarity index 63% rename from mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-sqsum.mlir rename to mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-kernel.mlir index 6d03565f8f7b2..6cca4fa86a162 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-sqsum.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-kernel.mlir @@ -35,9 +35,13 @@ explicitVal = 1 : i32 }> -// An example of vector reductions. -module { +#VEC = #sparse_tensor.encoding<{ + map = (d0) -> (d0 : compressed) +}> + +module { + // An example of vector reductions (lowered through sparse_tensor.iterate). func.func @sqsum(%arg0: tensor<2x3x4x5xi32, #COO>) -> tensor { %cst = arith.constant dense<0> : tensor %0 = linalg.generic { @@ -55,7 +59,30 @@ module { return %0 : tensor } + // An example of vector addition (lowered through sparse_tensor.coiterate). + func.func @vec_add(%arg0: tensor<4xi32, #VEC>, %arg1: tensor<4xi32, #VEC>) -> tensor<4xi32> { + %cst = arith.constant dense<0> : tensor<4xi32> + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)> + ], + iterator_types = ["parallel"] + } + ins(%arg0, %arg1 : tensor<4xi32, #VEC>, tensor<4xi32, #VEC>) + outs(%cst : tensor<4xi32>) { + ^bb0(%in1: i32, %in2: i32, %out: i32): + %2 = arith.addi %in1, %in2 : i32 + linalg.yield %2 : i32 + } -> tensor<4xi32> + return %0 : tensor<4xi32> + } + func.func @main() { + %c0 = arith.constant 0 : index + %i0 = arith.constant 0 : i32 + %cst = arith.constant sparse< [ [0, 1, 2, 3], @@ -66,15 +93,33 @@ module { [1, 1, 1, 1] > : tensor<2x3x4x5xi32> + %l = arith.constant dense< + [0, 1, 2, 3] + > : tensor<4xi32> + %r = arith.constant dense< + [1, 0, 3, 0] + > : tensor<4xi32> + %input = sparse_tensor.convert %cst : tensor<2x3x4x5xi32> to tensor<2x3x4x5xi32, #COO> %0 = call @sqsum(%input) : (tensor<2x3x4x5xi32, #COO>) -> tensor %v = tensor.extract %0[] : tensor + %lhs = sparse_tensor.convert %l : tensor<4xi32> to tensor<4xi32, #VEC> + %rhs = sparse_tensor.convert %r : tensor<4xi32> to tensor<4xi32, #VEC> + %add = call @vec_add(%lhs, %rhs) : (tensor<4xi32, #VEC>, tensor<4xi32, #VEC>) -> tensor<4xi32> + // CHECK: 4 vector.print %v : i32 + // CHECK-NEXT: ( 1, 1, 5, 3 ) + %vec = vector.transfer_read %add[%c0], %i0 : tensor<4xi32>, vector<4xi32> + vector.print %vec : vector<4xi32> bufferization.dealloc_tensor %input : tensor<2x3x4x5xi32, #COO> bufferization.dealloc_tensor %0 : tensor + + bufferization.dealloc_tensor %lhs : tensor<4xi32, #VEC> + bufferization.dealloc_tensor %rhs : tensor<4xi32, #VEC> + bufferization.dealloc_tensor %add : tensor<4xi32> return } }