Skip to content

Commit 647a23f

Browse files
committed
constant folding of IntegralType on binaryComparison
1 parent d0b1891 commit 647a23f

File tree

2 files changed

+130
-0
lines changed

2 files changed

+130
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,102 @@ object ConstantFolding extends Rule[LogicalPlan] {
382382
case Literal(candidate, _) if candidate == v => true
383383
case _ => false
384384
} => Literal.create(true, BooleanType)
385+
386+
case EqualTo(c @ Cast(a: Attribute, _), Literal(v, _)) if isUpCastingIntegral(c) &&
387+
(v.asInstanceOf[Number].longValue < minValue(a.dataType) ||
388+
v.asInstanceOf[Number].longValue > maxValue(a.dataType)) =>
389+
Literal.create(false, BooleanType)
390+
case EqualTo(Literal(v, _), c @ Cast(a: Attribute, _)) if isUpCastingIntegral(c) &&
391+
(v.asInstanceOf[Number].longValue < minValue(a.dataType) ||
392+
v.asInstanceOf[Number].longValue > maxValue(a.dataType)) =>
393+
Literal.create(false, BooleanType)
394+
395+
case EqualNullSafe(c @ Cast(a: Attribute, _), Literal(v, _)) if isUpCastingIntegral(c) &&
396+
(v.asInstanceOf[Number].longValue < minValue(a.dataType) ||
397+
v.asInstanceOf[Number].longValue > maxValue(a.dataType)) =>
398+
Literal.create(false, BooleanType)
399+
case EqualNullSafe(Literal(v, _), c @ Cast(a: Attribute, _)) if isUpCastingIntegral(c) &&
400+
(v.asInstanceOf[Number].longValue < minValue(a.dataType) ||
401+
v.asInstanceOf[Number].longValue > maxValue(a.dataType)) =>
402+
Literal.create(false, BooleanType)
403+
404+
case GreaterThan(c @ Cast(a: Attribute, _), Literal(v, _)) if isUpCastingIntegral(c) &&
405+
v.asInstanceOf[Number].longValue < minValue(a.dataType) =>
406+
Literal.create(true, BooleanType)
407+
case GreaterThan(c @ Cast(a: Attribute, _), Literal(v, _)) if isUpCastingIntegral(c) &&
408+
v.asInstanceOf[Number].longValue >= maxValue(a.dataType) =>
409+
Literal.create(false, BooleanType)
410+
case GreaterThan(Literal(v, _), c @ Cast(a: Attribute, _)) if isUpCastingIntegral(c) &&
411+
v.asInstanceOf[Number].longValue <= minValue(a.dataType) =>
412+
Literal.create(false, BooleanType)
413+
case GreaterThan(Literal(v, _), c @ Cast(a: Attribute, _)) if isUpCastingIntegral(c) &&
414+
v.asInstanceOf[Number].longValue > maxValue(a.dataType) =>
415+
Literal.create(true, BooleanType)
416+
417+
case LessThan(c @ Cast(a: Attribute, _), Literal(v, _)) if isUpCastingIntegral(c) &&
418+
v.asInstanceOf[Number].longValue <= minValue(a.dataType) =>
419+
Literal.create(false, BooleanType)
420+
case LessThan(c @ Cast(a: Attribute, _), Literal(v, _)) if isUpCastingIntegral(c) &&
421+
v.asInstanceOf[Number].longValue > maxValue(a.dataType) =>
422+
Literal.create(true, BooleanType)
423+
case LessThan(Literal(v, _), c @ Cast(a: Attribute, _)) if isUpCastingIntegral(c) &&
424+
v.asInstanceOf[Number].longValue < minValue(a.dataType) =>
425+
Literal.create(true, BooleanType)
426+
case LessThan(Literal(v, _), c @ Cast(a: Attribute, _)) if isUpCastingIntegral(c) &&
427+
v.asInstanceOf[Number].longValue >= maxValue(a.dataType) =>
428+
Literal.create(false, BooleanType)
429+
430+
case GreaterThanOrEqual(c @ Cast(a: Attribute, _), Literal(v, _))
431+
if isUpCastingIntegral(c) && v.asInstanceOf[Number].longValue <= minValue(a.dataType) =>
432+
Literal.create(true, BooleanType)
433+
case GreaterThanOrEqual(c @ Cast(a: Attribute, _), Literal(v, _))
434+
if isUpCastingIntegral(c) && v.asInstanceOf[Number].longValue > maxValue(a.dataType) =>
435+
Literal.create(false, BooleanType)
436+
case GreaterThanOrEqual(Literal(v, _), c @ Cast(a: Attribute, _))
437+
if isUpCastingIntegral(c) && v.asInstanceOf[Number].longValue < minValue(a.dataType) =>
438+
Literal.create(false, BooleanType)
439+
case GreaterThanOrEqual(Literal(v, _), c @ Cast(a: Attribute, _))
440+
if isUpCastingIntegral(c) && v.asInstanceOf[Number].longValue >= maxValue(a.dataType) =>
441+
Literal.create(true, BooleanType)
442+
443+
case LessThanOrEqual(c @ Cast(a: Attribute, _), Literal(v, _)) if isUpCastingIntegral(c) &&
444+
v.asInstanceOf[Number].longValue < minValue(a.dataType) =>
445+
Literal.create(false, BooleanType)
446+
case LessThanOrEqual(c @ Cast(a: Attribute, _), Literal(v, _)) if isUpCastingIntegral(c) &&
447+
v.asInstanceOf[Number].longValue >= maxValue(a.dataType) =>
448+
Literal.create(true, BooleanType)
449+
case LessThanOrEqual(Literal(v, _), c @ Cast(a: Attribute, _)) if isUpCastingIntegral(c) &&
450+
v.asInstanceOf[Number].longValue <= minValue(a.dataType) =>
451+
Literal.create(true, BooleanType)
452+
case LessThanOrEqual(Literal(v, _), c @ Cast(a: Attribute, _)) if isUpCastingIntegral(c) &&
453+
v.asInstanceOf[Number].longValue > maxValue(a.dataType) =>
454+
Literal.create(false, BooleanType)
455+
}
456+
}
457+
458+
private val integralPrecedence = Seq(ByteType, ShortType, IntegerType, LongType)
459+
460+
private def isUpCastingIntegral(c: Cast): Boolean = {
461+
(c.child.dataType, c.dataType) match {
462+
case (from: IntegralType, to: IntegralType)
463+
if integralPrecedence.indexOf(from) < integralPrecedence.indexOf(to) => true
464+
case _ => false
465+
}
466+
}
467+
468+
private def maxValue(dataType: DataType): Long = {
469+
dataType match {
470+
case ByteType => Byte.MaxValue.toLong
471+
case ShortType => Short.MaxValue.toLong
472+
case IntegerType => Int.MaxValue.toLong
473+
}
474+
}
475+
476+
private def minValue(dataType: DataType): Long = {
477+
dataType match {
478+
case ByteType => Byte.MinValue.toLong
479+
case ShortType => Short.MinValue.toLong
480+
case IntegerType => Int.MinValue.toLong
385481
}
386482
}
387483
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,4 +280,38 @@ class ConstantFoldingSuite extends PlanTest {
280280

281281
comparePlans(optimized, correctAnswer)
282282
}
283+
284+
test("binary comparison folding") {
285+
val trueQuery = testRelation.select(Literal(true).as("r"))
286+
val falseQuery = testRelation.select(Literal(false).as("r"))
287+
def checkComparisonFolding(l: LogicalPlan, expected: Boolean): Unit = {
288+
val optimized = Optimize.execute(l.analyze)
289+
if (expected) {
290+
comparePlans(optimized, trueQuery)
291+
} else {
292+
comparePlans(optimized, falseQuery)
293+
}
294+
}
295+
296+
checkComparisonFolding(
297+
testRelation.select(EqualTo('a, Int.MaxValue.toLong + 1L).as("r")), false)
298+
checkComparisonFolding(
299+
testRelation.select(EqualTo('a, Int.MinValue.toLong - 1L).as("r")), false)
300+
checkComparisonFolding(
301+
testRelation.select(LessThan('a, Int.MaxValue.toLong + 1L).as("r")), true)
302+
checkComparisonFolding(
303+
testRelation.select(LessThan('a, Int.MinValue.toLong - 1L).as("r")), false)
304+
checkComparisonFolding(
305+
testRelation.select(GreaterThan('a, Int.MaxValue.toLong + 1L).as("r")), false)
306+
checkComparisonFolding(
307+
testRelation.select(GreaterThan('a, Int.MinValue.toLong - 1L).as("r")), true)
308+
checkComparisonFolding(
309+
testRelation.select(LessThanOrEqual('a, Int.MaxValue.toLong).as("r")), true)
310+
checkComparisonFolding(
311+
testRelation.select(LessThanOrEqual('a, Int.MinValue.toLong - 1L).as("r")), false)
312+
checkComparisonFolding(
313+
testRelation.select(GreaterThanOrEqual('a, Int.MaxValue.toLong + 1L).as("r")), false)
314+
checkComparisonFolding(
315+
testRelation.select(GreaterThanOrEqual('a, Int.MinValue.toLong).as("r")), true)
316+
}
283317
}

0 commit comments

Comments
 (0)