@@ -110,25 +110,47 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
110110 assertValidValueDim (value, dim);
111111#endif // NDEBUG
112112
113+ // Check if the value/dim is statically known. In that case, an affine
114+ // constant expression should be returned. This allows us to support
115+ // multiplications with constants. (Multiplications of two columns in the
116+ // constraint set is not supported.)
117+ std::optional<int64_t > constSize = std::nullopt ;
113118 auto shapedType = dyn_cast<ShapedType>(value.getType ());
114119 if (shapedType) {
115- // Static dimension: return constant directly.
116120 if (shapedType.hasRank () && !shapedType.isDynamicDim (*dim))
117- return builder.getAffineConstantExpr (shapedType.getDimSize (*dim));
118- } else {
119- // Constant index value: return directly.
120- if (auto constInt = ::getConstantIntValue (value))
121- return builder.getAffineConstantExpr (*constInt);
121+ constSize = shapedType.getDimSize (*dim);
122+ } else if (auto constInt = ::getConstantIntValue (value)) {
123+ constSize = *constInt;
122124 }
123125
124- // Dynamic value: add to constraint set.
126+ // If the value/dim is already mapped, return the corresponding expression
127+ // directly.
125128 ValueDim valueDim = std::make_pair (value, dim.value_or (kIndexValue ));
126- if (!valueDimToPosition.contains (valueDim))
127- (void )insert (value, dim);
128- int64_t pos = getPos (value, dim);
129- return pos < cstr.getNumDimVars ()
130- ? builder.getAffineDimExpr (pos)
131- : builder.getAffineSymbolExpr (pos - cstr.getNumDimVars ());
129+ if (valueDimToPosition.contains (valueDim)) {
130+ // If it is a constant, return an affine constant expression. Otherwise,
131+ // return an affine expression that represents the respective column in the
132+ // constraint set.
133+ if (constSize)
134+ return builder.getAffineConstantExpr (*constSize);
135+ return getPosExpr (getPos (value, dim));
136+ }
137+
138+ if (constSize) {
139+ // Constant index value/dim: add column to the constraint set, add EQ bound
140+ // and return an affine constant expression without pushing the newly added
141+ // column to the worklist.
142+ (void )insert (value, dim, /* isSymbol=*/ true , /* addToWorklist=*/ false );
143+ if (shapedType)
144+ bound (value)[*dim] == *constSize;
145+ else
146+ bound (value) == *constSize;
147+ return builder.getAffineConstantExpr (*constSize);
148+ }
149+
150+ // Dynamic value/dim: insert column to the constraint set and put it on the
151+ // worklist. Return an affine expression that represents the newly inserted
152+ // column in the constraint set.
153+ return getPosExpr (insert (value, dim, /* isSymbol=*/ true ));
132154}
133155
134156AffineExpr ValueBoundsConstraintSet::getExpr (OpFoldResult ofr) {
@@ -145,7 +167,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {
145167
146168int64_t ValueBoundsConstraintSet::insert (Value value,
147169 std::optional<int64_t > dim,
148- bool isSymbol) {
170+ bool isSymbol, bool addToWorklist ) {
149171#ifndef NDEBUG
150172 assertValidValueDim (value, dim);
151173#endif // NDEBUG
@@ -160,7 +182,12 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
160182 if (positionToValueDim[i].has_value ())
161183 valueDimToPosition[*positionToValueDim[i]] = i;
162184
163- worklist.push (pos);
185+ if (addToWorklist) {
186+ LLVM_DEBUG (llvm::dbgs () << " Push to worklist: " << value
187+ << " (dim: " << dim.value_or (kIndexValue ) << " )\n " );
188+ worklist.push (pos);
189+ }
190+
164191 return pos;
165192}
166193
@@ -190,6 +217,13 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
190217 return it->second ;
191218}
192219
220+ AffineExpr ValueBoundsConstraintSet::getPosExpr (int64_t pos) {
221+ assert (pos >= 0 && pos < cstr.getNumDimAndSymbolVars () && " invalid position" );
222+ return pos < cstr.getNumDimVars ()
223+ ? builder.getAffineDimExpr (pos)
224+ : builder.getAffineSymbolExpr (pos - cstr.getNumDimVars ());
225+ }
226+
193227static Operation *getOwnerOfValue (Value value) {
194228 if (auto bbArg = dyn_cast<BlockArgument>(value))
195229 return bbArg.getOwner ()->getParentOp ();
@@ -492,15 +526,16 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
492526
493527 // Default stop condition if none was specified: Keep adding constraints until
494528 // a bound could be computed.
495- int64_t pos;
529+ int64_t pos = 0 ;
496530 auto defaultStopCondition = [&](Value v, std::optional<int64_t > dim,
497531 ValueBoundsConstraintSet &cstr) {
498532 return cstr.cstr .getConstantBound64 (type, pos).has_value ();
499533 };
500534
501535 ValueBoundsConstraintSet cstr (
502536 map.getContext (), stopCondition ? stopCondition : defaultStopCondition);
503- cstr.populateConstraintsSet (map, operands, &pos);
537+ pos = cstr.populateConstraints (map, operands);
538+ assert (pos == 0 && " expected `map` is the first column" );
504539
505540 // Compute constant bound for `valueDim`.
506541 int64_t ubAdjustment = closedUB ? 0 : 1 ;
@@ -509,29 +544,28 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
509544 return failure ();
510545}
511546
512- int64_t
513- ValueBoundsConstraintSet::populateConstraintsSet (Value value,
514- std::optional<int64_t > dim) {
547+ void ValueBoundsConstraintSet::populateConstraints (Value value,
548+ std::optional<int64_t > dim) {
515549#ifndef NDEBUG
516550 assertValidValueDim (value, dim);
517551#endif // NDEBUG
518552
519- AffineMap map =
520- AffineMap::get (/* dimCount=*/ 1 , /* symbolCount=*/ 0 ,
521- Builder (value.getContext ()).getAffineDimExpr (0 ));
522- return populateConstraintsSet (map, {{value, dim}});
553+ // `getExpr` pushes the value/dim onto the worklist (unless it was already
554+ // analyzed).
555+ (void )getExpr (value, dim);
556+ // Process all values/dims on the worklist. This may traverse and analyze
557+ // additional IR, depending the current stop function.
558+ processWorklist ();
523559}
524560
525- int64_t ValueBoundsConstraintSet::populateConstraintsSet (AffineMap map,
526- ValueDimList operands,
527- int64_t *posOut) {
561+ int64_t ValueBoundsConstraintSet::populateConstraints (AffineMap map,
562+ ValueDimList operands) {
528563 assert (map.getNumResults () == 1 && " expected affine map with one result" );
529564 int64_t pos = insert (/* isSymbol=*/ false );
530- if (posOut)
531- *posOut = pos;
532565
533566 // Add map and operands to the constraint set. Dimensions are converted to
534- // symbols. All operands are added to the worklist.
567+ // symbols. All operands are added to the worklist (unless they were already
568+ // processed).
535569 auto mapper = [&](std::pair<Value, std::optional<int64_t >> v) {
536570 return getExpr (v.first , v.second );
537571 };
@@ -566,6 +600,55 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
566600 {{value1, dim1}, {value2, dim2}});
567601}
568602
603+ bool ValueBoundsConstraintSet::compare (Value lhs, std::optional<int64_t > lhsDim,
604+ ComparisonOperator cmp, Value rhs,
605+ std::optional<int64_t > rhsDim) {
606+ // This function returns "true" if "lhs CMP rhs" is proven to hold.
607+ //
608+ // Example for ComparisonOperator::LE and index-typed values: We would like to
609+ // prove that lhs <= rhs. Proof by contradiction: add the inverse
610+ // relation (lhs > rhs) to the constraint set and check if the resulting
611+ // constraint set is "empty" (i.e. has no solution). In that case,
612+ // lhs > rhs must be incorrect and we can deduce that lhs <= rhs holds.
613+
614+ // We cannot prove anything if the constraint set is already empty.
615+ if (cstr.isEmpty ()) {
616+ LLVM_DEBUG (
617+ llvm::dbgs ()
618+ << " cannot compare value/dims: constraint system is already empty" );
619+ return false ;
620+ }
621+
622+ // EQ can be expressed as LE and GE.
623+ if (cmp == EQ)
624+ return compare (lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
625+ compare (lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
626+
627+ // Construct inequality. For the above example: lhs > rhs.
628+ // `IntegerRelation` inequalities are expressed in the "flattened" form and
629+ // with ">= 0". I.e., lhs - rhs - 1 >= 0.
630+ SmallVector<int64_t > eq (cstr.getNumDimAndSymbolVars () + 1 , 0 );
631+ if (cmp == LT || cmp == LE) {
632+ ++eq[getPos (lhs, lhsDim)];
633+ --eq[getPos (rhs, rhsDim)];
634+ } else if (cmp == GT || cmp == GE) {
635+ --eq[getPos (lhs, lhsDim)];
636+ ++eq[getPos (rhs, rhsDim)];
637+ } else {
638+ llvm_unreachable (" unsupported comparison operator" );
639+ }
640+ if (cmp == LE || cmp == GE)
641+ eq[cstr.getNumDimAndSymbolVars ()] -= 1 ;
642+
643+ // Add inequality to the constraint set and check if it made the constraint
644+ // set empty.
645+ int64_t ineqPos = cstr.getNumInequalities ();
646+ cstr.addInequality (eq);
647+ bool isEmpty = cstr.isEmpty ();
648+ cstr.removeInequality (ineqPos);
649+ return isEmpty;
650+ }
651+
569652FailureOr<bool >
570653ValueBoundsConstraintSet::areEqual (Value value1, Value value2,
571654 std::optional<int64_t > dim1,
0 commit comments