-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-13332][SQL] Decimal datatype support for SQL pow #11212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -109,7 +109,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) | |
| abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) | ||
| extends BinaryExpression with Serializable with ImplicitCastInputTypes { | ||
|
|
||
| override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) | ||
| override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) | ||
|
|
||
| override def toString: String = s"$name($left, $right)" | ||
|
|
||
|
|
@@ -523,11 +523,45 @@ case class Atan2(left: Expression, right: Expression) | |
|
|
||
| case class Pow(left: Expression, right: Expression) | ||
| extends BinaryMathExpression(math.pow, "POWER") { | ||
| override def genCode(ctx: CodegenContext, ev: ExprCode): String = { | ||
| defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") | ||
| } | ||
| } | ||
| override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType) | ||
|
|
||
| override def dataType: DataType = (left.dataType, right.dataType) match { | ||
| case (dt: DecimalType, ByteType | ShortType | IntegerType) => dt | ||
| case _ => DoubleType | ||
| } | ||
|
|
||
| protected override def nullSafeEval(input1: Any, input2: Any): Any = | ||
| (left.dataType, right.dataType) match { | ||
| case (dt: DecimalType, ByteType) => | ||
| input1.asInstanceOf[Decimal].pow(input2.asInstanceOf[Byte]) | ||
| case (dt: DecimalType, ShortType) => | ||
| input1.asInstanceOf[Decimal].pow(input2.asInstanceOf[Short]) | ||
| case (dt: DecimalType, IntegerType) => | ||
| input1.asInstanceOf[Decimal].pow(input2.asInstanceOf[Int]) | ||
| case (dt: DecimalType, FloatType) => | ||
| math.pow(input1.asInstanceOf[Decimal].toDouble, input2.asInstanceOf[Float]) | ||
| case (dt: DecimalType, DoubleType) => | ||
| math.pow(input1.asInstanceOf[Decimal].toDouble, input2.asInstanceOf[Double]) | ||
| case (dt1: DecimalType, dt2: DecimalType) => | ||
| math.pow(input1.asInstanceOf[Decimal].toDouble, input2.asInstanceOf[Decimal].toDouble) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we cast the result of |
||
| case _ => | ||
| math.pow(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) | ||
| } | ||
|
|
||
| override def genCode(ctx: CodegenContext, ev: ExprCode): String = | ||
| (left.dataType, right.dataType) match { | ||
| case (dt: DecimalType, ByteType | ShortType | IntegerType) => | ||
| defineCodeGen(ctx, ev, (c1, c2) => s"$c1.pow($c2)") | ||
| case (dt1: DecimalType, dt2: DecimalType) => | ||
| defineCodeGen(ctx, ev, (c1, c2) => | ||
| s"java.lang.Math.pow($c1.toDouble(),$c2.toDouble())") | ||
| case (dt: DecimalType, _) => | ||
| defineCodeGen(ctx, ev, (c1, c2) => | ||
| s"java.lang.Math.pow($c1.toDouble(),$c2)") | ||
| case _ => | ||
| defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Bitwise unsigned left shift. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,10 +88,10 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { | |
| * @param expectNull Whether the given values should return null or not | ||
| * @param expectNaN Whether the given values should eval to NaN or not | ||
| */ | ||
| private def testBinary( | ||
| private def testBinary[T, U, V]( | ||
| c: (Expression, Expression) => Expression, | ||
| f: (Double, Double) => Double, | ||
| domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), | ||
| f: (T, U) => V, | ||
| domain: Iterable[(T, U)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), | ||
| expectNull: Boolean = false, expectNaN: Boolean = false): Unit = { | ||
| if (expectNull) { | ||
| domain.foreach { case (v1, v2) => | ||
|
|
@@ -103,8 +103,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { | |
| } | ||
| } else { | ||
| domain.foreach { case (v1, v2) => | ||
| checkEvaluation(c(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) | ||
| checkEvaluation(c(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow) | ||
| checkEvaluation(c(Literal(v1), Literal(v2)), f(v1, v2), EmptyRow) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. keep the test of |
||
| } | ||
| } | ||
| checkEvaluation(c(Literal.create(null, DoubleType), Literal(1.0)), null, create_row(null)) | ||
|
|
@@ -351,6 +350,20 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { | |
| } | ||
|
|
||
| test("pow") { | ||
| testBinary(Pow, (d: Decimal, n: Byte) => d.pow(n), | ||
| (-5 to 5).map(v => (Decimal(v * 1.0), v.toByte))) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe |
||
| testBinary(Pow, (d: Decimal, n: Short) => d.pow(n), | ||
| (-5 to 5).map(v => (Decimal(v * 1.0), v.toShort))) | ||
| testBinary(Pow, (d: Decimal, n: Int) => d.pow(n), | ||
| (-5 to 5).map(v => (Decimal(v * 1.0), v))) | ||
| testBinary(Pow, (d1: Decimal, d2: Float) => math.pow(d1.toDouble, d2), | ||
| (-5 to 5).map(v => (Decimal(v * 1.0), (v * 1.0).toFloat))) | ||
| testBinary(Pow, (d1: Decimal, d2: Double) => math.pow(d1.toDouble, d2), | ||
| (-5 to 5).map(v => (Decimal(v * 1.0), v * 1.0))) | ||
| testBinary(Pow, (d1: Decimal, d2: Decimal) => math.pow(d1.toDouble, d2.toDouble), | ||
| (-5 to 5).map(v => (Decimal(v * 1.0), Decimal(v * 1.0)))) | ||
| testBinary(Pow, (d1: Decimal, d2: Double) => math.pow(d1.toDouble, d2), | ||
| Seq((Decimal("-1.0"), 0.9)), expectNaN = true) | ||
| testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) | ||
| testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) | ||
| checkConsistencyBetweenInterpretedAndCodegen(Pow, DoubleType, DoubleType) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better add a comment that explains why we need to clone before write.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary? Seems like a really bad idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we'll call
changePrecisiononinputhere, which would affect the orignal data. I agree that this is a bad idea, maybe we need to propose a separate pr to work around this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As Adrian mentioned, we need a copy of input, otherwise
changePrecisionwould change the original input.In our case, this means
catalystValue(expected value) would be changed whencheckEvalutionWithUnsafeProjectionis invoked, and then all tests after checkEvalutionWithUnsafeProjection will fail.Does it make sense? Any suggestion is great helpful.