Skip to content

Commit 61760ee

Browse files
committed
address reviews
1 parent 302a78a commit 61760ee

File tree

2 files changed

+141
-55
lines changed

2 files changed

+141
-55
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 110 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,20 @@ case class Logarithm(left: Expression, right: Expression)
523523
}
524524
}
525525

526+
/**
527+
* Round the `child`'s result to `scale` decimal place when `scale` >= 0
528+
* or round at integral part when `scale` < 0.
529+
* For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30.
530+
*
531+
* Child of IntegralType would eval to itself when `scale` >= 0.
532+
* Child of FractionalType whose value is NaN or Infinite would always eval to itself.
533+
*
534+
* Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]],
535+
* which leads to scale update in DecimalType's [[PrecisionInfo]]
536+
*
537+
* @param child expr to be round, all [[NumericType]] is allowed as Input
538+
* @param scale new scale to be round to, this should be a constant int at runtime
539+
*/
526540
case class Round(child: Expression, scale: Expression)
527541
extends BinaryExpression with ExpectsInputTypes {
528542

@@ -559,10 +573,27 @@ case class Round(child: Expression, scale: Expression)
559573
}
560574
}
561575

562-
private lazy val scaleV = scale.eval(EmptyRow)
563-
private lazy val _scale = if (scaleV != null) scaleV.asInstanceOf[Int] else 0
576+
// Avoid repeated evaluation since `scale` is a constant int,
577+
// avoid unnecessary `child` evaluation in both codegen and non-codegen eval
578+
// by checking if scaleV == null as well.
579+
private lazy val scaleV: Any = scale.eval(EmptyRow)
580+
private lazy val _scale: Int = scaleV.asInstanceOf[Int]
564581

565-
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
582+
override def eval(input: InternalRow): Any = {
583+
if (scaleV == null) { // if scale is null, no need to eval its child at all
584+
null
585+
} else {
586+
val evalE = child.eval(input)
587+
if (evalE == null) {
588+
null
589+
} else {
590+
nullSafeEval(evalE)
591+
}
592+
}
593+
}
594+
595+
// not overriding since _scale is a constant int at runtime
596+
def nullSafeEval(input1: Any): Any = {
566597
child.dataType match {
567598
case _: DecimalType =>
568599
val decimal = input1.asInstanceOf[Decimal]
@@ -604,45 +635,89 @@ case class Round(child: Expression, scale: Expression)
604635
${ev.isNull} = true;
605636
}"""
606637
case ByteType =>
607-
s"""
608-
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
609-
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();"""
638+
if (_scale < 0) {
639+
s"""
640+
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
641+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();"""
642+
} else {
643+
s"${ev.primitive} = ${ce.primitive};"
644+
}
610645
case ShortType =>
611-
s"""
612-
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
613-
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();"""
646+
if (_scale < 0) {
647+
s"""
648+
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
649+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();"""
650+
} else {
651+
s"${ev.primitive} = ${ce.primitive};"
652+
}
614653
case IntegerType =>
615-
s"""
616-
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
617-
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();"""
654+
if (_scale < 0) {
655+
s"""
656+
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
657+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();"""
658+
} else {
659+
s"${ev.primitive} = ${ce.primitive};"
660+
}
618661
case LongType =>
619-
s"""
620-
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
621-
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();"""
622-
case FloatType =>
623-
s"""
624-
if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){
625-
${ev.primitive} = ${ce.primitive};
662+
if (_scale < 0) {
663+
s"""
664+
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
665+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();"""
626666
} else {
627-
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
628-
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
629-
}"""
630-
case DoubleType =>
631-
s"""
632-
if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){
633-
${ev.primitive} = ${ce.primitive};
667+
s"${ev.primitive} = ${ce.primitive};"
668+
}
669+
case FloatType => // if child eval to NaN or Infinity, just return it.
670+
if (_scale == 0) {
671+
s"""
672+
if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){
673+
${ev.primitive} = ${ce.primitive};
674+
} else {
675+
${ev.primitive} = Math.round(${ce.primitive});
676+
}"""
634677
} else {
635-
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
636-
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
637-
}"""
678+
s"""
679+
if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){
680+
${ev.primitive} = ${ce.primitive};
681+
} else {
682+
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
683+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
684+
}"""
685+
}
686+
case DoubleType => // if child eval to NaN or Infinity, just return it.
687+
if (_scale == 0) {
688+
s"""
689+
if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){
690+
${ev.primitive} = ${ce.primitive};
691+
} else {
692+
${ev.primitive} = Math.round(${ce.primitive});
693+
}"""
694+
} else {
695+
s"""
696+
if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){
697+
${ev.primitive} = ${ce.primitive};
698+
} else {
699+
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
700+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
701+
}"""
702+
}
638703
}
639704

640-
ce.code + s"""
641-
boolean ${ev.isNull} = ${ce.isNull} || ${scaleV == null};
642-
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
643-
if (!${ev.isNull}) {
644-
${evaluationCode}
645-
}
705+
if (scaleV == null) { // if scale is null, no need to eval its child at all
706+
s"""
707+
boolean ${ev.isNull} = true;
708+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
709+
"""
710+
} else {
711+
s"""
712+
${ce.code}
713+
boolean ${ev.isNull} = ${ce.isNull};
714+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
715+
if (!${ev.isNull}) {
716+
$evaluationCode
717+
}
646718
"""
719+
}
647720
}
721+
722+
override def prettyName: String = "round"
648723
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -340,32 +340,43 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
340340
}
341341

342342
test("round") {
343-
val domain = -16 to 16
344-
val doublePi = math.Pi
343+
val domain = -6 to 6
344+
val doublePi: Double = math.Pi
345345
val shortPi: Short = 31415
346-
val intPi = 314159265
347-
val longPi = 31415926535897932L
348-
val bdPi = BigDecimal(31415926535897932L, 10)
349-
350-
domain.foreach { scale =>
351-
checkEvaluation(Round(doublePi, scale),
352-
BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow)
353-
checkEvaluation(Round(shortPi, scale),
354-
BigDecimal.valueOf(shortPi).setScale(scale, RoundingMode.HALF_UP).toShort, EmptyRow)
355-
checkEvaluation(Round(intPi, scale),
356-
BigDecimal.valueOf(intPi).setScale(scale, RoundingMode.HALF_UP).toInt, EmptyRow)
357-
checkEvaluation(Round(longPi, scale),
358-
BigDecimal.valueOf(longPi).setScale(scale, RoundingMode.HALF_UP).toLong, EmptyRow)
346+
val intPi: Int = 314159265
347+
val longPi: Long = 31415926535897932L
348+
val bdPi: BigDecimal = BigDecimal(31415927L, 7)
349+
350+
val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142,
351+
3.1416, 3.14159, 3.141593)
352+
353+
val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++
354+
Seq.fill[Short](7)(31415)
355+
356+
val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300,
357+
314159270) ++ Seq.fill(7)(314159265)
358+
359+
val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L,
360+
31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++
361+
Seq.fill(7)(31415926535897932L)
362+
363+
val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
364+
BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
365+
BigDecimal(3.141593), BigDecimal(3.1415927))
366+
367+
domain.zipWithIndex.foreach { case (scale, i) =>
368+
checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow)
369+
checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
370+
checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
371+
checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
359372
}
360373

361374
// round_scale > current_scale would result in precision increase
362375
// and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
363-
val (validScales, nullScales) = domain.splitAt(27)
364-
validScales.foreach { scale =>
365-
checkEvaluation(Round(bdPi, scale),
366-
Decimal(bdPi.setScale(scale, RoundingMode.HALF_UP)), EmptyRow)
376+
(0 to 7).foreach { i =>
377+
checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow)
367378
}
368-
nullScales.foreach { scale =>
379+
(8 to 10).foreach { scale =>
369380
checkEvaluation(Round(bdPi, scale), null, EmptyRow)
370381
}
371382
}

0 commit comments

Comments
 (0)