Skip to content

Commit 656f6a2

Browse files
author
wangzhenhua
committed
fix boolean type for in set condition
1 parent 4ef05e7 commit 656f6a2

File tree

4 files changed

+41
-40
lines changed

4 files changed

+41
-40
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import scala.math.BigDecimal.RoundingMode
2222
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
2323
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics}
2424
import org.apache.spark.sql.internal.SQLConf
25-
import org.apache.spark.sql.types.{DataType, StringType}
25+
import org.apache.spark.sql.types.{DecimalType, _}
2626

2727

2828
object EstimationUtils {
@@ -75,4 +75,26 @@ object EstimationUtils {
7575
// (simple computation of statistics returns product of children).
7676
if (outputRowCount > 0) outputRowCount * sizePerRow else 1
7777
}
78+
79+
/**
80+
* For simplicity we use Decimal to unify operations for data types whose min/max values can be
81+
* represented as numbers, e.g. Boolean can be represented as 0 (false) or 1 (true).
82+
* The two methods below are the contract of conversion.
83+
*/
84+
def toDecimal(value: Any, dataType: DataType): Decimal = {
85+
dataType match {
86+
case _: NumericType | DateType | TimestampType => Decimal(value.toString)
87+
case BooleanType => if (value.asInstanceOf[Boolean]) Decimal(1) else Decimal(0)
88+
}
89+
}
90+
91+
def fromDecimal(dec: Decimal, dataType: DataType): Any = {
92+
dataType match {
93+
case _: IntegralType | DateType | TimestampType => dec.toLong
94+
case FloatType | DoubleType => dec.toDouble
95+
case _: DecimalType => dec
96+
case BooleanType => dec.toLong == 1
97+
}
98+
}
99+
78100
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,8 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
409409
return Some(0.0)
410410
}
411411

412-
val newMax = validQuerySet.maxBy(v => BigDecimal(v.toString))
413-
val newMin = validQuerySet.minBy(v => BigDecimal(v.toString))
412+
val newMax = validQuerySet.maxBy(EstimationUtils.toDecimal(_, dataType))
413+
val newMin = validQuerySet.minBy(EstimationUtils.toDecimal(_, dataType))
414414
// newNdv should not be greater than the old ndv. For example, column has only 2 values
415415
// 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5.
416416
newNdv = ndv.min(BigInt(validQuerySet.size))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ object Range {
5454
def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match {
5555
case StringType | BinaryType => new DefaultRange()
5656
case _ if min.isEmpty || max.isEmpty => new NullRange()
57-
case _ => toNumericRange(min.get, max.get, dataType)
57+
case _ =>
58+
NumericRange(
59+
min = EstimationUtils.toDecimal(min.get, dataType),
60+
max = EstimationUtils.toDecimal(max.get, dataType))
5861
}
5962

6063
def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match {
@@ -79,40 +82,10 @@ object Range {
7982
(None, None)
8083
case (n1: NumericRange, n2: NumericRange) =>
8184
// Choose the maximum of two min values, and the minimum of two max values.
82-
val newRange = NumericRange(
83-
min = if (n1.min <= n2.min) n2.min else n1.min,
84-
max = if (n1.max <= n2.max) n1.max else n2.max)
85-
val (newMin, newMax) = fromNumericRange(newRange, dt)
86-
(Some(newMin), Some(newMax))
85+
val newMin = if (n1.min <= n2.min) n2.min else n1.min
86+
val newMax = if (n1.max <= n2.max) n1.max else n2.max
87+
(Some(EstimationUtils.fromDecimal(newMin, dt)),
88+
Some(EstimationUtils.fromDecimal(newMax, dt)))
8789
}
8890
}
89-
90-
/**
91-
* For simplicity we use decimal to unify operations of numeric types, the two methods below
92-
* are the contract of conversion.
93-
*/
94-
private def toNumericRange(min: Any, max: Any, dataType: DataType): NumericRange = {
95-
dataType match {
96-
case _: NumericType | DateType | TimestampType =>
97-
NumericRange(Decimal(min.toString), Decimal(max.toString))
98-
case BooleanType =>
99-
val min1 = if (min.asInstanceOf[Boolean]) 1 else 0
100-
val max1 = if (max.asInstanceOf[Boolean]) 1 else 0
101-
NumericRange(Decimal(min1), Decimal(max1))
102-
}
103-
}
104-
105-
private def fromNumericRange(n: NumericRange, dataType: DataType): (Any, Any) = {
106-
dataType match {
107-
case _: IntegralType | DateType | TimestampType =>
108-
(n.min.toLong, n.max.toLong)
109-
case FloatType | DoubleType =>
110-
(n.min.toDouble, n.max.toDouble)
111-
case _: DecimalType =>
112-
(n.min, n.max)
113-
case BooleanType =>
114-
(n.min.toLong == 1, n.max.toLong == 1)
115-
}
116-
}
117-
11891
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
148148

149149
test("cint < 3 OR null") {
150150
val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))
151-
val m = Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)).stats(conf)
152151
validateEstimatedStats(
153152
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
154153
Seq(attrInt -> colStatInt),
@@ -342,6 +341,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
342341
expectedRowCount = 7)
343342
}
344343

344+
test("cbool IN (true)") {
345+
validateEstimatedStats(
346+
Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)),
347+
Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
348+
nullCount = 0, avgLen = 1, maxLen = 1)),
349+
expectedRowCount = 5)
350+
}
351+
345352
test("cbool = true") {
346353
validateEstimatedStats(
347354
Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)),
@@ -533,7 +540,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
533540

534541
test("cint = cint3") {
535542
// no records qualify due to no overlap
536-
val emptyColStats = Seq[(Attribute, ColumnStat)]()
537543
validateEstimatedStats(
538544
Filter(EqualTo(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)),
539545
Nil, // set to empty

0 commit comments

Comments
 (0)