Skip to content

Commit 56db4bb

Browse files
committed
Add decimal support to Round
1 parent 7e163ae commit 56db4bb

File tree

2 files changed

+10
-18
lines changed

2 files changed

+10
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -528,9 +528,13 @@ case class Round(children: Seq[Expression]) extends Expression {
528528

529529
def nullable: Boolean = true
530530

531-
def dataType: DataType = {
531+
private lazy val evalE2 = if (children.size == 2) children(1).eval(EmptyRow) else null
532+
private lazy val _scale = if (evalE2 != null) evalE2.asInstanceOf[Int] else 0
533+
534+
override lazy val dataType: DataType = {
532535
children(0).dataType match {
533536
case StringType | BinaryType => DoubleType
537+
case DecimalType.Fixed(p, s) => DecimalType(p, _scale)
534538
case t => t
535539
}
536540
}
@@ -564,23 +568,14 @@ case class Round(children: Seq[Expression]) extends Expression {
564568

565569
def eval(input: InternalRow): Any = {
566570
val evalE1 = children(0).eval(input)
567-
if (evalE1 == null) {
568-
return null
569-
}
570571

571-
var _scale: Int = 0
572-
if (children.size == 2) {
573-
val evalE2 = children(1).eval(input)
574-
if (evalE2 == null) {
575-
return null
576-
} else {
577-
_scale = evalE2.asInstanceOf[Int]
578-
}
579-
}
572+
if (evalE1 == null) return null
573+
if (children.size == 2 && evalE2 == null) return null
580574

581575
children(0).dataType match {
582576
case decimalType: DecimalType =>
583-
// TODO: Support Decimal Round
577+
val decimal = evalE1.asInstanceOf[Decimal]
578+
if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
584579
case ByteType =>
585580
round(evalE1.asInstanceOf[Byte], _scale)
586581
case ShortType =>

sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
221221
"udf_when",
222222
"udf_case",
223223

224-
// Needs constant object inspectors
225-
"udf_round",
226-
227224
// the table src(key INT, value STRING) is not the same as HIVE unittest. In Hive
228225
// is src(key STRING, value STRING), and in the reflect.q, it failed in
229226
// Integer.valueOf, which expect the first argument passed as STRING type not INT.
@@ -918,7 +915,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
918915
"udf_regexp_replace",
919916
"udf_repeat",
920917
"udf_rlike",
921-
"udf_round",
918+
// "udf_round", turn this on after we figure out null vs nan vs infinity
922919
"udf_round_3",
923920
"udf_rpad",
924921
"udf_rtrim",

0 commit comments

Comments
 (0)