diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index e040ad0c0a6b..52f3af49c6c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -232,18 +232,20 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" } override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(DoubleType, DecimalType)) + Seq(TypeCollection(DoubleType, DecimalType, LongType)) protected override def nullSafeEval(input: Any): Any = child.dataType match { + case LongType => input.asInstanceOf[Long] case DoubleType => f(input.asInstanceOf[Double]).toLong - case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil + case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].ceil } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") - case DecimalType.Fixed(precision, scale) => + case DecimalType.Fixed(_, _) => defineCodeGen(ctx, ev, c => s"$c.ceil()") + case LongType => defineCodeGen(ctx, ev, c => s"$c") case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") } } @@ -347,18 +349,20 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO } override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(DoubleType, DecimalType)) + Seq(TypeCollection(DoubleType, DecimalType, LongType)) protected override def nullSafeEval(input: Any): Any = child.dataType match { + case LongType => input.asInstanceOf[Long] case DoubleType => f(input.asInstanceOf[Double]).toLong - case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor + case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].floor } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") - case DecimalType.Fixed(precision, scale) => + case DecimalType.Fixed(_, _) => defineCodeGen(ctx, ev, c => s"$c.floor()") + case LongType => defineCodeGen(ctx, ev, c => s"$c") case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 1555dd1cf58d..69ada8216515 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -252,6 +252,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0)) + + val doublePi: Double = 3.1415 + val floatPi: Float = 3.1415f + val longLit: Long = 12345678901234567L + checkEvaluation(Ceil(doublePi), 4L, EmptyRow) + checkEvaluation(Ceil(floatPi.toDouble), 4L, EmptyRow) + checkEvaluation(Ceil(longLit), longLit, EmptyRow) + checkEvaluation(Ceil(-doublePi), -3L, EmptyRow) + checkEvaluation(Ceil(-floatPi.toDouble), -3L, EmptyRow) + checkEvaluation(Ceil(-longLit), -longLit, EmptyRow) } test("floor") { @@ -262,6 +272,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0)) + + val doublePi: Double = 3.1415 + val floatPi: Float = 3.1415f + val longLit: Long = 12345678901234567L + checkEvaluation(Floor(doublePi), 3L, EmptyRow) + checkEvaluation(Floor(floatPi.toDouble), 3L, EmptyRow) + checkEvaluation(Floor(longLit), longLit, EmptyRow) + checkEvaluation(Floor(-doublePi), -4L, EmptyRow) + checkEvaluation(Floor(-floatPi.toDouble), -4L, EmptyRow) + checkEvaluation(Floor(-longLit), -longLit, EmptyRow) } test("factorial") {