Skip to content

Commit d82e93e

Browse files
authored
[mlir][sparse] add merger support on Batch LevelType. (#83186)
1 parent f7a9966 commit d82e93e

File tree

4 files changed

+59
-28
lines changed

4 files changed

+59
-28
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,16 +333,28 @@ struct LevelType {
333333
return lvlBits & static_cast<uint64_t>(p);
334334
}
335335

336+
/// Check if the `LevelType` is considered to be sparse.
337+
constexpr bool hasSparseSemantic() const {
338+
return isa<LevelFormat::Compressed, LevelFormat::Singleton,
339+
LevelFormat::LooseCompressed, LevelFormat::NOutOfM>();
340+
}
341+
342+
/// Check if the `LevelType` is considered to be dense-like.
343+
constexpr bool hasDenseSemantic() const {
344+
return isa<LevelFormat::Dense, LevelFormat::Batch>();
345+
}
346+
336347
/// Check if the `LevelType` needs positions array.
337348
constexpr bool isWithPosLT() const {
338-
return isa<LevelFormat::Compressed>() ||
339-
isa<LevelFormat::LooseCompressed>();
349+
assert(!isa<LevelFormat::Undef>());
350+
return isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>();
340351
}
341352

342353
/// Check if the `LevelType` needs coordinates array.
343354
constexpr bool isWithCrdLT() const {
355+
assert(!isa<LevelFormat::Undef>());
344356
// All sparse levels has coordinate array.
345-
return !isa<LevelFormat::Dense, LevelFormat::Batch>();
357+
return hasSparseSemantic();
346358
}
347359

348360
std::string toMLIRString() const {

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,8 +509,7 @@ class Merger {
509509
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
510510
if (isLvlWithNonTrivialIdxExp(b)) {
511511
auto lt = getLoopDependentLevelType(b);
512-
return isCompressedLT(lt) || isSingletonLT(lt) ||
513-
isLooseCompressedLT(lt) || isNOutOfMLT(lt);
512+
return lt.hasSparseSemantic();
514513
}
515514
return false;
516515
}

mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
476476
// Starts resetting from a dense level, so that the first bit (if kept)
477477
// is not undefined level-type.
478478
for (unsigned b = 0; b < be; b++) {
479-
if (simple[b] && isDenseLT(getLvlType(TensorLoopId{b}))) {
479+
if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) {
480480
offset = be - b - 1; // relative to the end
481481
break;
482482
}
@@ -489,8 +489,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
489489
// Slice on dense level has `locate` property as well, and can be optimized.
490490
if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
491491
const auto lt = getLvlType(b);
492-
if (!isCompressedLT(lt) && !isSingletonLT(lt) &&
493-
!isLooseCompressedLT(lt) && !isNOutOfMLT(lt)) {
492+
if (!lt.hasSparseSemantic()) {
494493
if (reset)
495494
simple.reset(b);
496495
reset = true;
@@ -670,8 +669,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
670669
bool Merger::hasAnySparse(const BitVector &bits) const {
671670
for (TensorLoopId b : bits.set_bits()) {
672671
const auto lt = getLvlType(b);
673-
if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
674-
isNOutOfMLT(lt))
672+
if (lt.hasSparseSemantic())
675673
return true;
676674
}
677675
return hasSparseIdxReduction(bits);

mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ static Match synZeroMatch() { return Match(); }
120120
FOREVERY_BINOP(IMPL_BINOP_PATTERN)
121121
#undef IMPL_BINOP_PATTERN
122122

123-
class MergerTestBase : public ::testing::Test {
123+
// Parameterize LevelFormat to test both Dense and Batch LevelFormat.
124+
class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
124125
protected:
125126
MergerTestBase(unsigned numTensors, unsigned numLoops)
126127
: merger(numTensors, numLoops, /*maxRank=*/numLoops) {
@@ -317,10 +318,14 @@ class MergerTest3T1L : public MergerTestBase {
317318
// Tensor 1: sparse input vector.
318319
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
319320
// Tensor 2: dense output vector.
320-
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
321+
merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
321322
}
322323
};
323324

325+
INSTANTIATE_TEST_SUITE_P(Test3T1L, MergerTest3T1L,
326+
::testing::Values(LevelFormat::Dense,
327+
LevelFormat::Batch));
328+
324329
/// Four tensors (three inputs, one output); and a single loop.
325330
class MergerTest4T1L : public MergerTestBase {
326331
protected:
@@ -333,10 +338,14 @@ class MergerTest4T1L : public MergerTestBase {
333338
// Tensor 2: sparse input vector
334339
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
335340
// Tensor 3: dense output vector
336-
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
341+
merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
337342
}
338343
};
339344

345+
INSTANTIATE_TEST_SUITE_P(Test4T1L, MergerTest4T1L,
346+
::testing::Values(LevelFormat::Dense,
347+
LevelFormat::Batch));
348+
340349
///
341350
/// Tests with both sparse and dense input.
342351
///
@@ -349,12 +358,16 @@ class MergerTest3T1LD : public MergerTestBase {
349358
// Tensor 0: sparse input vector.
350359
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
351360
// Tensor 1: dense input vector.
352-
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
361+
merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
353362
// Tensor 2: dense output vector.
354-
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
363+
merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
355364
}
356365
};
357366

367+
INSTANTIATE_TEST_SUITE_P(Test3T1LD, MergerTest3T1LD,
368+
::testing::Values(LevelFormat::Dense,
369+
LevelFormat::Batch));
370+
358371
///
359372
/// Tests with both undef and dense input.
360373
///
@@ -367,14 +380,18 @@ class MergerTest4T1LU : public MergerTestBase {
367380
// Tensor 0: undef input vector.
368381
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
369382
// Tensor 1: dense input vector.
370-
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
383+
merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
371384
// Tensor 2: undef input vector.
372385
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
373386
// Tensor 3: dense output vector.
374-
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
387+
merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
375388
}
376389
};
377390

391+
INSTANTIATE_TEST_SUITE_P(Test4T1LU, MergerTest4T1LU,
392+
::testing::Values(LevelFormat::Dense,
393+
LevelFormat::Batch));
394+
378395
///
379396
/// Tests with operation on sparse output.
380397
///
@@ -395,6 +412,11 @@ class MergerTest3T1LSo : public MergerTestBase {
395412
}
396413
};
397414

415+
// This testsuite does not use any dense-like format, just one of {Dense, Batch}
416+
// is enough.
417+
INSTANTIATE_TEST_SUITE_P(Test3T1LSo, MergerTest3T1LSo,
418+
::testing::Values(LevelFormat::Dense));
419+
398420
} // namespace
399421

400422
/// Vector multiplication (conjunction) of 3 vectors, i.e.;
@@ -409,7 +431,7 @@ class MergerTest3T1LSo : public MergerTestBase {
409431
/// lat( i_01_D / (tensor_0 * tensor_1 * tensor2) )
410432
/// }
411433
#define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2) \
412-
TEST_F(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
434+
TEST_P(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
413435
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
414436
const auto e = CONJ2##Expr(em, tensor(2)); \
415437
const auto l0 = lid(0); \
@@ -443,7 +465,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
443465
/// lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
444466
/// }
445467
#define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2) \
446-
TEST_F(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
468+
TEST_P(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
447469
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
448470
const auto e = CONJ2##Expr(em, tensor(2)); \
449471
const auto l0 = lid(0); \
@@ -482,7 +504,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
482504
/// lat( i_01 / tensor_1 )
483505
/// }
484506
#define IMPL_MERGER_TEST_DISJ(OP, UNUSED) \
485-
TEST_F(MergerTest3T1L, vector_##OP) { \
507+
TEST_P(MergerTest3T1L, vector_##OP) { \
486508
const auto e = OP##Expr(tensor(0), tensor(1)); \
487509
const auto l0 = lid(0); \
488510
const auto t0 = tid(0); \
@@ -514,7 +536,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
514536
/// lat( i_00 i_01 / (tensor_0 * tensor_1) )
515537
/// }
516538
#define IMPL_MERGER_TEST_CONJ(OP, UNUSED) \
517-
TEST_F(MergerTest3T1L, vector_##OP) { \
539+
TEST_P(MergerTest3T1L, vector_##OP) { \
518540
const auto e = OP##Expr(tensor(0), tensor(1)); \
519541
const auto l0 = lid(0); \
520542
const auto t0 = tid(0); \
@@ -544,7 +566,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
544566
/// lat( i_02 / tensor_2 )
545567
/// }
546568
#define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \
547-
TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
569+
TEST_P(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
548570
const auto em = CONJ##Expr(tensor(0), tensor(1)); \
549571
const auto e = DISJ##Expr(em, tensor(2)); \
550572
const auto l0 = lid(0); \
@@ -587,7 +609,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
587609
/// lat( i_00 / tensor_0 )
588610
/// }
589611
#define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \
590-
TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
612+
TEST_P(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
591613
const auto em = DISJ1##Expr(tensor(0), tensor(1)); \
592614
const auto e = DISJ2##Expr(em, tensor(2)); \
593615
const auto l0 = lid(0); \
@@ -636,7 +658,7 @@ FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
636658
/// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
637659
/// }
638660
#define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \
639-
TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
661+
TEST_P(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
640662
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
641663
const auto e = CONJ2##Expr(em, tensor(2)); \
642664
const auto l0 = lid(0); \
@@ -675,7 +697,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
675697
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
676698
/// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
677699
#define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED) \
678-
TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
700+
TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
679701
const auto e = OP##Expr(tensor(0), tensor(1)); \
680702
const auto l0 = lid(0); \
681703
const auto t0 = tid(0); \
@@ -711,7 +733,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
711733
/// }
712734
/// since i_01 is a dense dimension.
713735
#define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP, UNUSED) \
714-
TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
736+
TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
715737
const auto e = OP##Expr(tensor(0), tensor(1)); \
716738
const auto l0 = lid(0); \
717739
const auto t0 = tid(0); \
@@ -746,7 +768,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
746768
/// lat( i_00 / tensor_0 cmp 0 )
747769
/// lat( i_01 / 0 cmp tensor_1 )
748770
/// }
749-
TEST_F(MergerTest3T1L, vector_cmp) {
771+
TEST_P(MergerTest3T1L, vector_cmp) {
750772
const auto e = cmpiExpr(tensor(0), tensor(1));
751773
const auto l0 = lid(0);
752774
const auto t0 = tid(0);
@@ -784,7 +806,7 @@ TEST_F(MergerTest3T1L, vector_cmp) {
784806
///
785807
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
786808
/// with lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ).
787-
TEST_F(MergerTest3T1LD, vector_cmp) {
809+
TEST_P(MergerTest3T1LD, vector_cmp) {
788810
const auto e = cmpiExpr(tensor(0), tensor(1));
789811
const auto l0 = lid(0);
790812
const auto t0 = tid(0);

0 commit comments

Comments
 (0)