Skip to content

Commit f0e1297

Browse files
yjshenrxin
authored andcommitted
[SPARK-8279][SQL]Add math function round
JIRA: https://issues.apache.org/jira/browse/SPARK-8279 Author: Yijie Shen <[email protected]> Closes apache#6938 from yijieshen/udf_round_3 and squashes the following commits: 07a124c [Yijie Shen] remove useless def children 392b65b [Yijie Shen] add negative scale test in DecimalSuite 61760ee [Yijie Shen] address reviews 302a78a [Yijie Shen] Add dataframe function test 31dfe7c [Yijie Shen] refactor round to make it readable 8c7a949 [Yijie Shen] rebase & inputTypes update 9555e35 [Yijie Shen] tiny style fix d10be4a [Yijie Shen] use TypeCollection to specify wanted input and implicit cast c3b9839 [Yijie Shen] rely on implict cast to handle string input b0bff79 [Yijie Shen] make round's inner method's name more meaningful 9bd6930 [Yijie Shen] revert accidental change e6f44c4 [Yijie Shen] refactor eval and genCode 1b87540 [Yijie Shen] modify checkInputDataTypes using foldable 5486b2d [Yijie Shen] DataFrame API modification 2077888 [Yijie Shen] codegen versioned eval 6cd9a64 [Yijie Shen] refactor Round's constructor 9be894e [Yijie Shen] add round functions in o.a.s.sql.functions 7c83e13 [Yijie Shen] more tests on round 56db4bb [Yijie Shen] Add decimal support to Round 7e163ae [Yijie Shen] style fix 653d047 [Yijie Shen] Add math function round
1 parent 3f6296f commit f0e1297

File tree

8 files changed

+329
-13
lines changed

8 files changed

+329
-13
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ object FunctionRegistry {
117117
expression[Pow]("power"),
118118
expression[UnaryPositive]("positive"),
119119
expression[Rint]("rint"),
120+
expression[Round]("round"),
120121
expression[ShiftLeft]("shiftleft"),
121122
expression[ShiftRight]("shiftright"),
122123
expression[ShiftRightUnsigned]("shiftrightunsigned"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 202 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import java.{lang => jl}
2121

22-
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
23+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure}
2324
import org.apache.spark.sql.catalyst.expressions.codegen._
25+
import org.apache.spark.sql.catalyst.InternalRow
2426
import org.apache.spark.sql.types._
2527
import org.apache.spark.unsafe.types.UTF8String
2628

@@ -520,3 +522,202 @@ case class Logarithm(left: Expression, right: Expression)
520522
"""
521523
}
522524
}
525+
526+
/**
527+
* Round the `child`'s result to `scale` decimal place when `scale` >= 0
528+
* or round at integral part when `scale` < 0.
529+
* For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30.
530+
*
531+
* Child of IntegralType would eval to itself when `scale` >= 0.
532+
* Child of FractionalType whose value is NaN or Infinite would always eval to itself.
533+
*
534+
* Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]],
535+
* which leads to scale update in DecimalType's [[PrecisionInfo]]
536+
*
537+
* @param child expr to be round, all [[NumericType]] is allowed as Input
538+
* @param scale new scale to be round to, this should be a constant int at runtime
539+
*/
540+
case class Round(child: Expression, scale: Expression)
541+
extends BinaryExpression with ExpectsInputTypes {
542+
543+
import BigDecimal.RoundingMode.HALF_UP
544+
545+
def this(child: Expression) = this(child, Literal(0))
546+
547+
override def left: Expression = child
548+
override def right: Expression = scale
549+
550+
// round of Decimal would eval to null if it fails to `changePrecision`
551+
override def nullable: Boolean = true
552+
553+
override def foldable: Boolean = child.foldable
554+
555+
override lazy val dataType: DataType = child.dataType match {
556+
// if the new scale is bigger which means we are scaling up,
557+
// keep the original scale as `Decimal` does
558+
case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale)
559+
case t => t
560+
}
561+
562+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
563+
564+
override def checkInputDataTypes(): TypeCheckResult = {
565+
super.checkInputDataTypes() match {
566+
case TypeCheckSuccess =>
567+
if (scale.foldable) {
568+
TypeCheckSuccess
569+
} else {
570+
TypeCheckFailure("Only foldable Expression is allowed for scale arguments")
571+
}
572+
case f => f
573+
}
574+
}
575+
576+
// Avoid repeated evaluation since `scale` is a constant int,
577+
// avoid unnecessary `child` evaluation in both codegen and non-codegen eval
578+
// by checking if scaleV == null as well.
579+
private lazy val scaleV: Any = scale.eval(EmptyRow)
580+
private lazy val _scale: Int = scaleV.asInstanceOf[Int]
581+
582+
override def eval(input: InternalRow): Any = {
583+
if (scaleV == null) { // if scale is null, no need to eval its child at all
584+
null
585+
} else {
586+
val evalE = child.eval(input)
587+
if (evalE == null) {
588+
null
589+
} else {
590+
nullSafeEval(evalE)
591+
}
592+
}
593+
}
594+
595+
// not overriding since _scale is a constant int at runtime
596+
def nullSafeEval(input1: Any): Any = {
597+
child.dataType match {
598+
case _: DecimalType =>
599+
val decimal = input1.asInstanceOf[Decimal]
600+
if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
601+
case ByteType =>
602+
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte
603+
case ShortType =>
604+
BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort
605+
case IntegerType =>
606+
BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt
607+
case LongType =>
608+
BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong
609+
case FloatType =>
610+
val f = input1.asInstanceOf[Float]
611+
if (f.isNaN || f.isInfinite) {
612+
f
613+
} else {
614+
BigDecimal(f).setScale(_scale, HALF_UP).toFloat
615+
}
616+
case DoubleType =>
617+
val d = input1.asInstanceOf[Double]
618+
if (d.isNaN || d.isInfinite) {
619+
d
620+
} else {
621+
BigDecimal(d).setScale(_scale, HALF_UP).toDouble
622+
}
623+
}
624+
}
625+
626+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
627+
val ce = child.gen(ctx)
628+
629+
val evaluationCode = child.dataType match {
630+
case _: DecimalType =>
631+
s"""
632+
if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) {
633+
${ev.primitive} = ${ce.primitive};
634+
} else {
635+
${ev.isNull} = true;
636+
}"""
637+
case ByteType =>
638+
if (_scale < 0) {
639+
s"""
640+
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
641+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();"""
642+
} else {
643+
s"${ev.primitive} = ${ce.primitive};"
644+
}
645+
case ShortType =>
646+
if (_scale < 0) {
647+
s"""
648+
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
649+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();"""
650+
} else {
651+
s"${ev.primitive} = ${ce.primitive};"
652+
}
653+
case IntegerType =>
654+
if (_scale < 0) {
655+
s"""
656+
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
657+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();"""
658+
} else {
659+
s"${ev.primitive} = ${ce.primitive};"
660+
}
661+
case LongType =>
662+
if (_scale < 0) {
663+
s"""
664+
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
665+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();"""
666+
} else {
667+
s"${ev.primitive} = ${ce.primitive};"
668+
}
669+
case FloatType => // if child eval to NaN or Infinity, just return it.
670+
if (_scale == 0) {
671+
s"""
672+
if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){
673+
${ev.primitive} = ${ce.primitive};
674+
} else {
675+
${ev.primitive} = Math.round(${ce.primitive});
676+
}"""
677+
} else {
678+
s"""
679+
if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){
680+
${ev.primitive} = ${ce.primitive};
681+
} else {
682+
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
683+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
684+
}"""
685+
}
686+
case DoubleType => // if child eval to NaN or Infinity, just return it.
687+
if (_scale == 0) {
688+
s"""
689+
if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){
690+
${ev.primitive} = ${ce.primitive};
691+
} else {
692+
${ev.primitive} = Math.round(${ce.primitive});
693+
}"""
694+
} else {
695+
s"""
696+
if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){
697+
${ev.primitive} = ${ce.primitive};
698+
} else {
699+
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
700+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
701+
}"""
702+
}
703+
}
704+
705+
if (scaleV == null) { // if scale is null, no need to eval its child at all
706+
s"""
707+
boolean ${ev.isNull} = true;
708+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
709+
"""
710+
} else {
711+
s"""
712+
${ce.code}
713+
boolean ${ev.isNull} = ${ce.isNull};
714+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
715+
if (!${ev.isNull}) {
716+
$evaluationCode
717+
}
718+
"""
719+
}
720+
}
721+
722+
override def prettyName: String = "round"
723+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
5252
s"differing types in '${expr.prettyString}' (int and boolean)")
5353
}
5454

55+
def assertErrorWithImplicitCast(expr: Expression, errorMessage: String): Unit = {
56+
val e = intercept[AnalysisException] {
57+
assertSuccess(expr)
58+
}
59+
assert(e.getMessage.contains(errorMessage))
60+
}
61+
5562
test("check types for unary arithmetic") {
5663
assertError(UnaryMinus('stringField), "operator - accepts numeric type")
5764
assertError(Abs('stringField), "function abs accepts numeric type")
@@ -171,4 +178,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
171178
CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)),
172179
"Odd position only allow foldable and not-null StringType expressions")
173180
}
181+
182+
test("check types for ROUND") {
183+
assertErrorWithImplicitCast(Round(Literal(null), 'booleanField),
184+
"data type mismatch: argument 2 is expected to be of type int")
185+
assertErrorWithImplicitCast(Round(Literal(null), 'complexField),
186+
"data type mismatch: argument 2 is expected to be of type int")
187+
assertSuccess(Round(Literal(null), Literal(null)))
188+
assertError(Round('booleanField, 'intField),
189+
"data type mismatch: argument 1 is expected to be of type numeric")
190+
}
174191
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import scala.math.BigDecimal.RoundingMode
21+
2022
import com.google.common.math.LongMath
2123

2224
import org.apache.spark.SparkFunSuite
@@ -336,4 +338,46 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
336338
null,
337339
create_row(null))
338340
}
341+
342+
test("round") {
343+
val domain = -6 to 6
344+
val doublePi: Double = math.Pi
345+
val shortPi: Short = 31415
346+
val intPi: Int = 314159265
347+
val longPi: Long = 31415926535897932L
348+
val bdPi: BigDecimal = BigDecimal(31415927L, 7)
349+
350+
val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142,
351+
3.1416, 3.14159, 3.141593)
352+
353+
val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++
354+
Seq.fill[Short](7)(31415)
355+
356+
val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300,
357+
314159270) ++ Seq.fill(7)(314159265)
358+
359+
val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L,
360+
31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++
361+
Seq.fill(7)(31415926535897932L)
362+
363+
val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
364+
BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
365+
BigDecimal(3.141593), BigDecimal(3.1415927))
366+
367+
domain.zipWithIndex.foreach { case (scale, i) =>
368+
checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow)
369+
checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
370+
checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
371+
checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
372+
}
373+
374+
// round_scale > current_scale would result in precision increase
375+
// and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
376+
(0 to 7).foreach { i =>
377+
checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow)
378+
}
379+
(8 to 10).foreach { scale =>
380+
checkEvaluation(Round(bdPi, scale), null, EmptyRow)
381+
}
382+
}
339383
}

sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ import org.scalatest.PrivateMethodTester
2424
import scala.language.postfixOps
2525

2626
class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
27-
test("creating decimals") {
28-
/** Check that a Decimal has the given string representation, precision and scale */
29-
def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
30-
assert(d.toString === string)
31-
assert(d.precision === precision)
32-
assert(d.scale === scale)
33-
}
27+
/** Check that a Decimal has the given string representation, precision and scale */
28+
private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
29+
assert(d.toString === string)
30+
assert(d.precision === precision)
31+
assert(d.scale === scale)
32+
}
3433

34+
test("creating decimals") {
3535
checkDecimal(new Decimal(), "0", 1, 0)
3636
checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3)
3737
checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1)
@@ -53,6 +53,15 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
5353
intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0))
5454
}
5555

56+
test("creating decimals with negative scale") {
57+
checkDecimal(Decimal(BigDecimal("98765"), 5, -3), "9.9E+4", 5, -3)
58+
checkDecimal(Decimal(BigDecimal("314.159"), 6, -2), "3E+2", 6, -2)
59+
checkDecimal(Decimal(BigDecimal(1.579e12), 4, -9), "1.579E+12", 4, -9)
60+
checkDecimal(Decimal(BigDecimal(1.579e12), 4, -10), "1.58E+12", 4, -10)
61+
checkDecimal(Decimal(103050709L, 9, -10), "1.03050709E+18", 9, -10)
62+
checkDecimal(Decimal(1e8.toLong, 10, -10), "1.00000000E+18", 10, -10)
63+
}
64+
5665
test("double and long values") {
5766
/** Check that a Decimal converts to the given double and long values */
5867
def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = {

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,38 @@ object functions {
13891389
*/
13901390
def rint(columnName: String): Column = rint(Column(columnName))
13911391

1392+
/**
1393+
* Returns the value of the column `e` rounded to 0 decimal places.
1394+
*
1395+
* @group math_funcs
1396+
* @since 1.5.0
1397+
*/
1398+
def round(e: Column): Column = round(e.expr, 0)
1399+
1400+
/**
1401+
* Returns the value of the given column rounded to 0 decimal places.
1402+
*
1403+
* @group math_funcs
1404+
* @since 1.5.0
1405+
*/
1406+
def round(columnName: String): Column = round(Column(columnName), 0)
1407+
1408+
/**
1409+
* Returns the value of `e` rounded to `scale` decimal places.
1410+
*
1411+
* @group math_funcs
1412+
* @since 1.5.0
1413+
*/
1414+
def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale))
1415+
1416+
/**
1417+
* Returns the value of the given column rounded to `scale` decimal places.
1418+
*
1419+
* @group math_funcs
1420+
* @since 1.5.0
1421+
*/
1422+
def round(columnName: String, scale: Int): Column = round(Column(columnName), scale)
1423+
13921424
/**
13931425
* Shift the the given value numBits left. If the given value is a long value, this function
13941426
* will return a long value else it will return an integer value.

0 commit comments

Comments
 (0)