diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5fb3369f85d1..988c3e02faf7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -124,6 +124,7 @@ object FunctionRegistry { expression[Pow]("power"), expression[UnaryPositive]("positive"), expression[Rint]("rint"), + expression[Round]("round"), expression[Signum]("sign"), expression[Signum]("signum"), expression[Sin]("sin"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 250564dc4b81..6b1f003775ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -18,12 +18,17 @@ package org.apache.spark.sql.catalyst.expressions import java.lang.{Long => JLong} +import java.math.RoundingMode import org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.types._ + /** * A leaf expression specifically for math constants. Math constants expect no input. * @param c The math constant. @@ -312,3 +317,77 @@ case class Logarithm(left: Expression, right: Expression) """ } } + +case class Round(valueExpr: Expression, scaleExpr: Expression) + extends Expression with trees.BinaryNode[Expression] { + + def this(left: Expression) = { + this(left, Literal(0)) + } + + override def nullable: Boolean = valueExpr.nullable || scaleExpr.nullable + + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + if ((valueExpr.dataType.isInstanceOf[NumericType] || valueExpr.dataType.isInstanceOf[NullType]) + && + (scaleExpr.dataType.isInstanceOf[IntegerType] || scaleExpr.dataType.isInstanceOf[NullType])) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"round accepts numeric types as the value and integer type as the scale") + } + } + + override def toString: String = s"round($valueExpr, $scaleExpr)" + + override def dataType: DataType = valueExpr.dataType + + override def eval(input: InternalRow): Any = { + val valueEval = valueExpr.eval(input) + val scaleEval = scaleExpr.eval(input) + if (valueEval == null || scaleEval == null) { + null + } else { + dataType match { + case _: DecimalType => + round(valueEval.asInstanceOf[Decimal], scaleEval.asInstanceOf[Integer]) + case FloatType => + round(valueEval.asInstanceOf[Float].toDouble, + scaleEval.asInstanceOf[Integer]).floatValue() + case DoubleType => + round(valueEval.asInstanceOf[Double], scaleEval.asInstanceOf[Integer]).doubleValue() + case LongType => + round(valueEval.asInstanceOf[Long], scaleEval.asInstanceOf[Integer]).longValue() + case IntegerType => + round(valueEval.asInstanceOf[Integer].toLong, scaleEval.asInstanceOf[Integer]).intValue() + case ShortType => + round(valueEval.asInstanceOf[Short].toLong, scaleEval.asInstanceOf[Integer]).shortValue() + case ByteType => + round(valueEval.asInstanceOf[Byte].toLong, scaleEval.asInstanceOf[Integer]).byteValue() + } + } + } + + private def round(value: Long, scale: Int): BigDecimal = { + java.math.BigDecimal.valueOf(value).setScale(scale, RoundingMode.HALF_UP) + } + + private def round(value: Double, scale: Int): BigDecimal = { + if (java.lang.Double.isNaN(value) || java.lang.Double.isInfinite(value)) { + value + } else { + java.math.BigDecimal.valueOf(value).setScale(scale, RoundingMode.HALF_UP) + } + } + + private def round(value: Decimal, scale: Int): Decimal = { + value.set(value.toBigDecimal, value.precision, scale.asInstanceOf[Integer]) + } + + override def left: Expression = valueExpr + + override def right: Expression = scaleExpr +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 0d1d5ebdff2d..f7874c1eb948 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{DataType, DoubleType, LongType} +import org.apache.spark.sql.types.{DataType, LongType} +import org.apache.spark.sql.types.{Decimal, DoubleType} class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -252,4 +252,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { null, create_row(null)) } + + test("round") { + checkEvaluation(Round(Literal(Decimal(1.26)), Literal(1)), Decimal(1.3, 3, 1)) + checkEvaluation(Round(Literal(1.23D), Literal(1)), 1.2) + checkEvaluation(Round(Literal(1.25D), Literal(1)), 1.3) + checkEvaluation(Round(Literal(1.5F), Literal(0)), 2.0F) + checkEvaluation(Round(Literal(1.toShort), Literal(0)), 1.toShort) + checkEvaluation(Round(Literal(2.toByte), Literal(0)), 2.toByte) + checkEvaluation(Round(Literal(9223372036854775807L), 0), 9223372036854775807L) + checkEvaluation(Round(Literal(123), Literal(0)), 123) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 38d9085a505f..0a97bd614dd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1282,6 +1282,38 @@ object functions { */ def rint(columnName: String): Column = rint(Column(columnName)) + /** + * Computes rounded value of the given input. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column): Column = round(e.expr, 0) + + /** + * Computes rounded value of the given input. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(columnName: String): Column = round(Column(columnName)) + + /** + * Computes rounded value of the given input. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column, scale: Int): Column = Round(e.expr, lit(scale).expr) + + /** + * Computes rounded value of the given input. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(columnName: String, scale: Int): Column = round(Column(columnName), scale) + /** * Computes the signum of the given value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 2768d7dfc803..e02084a13ab3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{log => logarithm} +import org.apache.spark.sql.types.{Decimal, DecimalType, DoubleType} private object MathExpressionsTestData { @@ -292,4 +294,19 @@ class MathExpressionsSuite extends QueryTest { checkAnswer(df.selectExpr("positive(b)"), Row(-1)) checkAnswer(df.selectExpr("positive(c)"), Row("abc")) } + + test("round") { + val df = Seq((1.53, 0.62, 12345L, 0.67.toFloat)).toDF("a", "b", "c", "d") + checkAnswer(df.select(round('a)), Row(2.0)) + checkAnswer(df.select(round('b, 1)), Row(0.6)) + checkAnswer(df.selectExpr("round(a)"), Row(2)) + checkAnswer(df.selectExpr("round(b, 1)"), Row(0.6)) + checkAnswer(df.selectExpr("round(c, 1)"), Row(12345L)) + checkAnswer(df.selectExpr("round(d, 1)"), Row(0.7f)) + checkAnswer(df.selectExpr("round(null)"), Row(null)) + checkAnswer(df.selectExpr("round(null, 1)"), Row(null)) + checkAnswer(df.selectExpr("round(145.23, -1)"), Row(150.0)) // same as hive + checkAnswer(df.selectExpr("round(20, 1)"), Row(20)) + checkAnswer(df.selectExpr("round(1.0/0.0, 1)"), Row(null)) + } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index f88e62763ca7..c2e7701c9c2d 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -919,7 +919,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_repeat", "udf_rlike", "udf_round", - // "udf_round_3", TODO: FIX THIS failed due to cast exception + "udf_round_3", "udf_rpad", "udf_rtrim", "udf_second",