Skip to content

Commit 2077888

Browse files
committed
codegen versioned eval
1 parent 6cd9a64 commit 2077888

File tree

2 files changed

+89
-2
lines changed

2 files changed

+89
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,8 @@ case class Round(child: Expression, scale: Expression) extends Expression {
572572

573573
if (evalE == null || scaleV == null) return null
574574

575-
children(0).dataType match {
576-
case decimalType: DecimalType =>
575+
child.dataType match {
576+
case _: DecimalType =>
577577
val decimal = evalE.asInstanceOf[Decimal]
578578
if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
579579
case ByteType =>
@@ -595,6 +595,84 @@ case class Round(child: Expression, scale: Expression) extends Expression {
595595
}
596596
}
597597

598+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
599+
val ce = child.gen(ctx)
600+
601+
def integralRound(primitive: String): String = {
602+
s"""
603+
${ev.primitive} = new java.math.BigDecimal(${primitive}).
604+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP)"""
605+
}
606+
607+
def fractionalRound(primitive: String): String = {
608+
s"""
609+
${ev.primitive} = java.math.BigDecimal.valueOf(${primitive}).
610+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP)"""
611+
}
612+
613+
def check(primitive: String, function: String): String = {
614+
s"""
615+
if (Double.isNaN(${primitive}) || Double.isInfinite(${primitive})){
616+
${ev.primitive} = ${primitive};
617+
} else {
618+
${fractionalRound(primitive)}.${function};
619+
}"""
620+
}
621+
622+
def convert(primitive: String): String = {
623+
val dName = ctx.freshName("converter")
624+
s"""
625+
Double $dName = 0.0;
626+
try {
627+
$dName = Double.valueOf(${primitive}.toString());
628+
} catch (NumberFormatException e) {
629+
${ev.isNull} = true;
630+
}
631+
${check(dName, "doubleValue()")}
632+
"""
633+
}
634+
635+
def decimalRound(): String = {
636+
s"""
637+
if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) {
638+
${ev.primitive} = ${ce.primitive};
639+
} else {
640+
${ev.isNull} = true;
641+
}
642+
"""
643+
}
644+
645+
val roundCode = child.dataType match {
646+
case NullType => ";"
647+
case _: DecimalType =>
648+
decimalRound()
649+
case ByteType =>
650+
integralRound(ce.primitive) + ".byteValue();"
651+
case ShortType =>
652+
integralRound(ce.primitive) + ".shortValue();"
653+
case IntegerType =>
654+
integralRound(ce.primitive) + ".intValue();"
655+
case LongType =>
656+
integralRound(ce.primitive) + ".longValue();"
657+
case FloatType =>
658+
check(ce.primitive, "floatValue()")
659+
case DoubleType =>
660+
check(ce.primitive, "doubleValue()")
661+
case StringType =>
662+
convert(ce.primitive)
663+
case BinaryType =>
664+
convert(s"${ctx.stringType}.fromBytes(${ce.primitive})")
665+
}
666+
667+
ce.code + s"""
668+
boolean ${ev.isNull} = ${ce.isNull} || ${scaleV == null};
669+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
670+
if (!${ev.isNull}) {
671+
${roundCode}
672+
}
673+
"""
674+
}
675+
598676
private def round[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = {
599677
input match {
600678
case f: Float if (f.isNaN || f.isInfinite) => return input

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,16 +343,25 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
343343
val domain = -16 to 16
344344
val doublePi = math.Pi
345345
val stringPi = "3.141592653589793"
346+
val arrayPi: Array[Byte] = stringPi.toCharArray.map(_.toByte)
347+
val shortPi: Short = 31415
346348
val intPi = 314159265
349+
val longPi = 31415926535897932L
347350
val bdPi = BigDecimal(31415926535897932L, 10)
348351

349352
domain.foreach { scale =>
350353
checkEvaluation(Round(doublePi, scale),
351354
BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow)
352355
checkEvaluation(Round(stringPi, scale),
353356
BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow)
357+
checkEvaluation(Round(arrayPi, scale),
358+
BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow)
359+
checkEvaluation(Round(shortPi, scale),
360+
BigDecimal.valueOf(shortPi).setScale(scale, RoundingMode.HALF_UP).toShort, EmptyRow)
354361
checkEvaluation(Round(intPi, scale),
355362
BigDecimal.valueOf(intPi).setScale(scale, RoundingMode.HALF_UP).toInt, EmptyRow)
363+
checkEvaluation(Round(longPi, scale),
364+
BigDecimal.valueOf(longPi).setScale(scale, RoundingMode.HALF_UP).toLong, EmptyRow)
356365
}
357366
checkEvaluation(new Round(Literal("invalid input")), null, EmptyRow)
358367

0 commit comments

Comments
 (0)