@@ -788,24 +788,29 @@ LogicalResult SparseTensorEncodingAttr::verify(
788788 return emitError () << " unexpected position bitwidth: " << posWidth;
789789 if (!acceptBitWidth (crdWidth))
790790 return emitError () << " unexpected coordinate bitwidth: " << crdWidth;
791- if (auto it = std::find_if (lvlTypes.begin (), lvlTypes.end (), isSingletonLT);
792- it != std::end (lvlTypes)) {
791+
792+ // Verify every COO segment.
793+ auto *it = std::find_if (lvlTypes.begin (), lvlTypes.end (), isSingletonLT);
794+ while (it != lvlTypes.end ()) {
793795 if (it == lvlTypes.begin () ||
794- (! isCompressedLT (*( it - 1 )) && ! isLooseCompressedLT (*(it - 1 )) ))
796+ !( it - 1 )-> isa <LevelFormat::Compressed, LevelFormat::LooseCompressed>( ))
795797 return emitError () << " expected compressed or loose_compressed level "
796798 " before singleton level" ;
797- if (!std::all_of (it, lvlTypes.end (),
799+
800+ auto *curCOOEnd = std::find_if_not (it, lvlTypes.end (), isSingletonLT);
801+ if (!std::all_of (it, curCOOEnd,
798802 [](LevelType i) { return isSingletonLT (i); }))
799803 return emitError () << " expected all singleton lvlTypes "
800804 " following a singleton level" ;
801805 // We can potentially support mixed SoA/AoS singleton levels.
802- if (!std::all_of (it, lvlTypes. end () , [it](LevelType i) {
806+ if (!std::all_of (it, curCOOEnd , [it](LevelType i) {
803807 return it->isa <LevelPropNonDefault::SoA>() ==
804808 i.isa <LevelPropNonDefault::SoA>();
805809 })) {
806810 return emitError () << " expected all singleton lvlTypes stored in the "
807811 " same memory layout (SoA vs AoS)." ;
808812 }
813+ it = std::find_if (curCOOEnd, lvlTypes.end (), isSingletonLT);
809814 }
810815
811816 auto lastBatch = std::find_if (lvlTypes.rbegin (), lvlTypes.rend (), isBatchLT);
0 commit comments