From cce48076151403cf456493e5b3756bc7aa6e1713 Mon Sep 17 00:00:00 2001 From: Yucai Yu Date: Tue, 2 Feb 2016 10:52:06 +0800 Subject: [PATCH 1/2] Decimal support for pow --- .../expressions/codegen/UnsafeRowWriter.java | 1 + .../expressions/mathExpressions.scala | 44 ++++++++++++++++--- .../org/apache/spark/sql/types/Decimal.scala | 2 + .../expressions/MathFunctionsSuite.scala | 35 ++++++++++----- 4 files changed, 66 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 4776617043878..5bb0800537264 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -170,6 +170,7 @@ public void write(int ordinal, double value) { } public void write(int ordinal, Decimal input, int precision, int scale) { + input = input.clone(); if (precision <= Decimal.MAX_LONG_DIGITS()) { // make sure Decimal object has the same scale as DecimalType if (input.changePrecision(precision, scale)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 8b9a60f97ce6e..e5d834de9b39d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -109,7 +109,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) extends BinaryExpression with Serializable with ImplicitCastInputTypes { - override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) override def toString: String = s"$name($left, $right)" @@ -523,11 +523,45 @@ case class Atan2(left: Expression, right: Expression) case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") - } -} + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType) + + override def dataType: DataType = (left.dataType, right.dataType) match { + case (dt: DecimalType, ByteType | ShortType | IntegerType) => dt + case _ => DoubleType + } + + protected override def nullSafeEval(input1: Any, input2: Any): Any = + (left.dataType, right.dataType) match { + case (dt: DecimalType, ByteType) => + input1.asInstanceOf[Decimal].pow(input2.asInstanceOf[Byte]) + case (dt: DecimalType, ShortType) => + input1.asInstanceOf[Decimal].pow(input2.asInstanceOf[Short]) + case (dt: DecimalType, IntegerType) => + input1.asInstanceOf[Decimal].pow(input2.asInstanceOf[Int]) + case (dt: DecimalType, FloatType) => + math.pow(input1.asInstanceOf[Decimal].toDouble, input2.asInstanceOf[Float]) + case (dt: DecimalType, DoubleType) => + math.pow(input1.asInstanceOf[Decimal].toDouble, input2.asInstanceOf[Double]) + case (dt1: DecimalType, dt2: DecimalType) => + math.pow(input1.asInstanceOf[Decimal].toDouble, input2.asInstanceOf[Decimal].toDouble) + case _ => + math.pow(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) + } + override def genCode(ctx: CodegenContext, ev: ExprCode): String = + (left.dataType, right.dataType) match { + case (dt: DecimalType, ByteType | ShortType | IntegerType) => + defineCodeGen(ctx, ev, (c1, c2) => s"$c1.pow($c2)") + case (dt1: DecimalType, dt2: DecimalType) => + defineCodeGen(ctx, ev, (c1, c2) => + s"java.lang.Math.pow($c1.toDouble(),$c2.toDouble())") + case (dt: DecimalType, _) => + defineCodeGen(ctx, ev, (c1, c2) => + s"java.lang.Math.pow($c1.toDouble(),$c2)") + case _ => + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + } +} /** * Bitwise unsigned left shift. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 38ce1604b1ede..e56dc9fa024c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -318,6 +318,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } + def pow(n: Int): Decimal = Decimal(toJavaBigDecimal.pow(n, MATH_CONTEXT)) + def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this def floor: Decimal = if (scale == 0) this else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 4ad65db0977c7..adaec1206b3bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -88,10 +88,10 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { * @param expectNull Whether the given values should return null or not * @param expectNaN Whether the given values should eval to NaN or not */ - private def testBinary( + private def testBinary[T, U, V]( c: (Expression, Expression) => Expression, - f: (Double, Double) => Double, - domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), + f: (T, U) => V, + domain: Iterable[(T, U)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), expectNull: Boolean = false, expectNaN: Boolean = false): Unit = { if (expectNull) { domain.foreach { case (v1, v2) => @@ -103,8 +103,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } else { domain.foreach { case (v1, v2) => - checkEvaluation(c(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) - checkEvaluation(c(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow) + checkEvaluation(c(Literal(v1), Literal(v2)), f(v1, v2), EmptyRow) } } checkEvaluation(c(Literal.create(null, DoubleType), Literal(1.0)), null, create_row(null)) @@ -154,6 +153,26 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) } + test("pow") { + testBinary(Pow, (d: Decimal, n: Byte) => d.pow(n), + (-5 to 5).map(v => (Decimal(v * 1.0), v.toByte))) + testBinary(Pow, (d: Decimal, n: Short) => d.pow(n), + (-5 to 5).map(v => (Decimal(v * 1.0), v.toShort))) + testBinary(Pow, (d: Decimal, n: Int) => d.pow(n), + (-5 to 5).map(v => (Decimal(v * 1.0), v))) + testBinary(Pow, (d1: Decimal, d2: Float) => math.pow(d1.toDouble, d2), + (-5 to 5).map(v => (Decimal(v * 1.0), (v * 1.0).toFloat))) + testBinary(Pow, (d1: Decimal, d2: Double) => math.pow(d1.toDouble, d2), + (-5 to 5).map(v => (Decimal(v * 1.0), v * 1.0))) + testBinary(Pow, (d1: Decimal, d2: Decimal) => math.pow(d1.toDouble, d2.toDouble), + (-5 to 5).map(v => (Decimal(v * 1.0), Decimal(v * 1.0)))) + testBinary(Pow, (d1: Decimal, d2: Double) => math.pow(d1.toDouble, d2), + Seq((Decimal("-1.0"), 0.9)), expectNaN = true) + testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) + testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Pow, DoubleType, DoubleType) + } + test("conv") { checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11") checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") @@ -350,12 +369,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Sqrt, DoubleType) } - test("pow") { - testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) - testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) - checkConsistencyBetweenInterpretedAndCodegen(Pow, DoubleType, DoubleType) - } - test("shift left") { checkEvaluation(ShiftLeft(Literal.create(null, IntegerType), Literal(1)), null) checkEvaluation(ShiftLeft(Literal(21), Literal.create(null, IntegerType)), null) From 08d64b1608bd50671f8d2d76b65e975a9072d95c Mon Sep 17 00:00:00 2001 From: Yucai Yu Date: Tue, 16 Feb 2016 10:51:27 +0800 Subject: [PATCH 2/2] relocation suite --- .../expressions/MathFunctionsSuite.scala | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index adaec1206b3bc..f81fb880c2b55 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -153,26 +153,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) } - test("pow") { - testBinary(Pow, (d: Decimal, n: Byte) => d.pow(n), - (-5 to 5).map(v => (Decimal(v * 1.0), v.toByte))) - testBinary(Pow, (d: Decimal, n: Short) => d.pow(n), - (-5 to 5).map(v => (Decimal(v * 1.0), v.toShort))) - testBinary(Pow, (d: Decimal, n: Int) => d.pow(n), - (-5 to 5).map(v => (Decimal(v * 1.0), v))) - testBinary(Pow, (d1: Decimal, d2: Float) => math.pow(d1.toDouble, d2), - (-5 to 5).map(v => (Decimal(v * 1.0), (v * 1.0).toFloat))) - testBinary(Pow, (d1: Decimal, d2: Double) => math.pow(d1.toDouble, d2), - (-5 to 5).map(v => (Decimal(v * 1.0), v * 1.0))) - testBinary(Pow, (d1: Decimal, d2: Decimal) => math.pow(d1.toDouble, d2.toDouble), - (-5 to 5).map(v => (Decimal(v * 1.0), Decimal(v * 1.0)))) - testBinary(Pow, (d1: Decimal, d2: Double) => math.pow(d1.toDouble, d2), - Seq((Decimal("-1.0"), 0.9)), expectNaN = true) - testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) - testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) - checkConsistencyBetweenInterpretedAndCodegen(Pow, DoubleType, DoubleType) - } - test("conv") { checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11") checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") @@ -369,6 +349,26 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Sqrt, DoubleType) } + test("pow") { + testBinary(Pow, (d: Decimal, n: Byte) => d.pow(n), + (-5 to 5).map(v => (Decimal(v * 1.0), v.toByte))) + testBinary(Pow, (d: Decimal, n: Short) => d.pow(n), + (-5 to 5).map(v => (Decimal(v * 1.0), v.toShort))) + testBinary(Pow, (d: Decimal, n: Int) => d.pow(n), + (-5 to 5).map(v => (Decimal(v * 1.0), v))) + testBinary(Pow, (d1: Decimal, d2: Float) => math.pow(d1.toDouble, d2), + (-5 to 5).map(v => (Decimal(v * 1.0), (v * 1.0).toFloat))) + testBinary(Pow, (d1: Decimal, d2: Double) => math.pow(d1.toDouble, d2), + (-5 to 5).map(v => (Decimal(v * 1.0), v * 1.0))) + testBinary(Pow, (d1: Decimal, d2: Decimal) => math.pow(d1.toDouble, d2.toDouble), + (-5 to 5).map(v => (Decimal(v * 1.0), Decimal(v * 1.0)))) + testBinary(Pow, (d1: Decimal, d2: Double) => math.pow(d1.toDouble, d2), + Seq((Decimal("-1.0"), 0.9)), expectNaN = true) + testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) + testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Pow, DoubleType, DoubleType) + } + test("shift left") { checkEvaluation(ShiftLeft(Literal.create(null, IntegerType), Literal(1)), null) checkEvaluation(ShiftLeft(Literal(21), Literal.create(null, IntegerType)), null)