From 0bc64ab802e82e51650ac3400a9e6ccdd6e8ab79 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Thu, 27 Feb 2025 08:59:25 +0530 Subject: [PATCH] [MLIR] NFC. Improve API signature + clang-tidy warning in IntegerRelation --- .../Analysis/Presburger/IntegerRelation.h | 10 +-- .../Analysis/FlatLinearValueConstraints.cpp | 10 +-- .../Analysis/Presburger/IntegerRelation.cpp | 68 +++++++++---------- 3 files changed, 41 insertions(+), 47 deletions(-) diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h index ddc18038e869c..fa29ac23af607 100644 --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -738,11 +738,11 @@ class IntegerRelation { /// Same as findSymbolicIntegerLexMin but produces lexmax instead of lexmin SymbolicLexOpt findSymbolicIntegerLexMax() const; - /// Searches for a constraint with a non-zero coefficient at `colIdx` in - /// equality (isEq=true) or inequality (isEq=false) constraints. - /// Returns true and sets row found in search in `rowIdx`, false otherwise. - bool findConstraintWithNonZeroAt(unsigned colIdx, bool isEq, - unsigned *rowIdx) const; + /// Finds a constraint with a non-zero coefficient at `colIdx` in equality + /// (isEq=true) or inequality (isEq=false) constraints. Returns the position + /// of the row if it was found or none otherwise. + std::optional findConstraintWithNonZeroAt(unsigned colIdx, + bool isEq) const; /// Return the set difference of this set and the given set, i.e., /// return `this \ set`. diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp index 4653eca9887ce..8c179cb2a38ba 100644 --- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp +++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp @@ -635,8 +635,8 @@ static void computeUnknownVars(const FlatLinearConstraints &cst, } // Detect a variable as an expression of other variables. - unsigned idx; - if (!cst.findConstraintWithNonZeroAt(pos, /*isEq=*/true, &idx)) { + std::optional idx; + if (!(idx = cst.findConstraintWithNonZeroAt(pos, /*isEq=*/true))) { continue; } @@ -646,7 +646,7 @@ static void computeUnknownVars(const FlatLinearConstraints &cst, for (j = 0, e = cst.getNumVars(); j < e; ++j) { if (j == pos) continue; - int64_t c = cst.atEq64(idx, j); + int64_t c = cst.atEq64(*idx, j); if (c == 0) continue; // If any of the involved IDs hasn't been found yet, we can't proceed. @@ -660,8 +660,8 @@ static void computeUnknownVars(const FlatLinearConstraints &cst, continue; // Add constant term to AffineExpr. - expr = expr + cst.atEq64(idx, cst.getNumVars()); - int64_t vPos = cst.atEq64(idx, pos); + expr = expr + cst.atEq64(*idx, cst.getNumVars()); + int64_t vPos = cst.atEq64(*idx, pos); assert(vPos != 0 && "expected non-zero here"); if (vPos > 0) expr = (-expr).floorDiv(vPos); diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 74cdf567c0e56..5de3fd920e4e0 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -564,22 +564,18 @@ void IntegerRelation::clearAndCopyFrom(const IntegerRelation &other) { *this = other; } -// Searches for a constraint with a non-zero coefficient at `colIdx` in -// equality (isEq=true) or inequality (isEq=false) constraints. -// Returns true and sets row found in search in `rowIdx`, false otherwise. -bool IntegerRelation::findConstraintWithNonZeroAt(unsigned colIdx, bool isEq, - unsigned *rowIdx) const { +std::optional +IntegerRelation::findConstraintWithNonZeroAt(unsigned colIdx, bool isEq) const { assert(colIdx < getNumCols() && "position out of bounds"); auto at = [&](unsigned rowIdx) -> DynamicAPInt { return isEq ? atEq(rowIdx, colIdx) : atIneq(rowIdx, colIdx); }; unsigned e = isEq ? getNumEqualities() : getNumInequalities(); - for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) { - if (at(*rowIdx) != 0) { - return true; - } + for (unsigned rowIdx = 0; rowIdx < e; ++rowIdx) { + if (at(rowIdx) != 0) + return rowIdx; } - return false; + return std::nullopt; } void IntegerRelation::normalizeConstraintsByGCD() { @@ -1088,31 +1084,30 @@ unsigned IntegerRelation::gaussianEliminateVars(unsigned posStart, unsigned pivotCol = 0; for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) { // Find a row which has a non-zero coefficient in column 'j'. - unsigned pivotRow; - if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/true, &pivotRow)) { - // No pivot row in equalities with non-zero at 'pivotCol'. - if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/false, &pivotRow)) { - // If inequalities are also non-zero in 'pivotCol', it can be - // eliminated. - continue; - } - break; + std::optional pivotRow = + findConstraintWithNonZeroAt(pivotCol, /*isEq=*/true); + // No pivot row in equalities with non-zero at 'pivotCol'. + if (!pivotRow) { + // If inequalities are also non-zero in 'pivotCol', it can be eliminated. + if ((pivotRow = findConstraintWithNonZeroAt(pivotCol, /*isEq=*/false))) + break; + continue; } // Eliminate variable at 'pivotCol' from each equality row. for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { - eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, + eliminateFromConstraint(this, i, *pivotRow, pivotCol, posStart, /*isEq=*/true); equalities.normalizeRow(i); } // Eliminate variable at 'pivotCol' from each inequality row. for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { - eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, + eliminateFromConstraint(this, i, *pivotRow, pivotCol, posStart, /*isEq=*/false); inequalities.normalizeRow(i); } - removeEquality(pivotRow); + removeEquality(*pivotRow); gcdTightenInequalities(); } // Update position limit based on number eliminated. @@ -1125,31 +1120,31 @@ unsigned IntegerRelation::gaussianEliminateVars(unsigned posStart, bool IntegerRelation::gaussianEliminate() { gcdTightenInequalities(); unsigned firstVar = 0, vars = getNumVars(); - unsigned nowDone, eqs, pivotRow; + unsigned nowDone, eqs; + std::optional pivotRow; for (nowDone = 0, eqs = getNumEqualities(); nowDone < eqs; ++nowDone) { // Finds the first non-empty column. for (; firstVar < vars; ++firstVar) { - if (!findConstraintWithNonZeroAt(firstVar, true, &pivotRow)) - continue; - break; + if ((pivotRow = findConstraintWithNonZeroAt(firstVar, /*isEq=*/true))) + break; } // The matrix has been normalized to row echelon form. if (firstVar >= vars) break; // The first pivot row found is below where it should currently be placed. - if (pivotRow > nowDone) { - equalities.swapRows(pivotRow, nowDone); - pivotRow = nowDone; + if (*pivotRow > nowDone) { + equalities.swapRows(*pivotRow, nowDone); + *pivotRow = nowDone; } // Normalize all lower equations and all inequalities. for (unsigned i = nowDone + 1; i < eqs; ++i) { - eliminateFromConstraint(this, i, pivotRow, firstVar, 0, true); + eliminateFromConstraint(this, i, *pivotRow, firstVar, 0, true); equalities.normalizeRow(i); } for (unsigned i = 0, ineqs = getNumInequalities(); i < ineqs; ++i) { - eliminateFromConstraint(this, i, pivotRow, firstVar, 0, false); + eliminateFromConstraint(this, i, *pivotRow, firstVar, 0, false); inequalities.normalizeRow(i); } gcdTightenInequalities(); @@ -2290,9 +2285,8 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { } bool IntegerRelation::isColZero(unsigned pos) const { - unsigned rowPos; - return !findConstraintWithNonZeroAt(pos, /*isEq=*/false, &rowPos) && - !findConstraintWithNonZeroAt(pos, /*isEq=*/true, &rowPos); + return !findConstraintWithNonZeroAt(pos, /*isEq=*/false) && + !findConstraintWithNonZeroAt(pos, /*isEq=*/true); } /// Find positions of inequalities and equalities that do not have a coefficient @@ -2600,16 +2594,16 @@ void IntegerRelation::print(raw_ostream &os) const { for (unsigned j = 0, f = getNumCols(); j < f; ++j) updatePrintMetrics(atIneq(i, j), ptm); // Print using PrintMetrics. - unsigned MIN_SPACING = 1; + constexpr unsigned kMinSpacing = 1; for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { for (unsigned j = 0, f = getNumCols(); j < f; ++j) { - printWithPrintMetrics(os, atEq(i, j), MIN_SPACING, ptm); + printWithPrintMetrics(os, atEq(i, j), kMinSpacing, ptm); } os << " = 0\n"; } for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { for (unsigned j = 0, f = getNumCols(); j < f; ++j) { - printWithPrintMetrics(os, atIneq(i, j), MIN_SPACING, ptm); + printWithPrintMetrics(os, atIneq(i, j), kMinSpacing, ptm); } os << " >= 0\n"; }