Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ object FunctionRegistry {
expression[Log]("ln"),
expression[Log10]("log10"),
expression[Log1p]("log1p"),
expression[Log2]("log2"),
expression[UnaryMinus]("negative"),
expression[Pi]("pi"),
expression[Log2]("log2"),
expression[Pow]("pow"),
expression[Pow]("power"),
expression[Pmod]("pmod"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,38 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
override def toString: String = s"$name($child)"

protected override def nullSafeEval(input: Any): Any = {
val result = f(input.asInstanceOf[Double])
if (result.isNaN) null else result
f(input.asInstanceOf[Double])
}

// name of function in java.lang.Math
def funcName: String = name.toLowerCase

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, eval => {
defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)")
}
}

abstract class UnaryLogExpression(f: Double => Double, name: String)
extends UnaryMathExpression(f, name) { self: Product =>

// values less than or equal to yAsymptote eval to null in Hive, instead of NaN or -Infinity
protected val yAsymptote: Double = 0.0

protected override def nullSafeEval(input: Any): Any = {
val d = input.asInstanceOf[Double]
if (d <= yAsymptote) null else f(d)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, c =>
s"""
${ev.primitive} = java.lang.Math.${funcName}($eval);
if (Double.valueOf(${ev.primitive}).isNaN()) {
if ($c <= $yAsymptote) {
${ev.isNull} = true;
} else {
${ev.primitive} = java.lang.Math.${funcName}($c);
}
"""
})
)
}
}

Expand All @@ -101,8 +117,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
override def dataType: DataType = DoubleType

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val result = f(input1.asInstanceOf[Double], input2.asInstanceOf[Double])
if (result.isNaN) null else result
f(input1.asInstanceOf[Double], input2.asInstanceOf[Double])
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
Expand Down Expand Up @@ -399,25 +414,28 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas
}
}

case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG")
case class Log(child: Expression) extends UnaryLogExpression(math.log, "LOG")

case class Log2(child: Expression)
extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") {
extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, eval => {
nullSafeCodeGen(ctx, ev, c =>
s"""
${ev.primitive} = java.lang.Math.log($eval) / java.lang.Math.log(2);
if (Double.valueOf(${ev.primitive}).isNaN()) {
if ($c <= $yAsymptote) {
${ev.isNull} = true;
} else {
${ev.primitive} = java.lang.Math.log($c) / java.lang.Math.log(2);
}
"""
})
)
}
}

case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10")
case class Log10(child: Expression) extends UnaryLogExpression(math.log10, "LOG10")

case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P")
case class Log1p(child: Expression) extends UnaryLogExpression(math.log1p, "LOG1P") {
protected override val yAsymptote: Double = -1.0
}

case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") {
override def funcName: String = "rint"
Expand Down Expand Up @@ -578,27 +596,18 @@ case class Atan2(left: Expression, right: Expression)

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
// With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0
val result = math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0)
if (result.isNaN) null else result
math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need + 0.0 here, math.atan2 is calling java.lang.Math.atan2 inside

}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s"""
if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
}
"""
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)")
}
}

case class Pow(left: Expression, right: Expression)
extends BinaryMathExpression(math.pow, "POWER") {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s"""
if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
}
"""
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)")
}
}

Expand Down Expand Up @@ -700,17 +709,33 @@ case class Logarithm(left: Expression, right: Expression)
this(EulerNumber(), child)
}

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val dLeft = input1.asInstanceOf[Double]
val dRight = input2.asInstanceOf[Double]
// Unlike Hive, we support Log base in (0.0, 1.0]
if (dLeft <= 0.0 || dRight <= 0.0) null else math.log(dRight) / math.log(dLeft)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val logCode = if (left.isInstanceOf[EulerNumber]) {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)")
if (left.isInstanceOf[EulerNumber]) {
nullSafeCodeGen(ctx, ev, (c1, c2) =>
s"""
if ($c2 <= 0.0) {
${ev.isNull} = true;
} else {
${ev.primitive} = java.lang.Math.log($c2);
}
""")
} else {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)")
nullSafeCodeGen(ctx, ev, (c1, c2) =>
s"""
if ($c1 <= 0.0 || $c2 <= 0.0) {
${ev.isNull} = true;
} else {
${ev.primitive} = java.lang.Math.log($c2) / java.lang.Math.log($c1);
}
""")
}
logCode + s"""
if (Double.isNaN(${ev.primitive})) {
${ev.isNull} = true;
}
"""
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import com.google.common.math.LongMath

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
import org.apache.spark.sql.types._


Expand All @@ -47,6 +51,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
* @param f The functions in scala.math or elsewhere used to generate expected results
* @param domain The set of values to run the function with
* @param expectNull Whether the given values should return null or not
* @param expectNaN Whether the given values should eval to NaN or not
* @tparam T Generic type for primitives
* @tparam U Generic type for the output of the given function `f`
*/
Expand All @@ -55,11 +60,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
f: T => U,
domain: Iterable[T] = (-20 to 20).map(_ * 0.1),
expectNull: Boolean = false,
expectNaN: Boolean = false,
evalType: DataType = DoubleType): Unit = {
if (expectNull) {
domain.foreach { value =>
checkEvaluation(c(Literal(value)), null, EmptyRow)
}
} else if (expectNaN) {
domain.foreach { value =>
checkNaN(c(Literal(value)), EmptyRow)
}
} else {
domain.foreach { value =>
checkEvaluation(c(Literal(value)), f(value), EmptyRow)
Expand All @@ -74,16 +84,22 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
* @param c The DataFrame function
* @param f The functions in scala.math
* @param domain The set of values to run the function with
* @param expectNull Whether the given values should return null or not
* @param expectNaN Whether the given values should eval to NaN or not
*/
private def testBinary(
c: (Expression, Expression) => Expression,
f: (Double, Double) => Double,
domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)),
expectNull: Boolean = false): Unit = {
expectNull: Boolean = false, expectNaN: Boolean = false): Unit = {
if (expectNull) {
domain.foreach { case (v1, v2) =>
checkEvaluation(c(Literal(v1), Literal(v2)), null, create_row(null))
}
} else if (expectNaN) {
domain.foreach { case (v1, v2) =>
checkNaN(c(Literal(v1), Literal(v2)), EmptyRow)
}
} else {
domain.foreach { case (v1, v2) =>
checkEvaluation(c(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow)
Expand Down Expand Up @@ -112,6 +128,62 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Conv(Literal("11abc"), Literal(10), Literal(16)), "B")
}

private def checkNaN(
expression: Expression, inputRow: InternalRow = EmptyRow): Unit = {
checkNaNWithoutCodegen(expression, inputRow)
checkNaNWithGeneratedProjection(expression, inputRow)
checkNaNWithOptimization(expression, inputRow)
}

private def checkNaNWithoutCodegen(
expression: Expression,
expected: Any,
inputRow: InternalRow = EmptyRow): Unit = {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
if (!actual.asInstanceOf[Double].isNaN) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect evaluation (codegen off): $expression, " +
s"actual: $actual, " +
s"expected: NaN")
}
}


private def checkNaNWithGeneratedProjection(
expression: Expression,
inputRow: InternalRow = EmptyRow): Unit = {

val plan = try {
GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)()
} catch {
case e: Throwable =>
val ctx = GenerateProjection.newCodeGenContext()
val evaluated = expression.gen(ctx)
fail(
s"""
|Code generation of $expression failed:
|${evaluated.code}
|$e
""".stripMargin)
}

val actual = plan(inputRow).apply(0)
if (!actual.asInstanceOf[Double].isNaN) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: NaN")
}
}

private def checkNaNWithOptimization(
expression: Expression,
inputRow: InternalRow = EmptyRow): Unit = {
val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation)
val optimizedPlan = DefaultOptimizer.execute(plan)
checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow)
}

test("e") {
testLeaf(EulerNumber, math.E)
}
Expand All @@ -126,7 +198,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("asin") {
testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1))
testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNull = true)
testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true)
}

test("sinh") {
Expand All @@ -139,7 +211,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("acos") {
testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1))
testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNull = true)
testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true)
}

test("cosh") {
Expand Down Expand Up @@ -204,18 +276,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("log") {
testUnary(Log, math.log, (0 to 20).map(_ * 0.1))
testUnary(Log, math.log, (-5 to -1).map(_ * 0.1), expectNull = true)
testUnary(Log, math.log, (1 to 20).map(_ * 0.1))
testUnary(Log, math.log, (-5 to 0).map(_ * 0.1), expectNull = true)
}

test("log10") {
testUnary(Log10, math.log10, (0 to 20).map(_ * 0.1))
testUnary(Log10, math.log10, (-5 to -1).map(_ * 0.1), expectNull = true)
testUnary(Log10, math.log10, (1 to 20).map(_ * 0.1))
testUnary(Log10, math.log10, (-5 to 0).map(_ * 0.1), expectNull = true)
}

test("log1p") {
testUnary(Log1p, math.log1p, (-1 to 20).map(_ * 0.1))
testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true)
testUnary(Log1p, math.log1p, (0 to 20).map(_ * 0.1))
testUnary(Log1p, math.log1p, (-10 to -1).map(_ * 1.0), expectNull = true)
}

test("bin") {
Expand All @@ -237,22 +309,22 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("log2") {
def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2)
testUnary(Log2, f, (0 to 20).map(_ * 0.1))
testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true)
testUnary(Log2, f, (1 to 20).map(_ * 0.1))
testUnary(Log2, f, (-5 to 0).map(_ * 1.0), expectNull = true)
}

test("sqrt") {
testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1))
testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNull = true)
testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNaN = true)

checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
checkEvaluation(Sqrt(Literal(-1.0)), null, EmptyRow)
checkEvaluation(Sqrt(Literal(-1.5)), null, EmptyRow)
checkNaN(Sqrt(Literal(-1.0)), EmptyRow)
checkNaN(Sqrt(Literal(-1.5)), EmptyRow)
}

test("pow") {
testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0)))
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true)
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true)
}

test("shift left") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,7 @@ class MathExpressionsSuite extends QueryTest {
if (f(-1) === math.log1p(-1)) {
checkAnswer(
nnDoubleData.select(c('b)),
(1 to 9).map(n => Row(f(n * -0.1))) :+ Row(Double.NegativeInfinity)
)
} else {
checkAnswer(
nnDoubleData.select(c('b)),
(1 to 10).map(n => Row(null))
(1 to 9).map(n => Row(f(n * -0.1))) :+ Row(null)
)
}

Expand Down
Loading