Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ object FunctionRegistry {
expression[Pow]("power"),
expression[UnaryPositive]("positive"),
expression[Rint]("rint"),
expression[Round]("round"),
expression[Signum]("sign"),
expression[Signum]("signum"),
expression[Sin]("sin"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@
package org.apache.spark.sql.catalyst.expressions

import java.lang.{Long => JLong}
import java.math.RoundingMode

import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String

import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types._

/**
* A leaf expression specifically for math constants. Math constants expect no input.
* @param c The math constant.
Expand Down Expand Up @@ -312,3 +317,77 @@ case class Logarithm(left: Expression, right: Expression)
"""
}
}

case class Round(valueExpr: Expression, scaleExpr: Expression)
extends Expression with trees.BinaryNode[Expression] {

def this(left: Expression) = {
this(left, Literal(0))
}

override def nullable: Boolean = valueExpr.nullable || scaleExpr.nullable

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

override def checkInputDataTypes(): TypeCheckResult = {
if ((valueExpr.dataType.isInstanceOf[NumericType] || valueExpr.dataType.isInstanceOf[NullType])
&&
(scaleExpr.dataType.isInstanceOf[IntegerType] || scaleExpr.dataType.isInstanceOf[NullType])) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
s"round accepts numeric types as the value and integer type as the scale")
}
}

override def toString: String = s"round($valueExpr, $scaleExpr)"

override def dataType: DataType = valueExpr.dataType

override def eval(input: InternalRow): Any = {
val valueEval = valueExpr.eval(input)
val scaleEval = scaleExpr.eval(input)
if (valueEval == null || scaleEval == null) {
null
} else {
dataType match {
case _: DecimalType =>
round(valueEval.asInstanceOf[Decimal], scaleEval.asInstanceOf[Integer])
case FloatType =>
round(valueEval.asInstanceOf[Float].toDouble,
scaleEval.asInstanceOf[Integer]).floatValue()
case DoubleType =>
round(valueEval.asInstanceOf[Double], scaleEval.asInstanceOf[Integer]).doubleValue()
case LongType =>
round(valueEval.asInstanceOf[Long], scaleEval.asInstanceOf[Integer]).longValue()
case IntegerType =>
round(valueEval.asInstanceOf[Integer].toLong, scaleEval.asInstanceOf[Integer]).intValue()
case ShortType =>
round(valueEval.asInstanceOf[Short].toLong, scaleEval.asInstanceOf[Integer]).shortValue()
case ByteType =>
round(valueEval.asInstanceOf[Byte].toLong, scaleEval.asInstanceOf[Integer]).byteValue()
}
}
}

private def round(value: Long, scale: Int): BigDecimal = {
java.math.BigDecimal.valueOf(value).setScale(scale, RoundingMode.HALF_UP)
}

private def round(value: Double, scale: Int): BigDecimal = {
if (java.lang.Double.isNaN(value) || java.lang.Double.isInfinite(value)) {
value
} else {
java.math.BigDecimal.valueOf(value).setScale(scale, RoundingMode.HALF_UP)
}
}

private def round(value: Decimal, scale: Int): Decimal = {
value.set(value.toBigDecimal, value.precision, scale.asInstanceOf[Integer])
}

override def left: Expression = valueExpr

override def right: Expression = scaleExpr
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.{DataType, DoubleType, LongType}
import org.apache.spark.sql.types.{DataType, LongType}
import org.apache.spark.sql.types.{Decimal, DoubleType}

class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

Expand Down Expand Up @@ -252,4 +252,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
null,
create_row(null))
}

test("round") {
checkEvaluation(Round(Literal(Decimal(1.26)), Literal(1)), Decimal(1.3, 3, 1))
checkEvaluation(Round(Literal(1.23D), Literal(1)), 1.2)
checkEvaluation(Round(Literal(1.25D), Literal(1)), 1.3)
checkEvaluation(Round(Literal(1.5F), Literal(0)), 2.0F)
checkEvaluation(Round(Literal(1.toShort), Literal(0)), 1.toShort)
checkEvaluation(Round(Literal(2.toByte), Literal(0)), 2.toByte)
checkEvaluation(Round(Literal(9223372036854775807L), 0), 9223372036854775807L)
checkEvaluation(Round(Literal(123), Literal(0)), 123)
}
}
32 changes: 32 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,38 @@ object functions {
*/
def rint(columnName: String): Column = rint(Column(columnName))

/**
* Computes rounded value of the given input.
*
* @group math_funcs
* @since 1.5.0
*/
def round(e: Column): Column = round(e.expr, 0)

/**
* Computes rounded value of the given input.
*
* @group math_funcs
* @since 1.5.0
*/
def round(columnName: String): Column = round(Column(columnName))

/**
* Computes rounded value of the given input.
*
* @group math_funcs
* @since 1.5.0
*/
def round(e: Column, scale: Int): Column = Round(e.expr, lit(scale).expr)

/**
* Computes rounded value of the given input.
*
* @group math_funcs
* @since 1.5.0
*/
def round(columnName: String, scale: Int): Column = round(Column(columnName), scale)

/**
* Computes the signum of the given value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions.{log => logarithm}
import org.apache.spark.sql.types.{Decimal, DecimalType, DoubleType}


private object MathExpressionsTestData {
Expand Down Expand Up @@ -292,4 +294,19 @@ class MathExpressionsSuite extends QueryTest {
checkAnswer(df.selectExpr("positive(b)"), Row(-1))
checkAnswer(df.selectExpr("positive(c)"), Row("abc"))
}

test("round") {
val df = Seq((1.53, 0.62, 12345L, 0.67.toFloat)).toDF("a", "b", "c", "d")
checkAnswer(df.select(round('a)), Row(2.0))
checkAnswer(df.select(round('b, 1)), Row(0.6))
checkAnswer(df.selectExpr("round(a)"), Row(2))
checkAnswer(df.selectExpr("round(b, 1)"), Row(0.6))
checkAnswer(df.selectExpr("round(c, 1)"), Row(12345L))
checkAnswer(df.selectExpr("round(d, 1)"), Row(0.7f))
checkAnswer(df.selectExpr("round(null)"), Row(null))
checkAnswer(df.selectExpr("round(null, 1)"), Row(null))
checkAnswer(df.selectExpr("round(145.23, -1)"), Row(150.0)) // same as hive
checkAnswer(df.selectExpr("round(20, 1)"), Row(20))
checkAnswer(df.selectExpr("round(1.0/0.0, 1)"), Row(null))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_repeat",
"udf_rlike",
"udf_round",
// "udf_round_3", TODO: FIX THIS failed due to cast exception
"udf_round_3",
"udf_rpad",
"udf_rtrim",
"udf_second",
Expand Down