@@ -1356,50 +1356,54 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
13561356 // See buildLattices() for an explanation of rejecting certain
13571357 // division and shift operations.
13581358 if (def->getNumOperands () == 2 ) {
1359- const auto [x, xDepSp] = buildTensorExp (op, def->getOperand (0 ));
1360- const auto [y, yDepSp] = buildTensorExp (op, def->getOperand (1 ));
1361- bool hasSpDep = xDepSp || yDepSp;
1359+ const auto [x, xSpVals] = buildTensorExp (op, def->getOperand (0 ));
1360+ const auto [y, ySpVals] = buildTensorExp (op, def->getOperand (1 ));
1361+ // For a conjunctive operation, it yields a "sparse" result if any operand
1362+ // is sparse. For a disjunctive operation, it yields a "sparse" result if
1363+ // all operands are sparse.
1364+ bool conjSpVals = xSpVals || ySpVals;
1365+ bool disjSpVals = xSpVals && ySpVals;
13621366 if (x.has_value () && y.has_value ()) {
13631367 const ExprId e0 = *x;
13641368 const ExprId e1 = *y;
13651369 if (isa<arith::MulFOp>(def))
1366- return {addExp (TensorExp::Kind::kMulF , e0 , e1 ), hasSpDep };
1370+ return {addExp (TensorExp::Kind::kMulF , e0 , e1 ), conjSpVals };
13671371 if (isa<complex ::MulOp>(def))
1368- return {addExp (TensorExp::Kind::kMulC , e0 , e1 ), hasSpDep };
1372+ return {addExp (TensorExp::Kind::kMulC , e0 , e1 ), conjSpVals };
13691373 if (isa<arith::MulIOp>(def))
1370- return {addExp (TensorExp::Kind::kMulI , e0 , e1 ), hasSpDep };
1374+ return {addExp (TensorExp::Kind::kMulI , e0 , e1 ), conjSpVals };
13711375 if (isa<arith::DivFOp>(def) && !maybeZero (e1 ))
1372- return {addExp (TensorExp::Kind::kDivF , e0 , e1 ), hasSpDep };
1376+ return {addExp (TensorExp::Kind::kDivF , e0 , e1 ), conjSpVals };
13731377 if (isa<complex ::DivOp>(def) && !maybeZero (e1 ))
1374- return {addExp (TensorExp::Kind::kDivC , e0 , e1 ), hasSpDep };
1378+ return {addExp (TensorExp::Kind::kDivC , e0 , e1 ), conjSpVals };
13751379 if (isa<arith::DivSIOp>(def) && !maybeZero (e1 ))
1376- return {addExp (TensorExp::Kind::kDivS , e0 , e1 ), hasSpDep };
1380+ return {addExp (TensorExp::Kind::kDivS , e0 , e1 ), conjSpVals };
13771381 if (isa<arith::DivUIOp>(def) && !maybeZero (e1 ))
1378- return {addExp (TensorExp::Kind::kDivU , e0 , e1 ), hasSpDep };
1382+ return {addExp (TensorExp::Kind::kDivU , e0 , e1 ), conjSpVals };
13791383 if (isa<arith::AddFOp>(def))
1380- return {addExp (TensorExp::Kind::kAddF , e0 , e1 ), hasSpDep };
1384+ return {addExp (TensorExp::Kind::kAddF , e0 , e1 ), disjSpVals };
13811385 if (isa<complex ::AddOp>(def))
1382- return {addExp (TensorExp::Kind::kAddC , e0 , e1 ), hasSpDep };
1386+ return {addExp (TensorExp::Kind::kAddC , e0 , e1 ), disjSpVals };
13831387 if (isa<arith::AddIOp>(def))
1384- return {addExp (TensorExp::Kind::kAddI , e0 , e1 ), hasSpDep };
1388+ return {addExp (TensorExp::Kind::kAddI , e0 , e1 ), disjSpVals };
13851389 if (isa<arith::SubFOp>(def))
1386- return {addExp (TensorExp::Kind::kSubF , e0 , e1 ), hasSpDep };
1390+ return {addExp (TensorExp::Kind::kSubF , e0 , e1 ), disjSpVals };
13871391 if (isa<complex ::SubOp>(def))
1388- return {addExp (TensorExp::Kind::kSubC , e0 , e1 ), hasSpDep };
1392+ return {addExp (TensorExp::Kind::kSubC , e0 , e1 ), disjSpVals };
13891393 if (isa<arith::SubIOp>(def))
1390- return {addExp (TensorExp::Kind::kSubI , e0 , e1 ), hasSpDep };
1394+ return {addExp (TensorExp::Kind::kSubI , e0 , e1 ), disjSpVals };
13911395 if (isa<arith::AndIOp>(def))
1392- return {addExp (TensorExp::Kind::kAndI , e0 , e1 ), hasSpDep };
1396+ return {addExp (TensorExp::Kind::kAndI , e0 , e1 ), conjSpVals };
13931397 if (isa<arith::OrIOp>(def))
1394- return {addExp (TensorExp::Kind::kOrI , e0 , e1 ), hasSpDep };
1398+ return {addExp (TensorExp::Kind::kOrI , e0 , e1 ), disjSpVals };
13951399 if (isa<arith::XOrIOp>(def))
1396- return {addExp (TensorExp::Kind::kXorI , e0 , e1 ), hasSpDep };
1400+ return {addExp (TensorExp::Kind::kXorI , e0 , e1 ), disjSpVals };
13971401 if (isa<arith::ShRSIOp>(def) && isInvariant (e1 ))
1398- return {addExp (TensorExp::Kind::kShrS , e0 , e1 ), hasSpDep };
1402+ return {addExp (TensorExp::Kind::kShrS , e0 , e1 ), conjSpVals };
13991403 if (isa<arith::ShRUIOp>(def) && isInvariant (e1 ))
1400- return {addExp (TensorExp::Kind::kShrU , e0 , e1 ), hasSpDep };
1404+ return {addExp (TensorExp::Kind::kShrU , e0 , e1 ), conjSpVals };
14011405 if (isa<arith::ShLIOp>(def) && isInvariant (e1 ))
1402- return {addExp (TensorExp::Kind::kShlI , e0 , e1 ), hasSpDep };
1406+ return {addExp (TensorExp::Kind::kShlI , e0 , e1 ), conjSpVals };
14031407 if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
14041408 if (ci.getPredicate () == arith::CmpIPredicate::eq &&
14051409 ci.getPredicate () == arith::CmpIPredicate::sle &&
@@ -1413,7 +1417,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
14131417
14141418 auto e = addExp (TensorExp::Kind::kCmpI , e0 , e1 , nullptr ,
14151419 ci.getPredicateAttr ());
1416- return {e, hasSpDep };
1420+ return {e, conjSpVals };
14171421 }
14181422 if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
14191423 if (cf.getPredicate () == arith::CmpFPredicate::OEQ &&
@@ -1431,15 +1435,15 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
14311435 }
14321436 auto e = addExp (TensorExp::Kind::kCmpF , e0 , e1 , nullptr ,
14331437 cf.getPredicateAttr ());
1434- return {e, hasSpDep };
1438+ return {e, conjSpVals };
14351439 }
14361440 if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
14371441 if (isAdmissibleBranch (binop, binop.getOverlapRegion ()) &&
14381442 (binop.getLeftIdentity () ||
14391443 isAdmissibleBranch (binop, binop.getLeftRegion ())) &&
14401444 (binop.getRightIdentity () ||
14411445 isAdmissibleBranch (binop, binop.getRightRegion ())))
1442- return {addExp (TensorExp::Kind::kBinary , e0 , e1 , def), hasSpDep };
1446+ return {addExp (TensorExp::Kind::kBinary , e0 , e1 , def), conjSpVals };
14431447 }
14441448 }
14451449 }
0 commit comments