@@ -75,6 +75,75 @@ struct MulIOpInterface
7575 }
7676};
7777
78+ struct SelectOpInterface
79+ : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
80+ SelectOp> {
81+
82+ static void populateBounds (SelectOp selectOp, std::optional<int64_t > dim,
83+ ValueBoundsConstraintSet &cstr) {
84+ Value value = selectOp.getResult ();
85+ Value condition = selectOp.getCondition ();
86+ Value trueValue = selectOp.getTrueValue ();
87+ Value falseValue = selectOp.getFalseValue ();
88+
89+ if (isa<ShapedType>(condition.getType ())) {
90+ // If the condition is a shaped type, the condition is applied
91+ // element-wise. All three operands must have the same shape.
92+ cstr.bound (value)[*dim] == cstr.getExpr (trueValue, dim);
93+ cstr.bound (value)[*dim] == cstr.getExpr (falseValue, dim);
94+ cstr.bound (value)[*dim] == cstr.getExpr (condition, dim);
95+ return ;
96+ }
97+
98+ // Populate constraints for the true/false values (and all values on the
99+ // backward slice, as long as the current stop condition is not satisfied).
100+ cstr.populateConstraints (trueValue, dim);
101+ cstr.populateConstraints (falseValue, dim);
102+ auto boundsBuilder = cstr.bound (value);
103+ if (dim)
104+ boundsBuilder[*dim];
105+
106+ // Compare yielded values.
107+ // If trueValue <= falseValue:
108+ // * result <= falseValue
109+ // * result >= trueValue
110+ if (cstr.compare (trueValue, dim,
111+ ValueBoundsConstraintSet::ComparisonOperator::LE,
112+ falseValue, dim)) {
113+ if (dim) {
114+ cstr.bound (value)[*dim] >= cstr.getExpr (trueValue, dim);
115+ cstr.bound (value)[*dim] <= cstr.getExpr (falseValue, dim);
116+ } else {
117+ cstr.bound (value) >= trueValue;
118+ cstr.bound (value) <= falseValue;
119+ }
120+ }
121+ // If falseValue <= trueValue:
122+ // * result <= trueValue
123+ // * result >= falseValue
124+ if (cstr.compare (falseValue, dim,
125+ ValueBoundsConstraintSet::ComparisonOperator::LE,
126+ trueValue, dim)) {
127+ if (dim) {
128+ cstr.bound (value)[*dim] >= cstr.getExpr (falseValue, dim);
129+ cstr.bound (value)[*dim] <= cstr.getExpr (trueValue, dim);
130+ } else {
131+ cstr.bound (value) >= falseValue;
132+ cstr.bound (value) <= trueValue;
133+ }
134+ }
135+ }
136+
137+ void populateBoundsForIndexValue (Operation *op, Value value,
138+ ValueBoundsConstraintSet &cstr) const {
139+ populateBounds (cast<SelectOp>(op), /* dim=*/ std::nullopt , cstr);
140+ }
141+
142+ void populateBoundsForShapedValueDim (Operation *op, Value value, int64_t dim,
143+ ValueBoundsConstraintSet &cstr) const {
144+ populateBounds (cast<SelectOp>(op), dim, cstr);
145+ }
146+ };
78147} // namespace
79148} // namespace arith
80149} // namespace mlir
@@ -86,5 +155,6 @@ void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
86155 arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
87156 arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
88157 arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
158+ arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
89159 });
90160}
0 commit comments