diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index d4ebdb139fe0f..474ec592201d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -310,7 +310,7 @@ object CatalystTypeConverters { case d: JavaBigInteger => Decimal(d) case d: Decimal => d } - decimal.toPrecision(dataType.precision, dataType.scale).orNull + decimal.toPrecision(dataType.precision, dataType.scale) } override def toScala(catalystValue: Decimal): JavaBigDecimal = { if (catalystValue == null) null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d949b8f1d6696..bc809f559d586 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -387,10 +387,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String /** * Create new `Decimal` with precision and scale given in `decimalType` (if any), * returning null if it overflows or creating a new `value` and returning it if successful. - * */ private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal = - value.toPrecision(decimalType.precision, decimalType.scale).orNull + value.toPrecision(decimalType.precision, decimalType.scale) private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index c2211ae5d594b..752dea23e1f7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -85,7 +85,7 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary override def nullable: Boolean = true override def nullSafeEval(input: Any): Any = - input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull + input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 547d5be0e908e..d8dc0862f1141 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1044,7 +1044,7 @@ abstract class RoundBase(child: Expression, scale: Expression, dataType match { case DecimalType.Fixed(_, s) => val decimal = input1.asInstanceOf[Decimal] - decimal.toPrecision(decimal.precision, s, mode).orNull + decimal.toPrecision(decimal.precision, s, mode) case ByteType => BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => @@ -1076,12 +1076,8 @@ abstract class RoundBase(child: Expression, scale: Expression, val evaluationCode = dataType match { case DecimalType.Fixed(_, s) => s""" - if (${ce.value}.changePrecision(${ce.value}.precision(), ${s}, - java.math.BigDecimal.${modeStr})) { - ${ev.value} = ${ce.value}; - } else { - ${ev.isNull} = true; - }""" + ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$modeStr()); + ${ev.isNull} = ${ev.value} == null;""" case ByteType => if (_scale < 0) { s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 1f1fb51addfd8..6da4f28b12962 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -234,22 +234,17 @@ final class Decimal extends Ordered[Decimal] with Serializable { changePrecision(precision, scale, ROUND_HALF_UP) } - def changePrecision(precision: Int, scale: Int, mode: Int): Boolean = mode match { - case java.math.BigDecimal.ROUND_HALF_UP => changePrecision(precision, scale, ROUND_HALF_UP) - case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN) - } - /** * Create new `Decimal` with given precision and scale. * - * @return `Some(decimal)` if successful or `None` if overflow would occur + * @return a non-null `Decimal` value if successful or `null` if overflow would occur. */ private[sql] def toPrecision( precision: Int, scale: Int, - roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = { + roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = { val copy = clone() - if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None + if (copy.changePrecision(precision, scale, roundMode)) copy else null } /** @@ -257,8 +252,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { * * @return true if successful, false if overflow would occur */ - private[sql] def changePrecision(precision: Int, scale: Int, - roundMode: BigDecimal.RoundingMode.Value): Boolean = { + private[sql] def changePrecision( + precision: Int, + scale: Int, + roundMode: BigDecimal.RoundingMode.Value): Boolean = { // fast path for UnsafeProjection if (precision == this.precision && scale == this.scale) { return true @@ -393,14 +390,20 @@ final class Decimal extends Ordered[Decimal] with Serializable { def floor: Decimal = if (scale == 0) this else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision - toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse( - throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) + val res = toPrecision(newPrecision, 0, ROUND_FLOOR) + if (res == null) { + throw new AnalysisException(s"Overflow when setting precision to $newPrecision") + } + res } def ceil: Decimal = if (scale == 0) this else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision - toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse( - throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) + val res = toPrecision(newPrecision, 0, ROUND_CEILING) + if (res == null) { + throw new AnalysisException(s"Overflow when setting precision to $newPrecision") + } + res } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 3193d1320ad9d..10de90c6a44ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -213,7 +213,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(d.changePrecision(10, 0, mode)) assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") - val copy = d.toPrecision(10, 0, mode).orNull + val copy = d.toPrecision(10, 0, mode) assert(copy !== null) assert(d.ne(copy)) assert(d === copy) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index c2d08a06569bf..5be8c581e9ddb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -258,6 +258,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("round/bround with table columns") { + withTable("t") { + Seq(BigDecimal("5.9")).toDF("i").write.saveAsTable("t") + checkAnswer( + sql("select i, round(i) from t"), + Seq(Row(BigDecimal("5.9"), BigDecimal("6")))) + checkAnswer( + sql("select i, bround(i) from t"), + Seq(Row(BigDecimal("5.9"), BigDecimal("6")))) + } + } + test("exp") { testOneToOneMathFunction(exp, math.exp) }