diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 4cc6ee971d4a3..4adb1c19096a2 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -788,24 +788,29 @@ LogicalResult SparseTensorEncodingAttr::verify( return emitError() << "unexpected position bitwidth: " << posWidth; if (!acceptBitWidth(crdWidth)) return emitError() << "unexpected coordinate bitwidth: " << crdWidth; - if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT); - it != std::end(lvlTypes)) { + + // Verify every COO segment. + auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT); + while (it != lvlTypes.end()) { if (it == lvlTypes.begin() || - (!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1)))) + !(it - 1)->isa()) return emitError() << "expected compressed or loose_compressed level " "before singleton level"; - if (!std::all_of(it, lvlTypes.end(), + + auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT); + if (!std::all_of(it, curCOOEnd, [](LevelType i) { return isSingletonLT(i); })) return emitError() << "expected all singleton lvlTypes " "following a singleton level"; // We can potentially support mixed SoA/AoS singleton levels. - if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) { + if (!std::all_of(it, curCOOEnd, [it](LevelType i) { return it->isa() == i.isa(); })) { return emitError() << "expected all singleton lvlTypes stored in the " "same memory layout (SoA vs AoS)."; } + it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT); } auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir index 7fb1c76c1a1ff..44710cad246c6 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir @@ -156,6 +156,17 @@ func.func private @sparse_coo(tensor) // ----- +#COO_DENSE = #sparse_tensor.encoding<{ + map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton, d2: dense) +}> + +// CHECK-DAG: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton, d2 : dense) }> +// CHECK-LABEL: func private @sparse_coo_trailing_dense( +// CHECK-SAME: tensor) +func.func private @sparse_coo_trailing_dense(tensor) + +// ----- + #BCOO = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton) }>