Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ public void write(int ordinal, double value) {
}

public void write(int ordinal, Decimal input, int precision, int scale) {
input = input.clone();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better add a comment that explains why we need to clone before write.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary? Seems like a really bad idea.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we'll call changePrecision on input here, which would affect the orignal data. I agree that this is a bad idea, maybe we need to propose a separate pr to work around this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As Adrian mentioned, we need a copy of input, otherwise changePrecision would change the original input.
In our case, this means catalystValue(expected value) would be changed when checkEvalutionWithUnsafeProjection is invoked, and then all tests after checkEvalutionWithUnsafeProjection will fail.

  protected def checkEvaluation(
      expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
    val catalystValue = CatalystTypeConverters.convertToCatalyst(expected)
    checkEvaluationWithoutCodegen(expression, catalystValue, inputRow)
    checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow)
    if (GenerateUnsafeProjection.canSupport(expression.dataType)) {
      checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow)
    }
    checkEvaluationWithOptimization(expression, catalystValue, inputRow)
  }

Does it make sense? Any suggestion is great helpful.

if (precision <= Decimal.MAX_LONG_DIGITS()) {
// make sure Decimal object has the same scale as DecimalType
if (input.changePrecision(precision, scale)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String)
abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
extends BinaryExpression with Serializable with ImplicitCastInputTypes {

override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)

override def toString: String = s"$name($left, $right)"

Expand Down Expand Up @@ -523,11 +523,45 @@ case class Atan2(left: Expression, right: Expression)

case class Pow(left: Expression, right: Expression)
extends BinaryMathExpression(math.pow, "POWER") {
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)")
}
}
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType)

override def dataType: DataType = (left.dataType, right.dataType) match {
case (dt: DecimalType, ByteType | ShortType | IntegerType) => dt
case _ => DoubleType
}

protected override def nullSafeEval(input1: Any, input2: Any): Any =
(left.dataType, right.dataType) match {
case (dt: DecimalType, ByteType) =>
input1.asInstanceOf[Decimal].pow(input2.asInstanceOf[Byte])
case (dt: DecimalType, ShortType) =>
input1.asInstanceOf[Decimal].pow(input2.asInstanceOf[Short])
case (dt: DecimalType, IntegerType) =>
input1.asInstanceOf[Decimal].pow(input2.asInstanceOf[Int])
case (dt: DecimalType, FloatType) =>
math.pow(input1.asInstanceOf[Decimal].toDouble, input2.asInstanceOf[Float])
case (dt: DecimalType, DoubleType) =>
math.pow(input1.asInstanceOf[Decimal].toDouble, input2.asInstanceOf[Double])
case (dt1: DecimalType, dt2: DecimalType) =>
math.pow(input1.asInstanceOf[Decimal].toDouble, input2.asInstanceOf[Decimal].toDouble)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we cast the result of math.pow back to DecimalType for these three cases?

case _ =>
math.pow(input1.asInstanceOf[Double], input2.asInstanceOf[Double])
}

override def genCode(ctx: CodegenContext, ev: ExprCode): String =
(left.dataType, right.dataType) match {
case (dt: DecimalType, ByteType | ShortType | IntegerType) =>
defineCodeGen(ctx, ev, (c1, c2) => s"$c1.pow($c2)")
case (dt1: DecimalType, dt2: DecimalType) =>
defineCodeGen(ctx, ev, (c1, c2) =>
s"java.lang.Math.pow($c1.toDouble(),$c2.toDouble())")
case (dt: DecimalType, _) =>
defineCodeGen(ctx, ev, (c1, c2) =>
s"java.lang.Math.pow($c1.toDouble(),$c2)")
case _ =>
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)")
}
}

/**
* Bitwise unsigned left shift.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
}

def pow(n: Int): Decimal = Decimal(toJavaBigDecimal.pow(n, MATH_CONTEXT))

def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this

def floor: Decimal = if (scale == 0) this else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
* @param expectNull Whether the given values should return null or not
* @param expectNaN Whether the given values should eval to NaN or not
*/
private def testBinary(
private def testBinary[T, U, V](
c: (Expression, Expression) => Expression,
f: (Double, Double) => Double,
domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)),
f: (T, U) => V,
domain: Iterable[(T, U)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)),
expectNull: Boolean = false, expectNaN: Boolean = false): Unit = {
if (expectNull) {
domain.foreach { case (v1, v2) =>
Expand All @@ -103,8 +103,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
} else {
domain.foreach { case (v1, v2) =>
checkEvaluation(c(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow)
checkEvaluation(c(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow)
checkEvaluation(c(Literal(v1), Literal(v2)), f(v1, v2), EmptyRow)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep the test of f(v2, v1)

}
}
checkEvaluation(c(Literal.create(null, DoubleType), Literal(1.0)), null, create_row(null))
Expand Down Expand Up @@ -351,6 +350,20 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("pow") {
testBinary(Pow, (d: Decimal, n: Byte) => d.pow(n),
(-5 to 5).map(v => (Decimal(v * 1.0), v.toByte)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe v.toDouble is better

testBinary(Pow, (d: Decimal, n: Short) => d.pow(n),
(-5 to 5).map(v => (Decimal(v * 1.0), v.toShort)))
testBinary(Pow, (d: Decimal, n: Int) => d.pow(n),
(-5 to 5).map(v => (Decimal(v * 1.0), v)))
testBinary(Pow, (d1: Decimal, d2: Float) => math.pow(d1.toDouble, d2),
(-5 to 5).map(v => (Decimal(v * 1.0), (v * 1.0).toFloat)))
testBinary(Pow, (d1: Decimal, d2: Double) => math.pow(d1.toDouble, d2),
(-5 to 5).map(v => (Decimal(v * 1.0), v * 1.0)))
testBinary(Pow, (d1: Decimal, d2: Decimal) => math.pow(d1.toDouble, d2.toDouble),
(-5 to 5).map(v => (Decimal(v * 1.0), Decimal(v * 1.0))))
testBinary(Pow, (d1: Decimal, d2: Double) => math.pow(d1.toDouble, d2),
Seq((Decimal("-1.0"), 0.9)), expectNaN = true)
testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0)))
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true)
checkConsistencyBetweenInterpretedAndCodegen(Pow, DoubleType, DoubleType)
Expand Down