Skip to content

Commit 2cac317

Browse files
10110346cloud-fan
authored andcommitted
[SPARK-20665][SQL] Bround" and "Round" function return NULL
## What changes were proposed in this pull request? spark-sql>select bround(12.3, 2); spark-sql>NULL For this case, the expected result is 12.3, but it is null. So ,when the second parameter is bigger than "decimal.scala", the result is not we expected. "round" function has the same problem. This PR can solve the problem for both of them. ## How was this patch tested? unit test cases in MathExpressionsSuite and MathFunctionsSuite Author: liuxian <[email protected]> Closes #17906 from 10110346/wip_lx_0509. (cherry picked from commit 2b36eb6) Signed-off-by: Wenchen Fan <[email protected]>
1 parent 3d1908f commit 2cac317

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,10 +1023,10 @@ abstract class RoundBase(child: Expression, scale: Expression,
10231023

10241024
// not overriding since _scale is a constant int at runtime
10251025
def nullSafeEval(input1: Any): Any = {
1026-
child.dataType match {
1027-
case _: DecimalType =>
1026+
dataType match {
1027+
case DecimalType.Fixed(_, s) =>
10281028
val decimal = input1.asInstanceOf[Decimal]
1029-
decimal.toPrecision(decimal.precision, _scale, mode).orNull
1029+
decimal.toPrecision(decimal.precision, s, mode).orNull
10301030
case ByteType =>
10311031
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
10321032
case ShortType =>
@@ -1055,10 +1055,10 @@ abstract class RoundBase(child: Expression, scale: Expression,
10551055
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
10561056
val ce = child.genCode(ctx)
10571057

1058-
val evaluationCode = child.dataType match {
1059-
case _: DecimalType =>
1058+
val evaluationCode = dataType match {
1059+
case DecimalType.Fixed(_, s) =>
10601060
s"""
1061-
if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale},
1061+
if (${ce.value}.changePrecision(${ce.value}.precision(), ${s},
10621062
java.math.BigDecimal.${modeStr})) {
10631063
${ev.value} = ${ce.value};
10641064
} else {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -546,15 +546,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
546546
val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
547547
BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
548548
BigDecimal(3.141593), BigDecimal(3.1415927))
549-
// round_scale > current_scale would result in precision increase
550-
// and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
549+
551550
(0 to 7).foreach { i =>
552551
checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow)
553552
checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow)
554553
}
555554
(8 to 10).foreach { scale =>
556-
checkEvaluation(Round(bdPi, scale), null, EmptyRow)
557-
checkEvaluation(BRound(bdPi, scale), null, EmptyRow)
555+
checkEvaluation(Round(bdPi, scale), bdPi, EmptyRow)
556+
checkEvaluation(BRound(bdPi, scale), bdPi, EmptyRow)
558557
}
559558

560559
DataTypeTestUtils.numericTypes.foreach { dataType =>

sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,19 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext {
231231
Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
232232
BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
233233
)
234+
235+
val bdPi: BigDecimal = BigDecimal(31415925L, 7)
236+
checkAnswer(
237+
sql(s"SELECT round($bdPi, 7), round($bdPi, 8), round($bdPi, 9), round($bdPi, 10), " +
238+
s"round($bdPi, 100), round($bdPi, 6), round(null, 8)"),
239+
Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141593"), null))
240+
)
241+
242+
checkAnswer(
243+
sql(s"SELECT bround($bdPi, 7), bround($bdPi, 8), bround($bdPi, 9), bround($bdPi, 10), " +
244+
s"bround($bdPi, 100), bround($bdPi, 6), bround(null, 8)"),
245+
Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141592"), null))
246+
)
234247
}
235248

236249
test("round/bround with data frame from a local Seq of Product") {

0 commit comments

Comments
 (0)