From 188be516aa0e932491ff067527586188efa96e6b Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 17 Jul 2015 02:34:06 +0800 Subject: [PATCH 1/3] null to nan in Math Expression --- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 46 ++------ .../expressions/MathFunctionsSuite.scala | 103 ++++++++++++++---- .../spark/sql/MathExpressionsSuite.scala | 5 - .../execution/HiveCompatibilitySuite.scala | 12 +- 5 files changed, 97 insertions(+), 71 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a45181712dbd..d39c9491b1c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -111,9 +111,9 @@ object FunctionRegistry { expression[Log]("ln"), expression[Log10]("log10"), expression[Log1p]("log1p"), + expression[Log2]("log2"), expression[UnaryMinus]("negative"), expression[Pi]("pi"), - expression[Log2]("log2"), expression[Pow]("pow"), expression[Pow]("power"), expression[Pmod]("pmod"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 7a543ff36afd..c4978bc791ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -66,22 +66,14 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) override def toString: String = s"$name($child)" protected override def nullSafeEval(input: Any): Any = { - val result = f(input.asInstanceOf[Double]) - if (result.isNaN) null else result + f(input.asInstanceOf[Double]) } // name of function in java.lang.Math def funcName: String = name.toLowerCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { - s""" - ${ev.primitive} = java.lang.Math.${funcName}($eval); - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ - }) + defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") } } @@ -101,8 +93,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) override def dataType: DataType = DoubleType protected override def nullSafeEval(input1: Any, input2: Any): Any = { - val result = f(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) - if (result.isNaN) null else result + f(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -404,14 +395,7 @@ case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") case class Log2(child: Expression) extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { - s""" - ${ev.primitive} = java.lang.Math.log($eval) / java.lang.Math.log(2); - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ - }) + defineCodeGen(ctx, ev, c => s"java.lang.Math.log($c) / java.lang.Math.log(2)") } } @@ -578,27 +562,18 @@ case class Atan2(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = { // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 - val result = math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0) - if (result.isNaN) null else result + math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") } } case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") } } @@ -701,16 +676,11 @@ case class Logarithm(left: Expression, right: Expression) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val logCode = if (left.isInstanceOf[EulerNumber]) { + if (left.isInstanceOf[EulerNumber]) { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)") } else { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)") } - logCode + s""" - if (Double.isNaN(${ev.primitive})) { - ${ev.isNull} = true; - } - """ } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index ca35c7ef8ae5..5750f3f028fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -21,6 +21,10 @@ import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ @@ -46,7 +50,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { * @param c expression * @param f The functions in scala.math or elsewhere used to generate expected results * @param domain The set of values to run the function with - * @param expectNull Whether the given values should return null or not + * @param expectNaN Whether the given values should eval to NaN or not * @tparam T Generic type for primitives * @tparam U Generic type for the output of the given function `f` */ @@ -54,11 +58,11 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { c: Expression => Expression, f: T => U, domain: Iterable[T] = (-20 to 20).map(_ * 0.1), - expectNull: Boolean = false, + expectNaN: Boolean = false, evalType: DataType = DoubleType): Unit = { - if (expectNull) { + if (expectNaN) { domain.foreach { value => - checkEvaluation(c(Literal(value)), null, EmptyRow) + checkNaN(c(Literal(value)), EmptyRow) } } else { domain.foreach { value => @@ -74,15 +78,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { * @param c The DataFrame function * @param f The functions in scala.math * @param domain The set of values to run the function with + * @param expectNaN Whether the given values should eval to NaN or not */ private def testBinary( c: (Expression, Expression) => Expression, f: (Double, Double) => Double, domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), - expectNull: Boolean = false): Unit = { - if (expectNull) { + expectNaN: Boolean = false): Unit = { + if (expectNaN) { domain.foreach { case (v1, v2) => - checkEvaluation(c(Literal(v1), Literal(v2)), null, create_row(null)) + checkNaN(c(Literal(v1), Literal(v2)), EmptyRow) } } else { domain.foreach { case (v1, v2) => @@ -112,6 +117,62 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Conv(Literal("11abc"), Literal(10), Literal(16)), "B") } + private def checkNaN( + expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { + checkNaNWithoutCodegen(expression, inputRow) + checkNaNWithGeneratedProjection(expression, inputRow) + checkNaNWithOptimization(expression, inputRow) + } + + private def checkNaNWithoutCodegen( + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + if (!actual.asInstanceOf[Double].isNaN) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect evaluation (codegen off): $expression, " + + s"actual: $actual, " + + s"expected: NaN") + } + } + + + private def checkNaNWithGeneratedProjection( + expression: Expression, + inputRow: InternalRow = EmptyRow): Unit = { + + val plan = try { + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() + } catch { + case e: Throwable => + val ctx = GenerateProjection.newCodeGenContext() + val evaluated = expression.gen(ctx) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow).apply(0) + if (!actual.asInstanceOf[Double].isNaN) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: NaN") + } + } + + private def checkNaNWithOptimization( + expression: Expression, + inputRow: InternalRow = EmptyRow): Unit = { + val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) + val optimizedPlan = DefaultOptimizer.execute(plan) + checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) + } + test("e") { testLeaf(EulerNumber, math.E) } @@ -126,7 +187,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("asin") { testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1)) - testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNull = true) + testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true) } test("sinh") { @@ -139,7 +200,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("acos") { testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) - testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNull = true) + testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true) } test("cosh") { @@ -205,17 +266,17 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("log") { testUnary(Log, math.log, (0 to 20).map(_ * 0.1)) - testUnary(Log, math.log, (-5 to -1).map(_ * 0.1), expectNull = true) + testUnary(Log, math.log, (-5 to -1).map(_ * 0.1), expectNaN = true) } test("log10") { testUnary(Log10, math.log10, (0 to 20).map(_ * 0.1)) - testUnary(Log10, math.log10, (-5 to -1).map(_ * 0.1), expectNull = true) + testUnary(Log10, math.log10, (-5 to -1).map(_ * 0.1), expectNaN = true) } test("log1p") { testUnary(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) - testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true) + testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNaN = true) } test("bin") { @@ -238,21 +299,21 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("log2") { def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) testUnary(Log2, f, (0 to 20).map(_ * 0.1)) - testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true) + testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNaN = true) } test("sqrt") { testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1)) - testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNull = true) + testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNaN = true) checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) - checkEvaluation(Sqrt(Literal(-1.0)), null, EmptyRow) - checkEvaluation(Sqrt(Literal(-1.5)), null, EmptyRow) + checkNaN(Sqrt(Literal(-1.0)), EmptyRow) + checkNaN(Sqrt(Literal(-1.5)), EmptyRow) } test("pow") { testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) - testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) + testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) } test("shift left") { @@ -345,14 +406,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { null, create_row(null)) - // negative input should yield null output - checkEvaluation( + // negative input should yield NaN output + checkNaN( Logarithm(Literal(-1.0), Literal(1.0)), - null, create_row(null)) - checkEvaluation( + checkNaN( Logarithm(Literal(1.0), Literal(-1.0)), - null, create_row(null)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 8eb3fec756b4..6db9cbe6e690 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -70,11 +70,6 @@ class MathExpressionsSuite extends QueryTest { nnDoubleData.select(c('b)), (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(Double.NegativeInfinity) ) - } else { - checkAnswer( - nnDoubleData.select(c('b)), - (1 to 10).map(n => Row(null)) - ) } checkAnswer( diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 4ada64bc2196..6b8f2f6217a5 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -254,7 +254,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Spark SQL use Long for TimestampType, lose the precision under 1us "timestamp_1", "timestamp_2", - "timestamp_udf" + "timestamp_udf", + + // Unlike Hive, we do support log base in (0, 1.0], therefore disable this + "udf7" ) /** @@ -816,19 +819,18 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf2", "udf5", "udf6", - // "udf7", turn this on after we figure out null vs nan vs infinity "udf8", "udf9", "udf_10_trims", "udf_E", "udf_PI", "udf_abs", - // "udf_acos", turn this on after we figure out null vs nan vs infinity + "udf_acos", "udf_add", "udf_array", "udf_array_contains", "udf_ascii", - // "udf_asin", turn this on after we figure out null vs nan vs infinity + "udf_asin", "udf_atan", "udf_avg", "udf_bigint", @@ -915,7 +917,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_regexp_replace", "udf_repeat", "udf_rlike", - // "udf_round", turn this on after we figure out null vs nan vs infinity + "udf_round", "udf_round_3", "udf_rpad", "udf_rtrim", From 63dee441de741ff4ef3856c032994ba029f65c69 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 17 Jul 2015 14:57:46 +0800 Subject: [PATCH 2/3] handle log expressions similar to Hive --- .../spark/sql/catalyst/expressions/math.scala | 69 +++++++++++++++++-- .../expressions/MathFunctionsSuite.scala | 41 +++++++---- .../spark/sql/MathExpressionsSuite.scala | 2 +- 3 files changed, 90 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index c4978bc791ed..3b6ee1693e0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -77,6 +77,30 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) } } +abstract class UnaryLogExpression(f: Double => Double, name: String) + extends UnaryMathExpression(f, name) { self: Product => + + // values less than or equal to yAsymptote eval to null in Hive, instead of NaN or -Infinity + protected val yAsymptote: Double = 0.0 + + protected override def nullSafeEval(input: Any): Any = { + val d = input.asInstanceOf[Double] + if (d <= yAsymptote) null else f(d) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, c => + s""" + if ($c <= $yAsymptote) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.${funcName}($c); + } + """ + ) + } +} + /** * A binary expression specifically for math functions that take two `Double`s as input and returns * a `Double`. @@ -390,18 +414,28 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas } } -case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") +case class Log(child: Expression) extends UnaryLogExpression(math.log, "LOG") case class Log2(child: Expression) - extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { + extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"java.lang.Math.log($c) / java.lang.Math.log(2)") + nullSafeCodeGen(ctx, ev, c => + s""" + if ($c <= $yAsymptote) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.log($c) / java.lang.Math.log(2); + } + """ + ) } } -case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10") +case class Log10(child: Expression) extends UnaryLogExpression(math.log10, "LOG10") -case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") +case class Log1p(child: Expression) extends UnaryLogExpression(math.log1p, "LOG1P") { + protected override val yAsymptote: Double = -1.0 +} case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { override def funcName: String = "rint" @@ -675,11 +709,32 @@ case class Logarithm(left: Expression, right: Expression) this(EulerNumber(), child) } + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val dLeft = input1.asInstanceOf[Double] + val dRight = input2.asInstanceOf[Double] + // Unlike Hive, we support Log base in (0.0, 1.0] + if (dLeft <= 0.0 || dRight <= 0.0) null else math.log(dRight) / math.log(dLeft) + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { if (left.isInstanceOf[EulerNumber]) { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)") + nullSafeCodeGen(ctx, ev, (c1, c2) => + s""" + if ($c2 <= 0.0) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.log($c2); + } + """) } else { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)") + nullSafeCodeGen(ctx, ev, (c1, c2) => + s""" + if ($c1 <= 0.0 || $c2 <= 0.0) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.log($c2) / java.lang.Math.log($c1); + } + """) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 5750f3f028fd..c839ad2879ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -50,6 +50,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { * @param c expression * @param f The functions in scala.math or elsewhere used to generate expected results * @param domain The set of values to run the function with + * @param expectNull Whether the given values should return null or not * @param expectNaN Whether the given values should eval to NaN or not * @tparam T Generic type for primitives * @tparam U Generic type for the output of the given function `f` @@ -58,9 +59,14 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { c: Expression => Expression, f: T => U, domain: Iterable[T] = (-20 to 20).map(_ * 0.1), + expectNull: Boolean = false, expectNaN: Boolean = false, evalType: DataType = DoubleType): Unit = { - if (expectNaN) { + if (expectNull) { + domain.foreach { case value => + checkEvaluation(c(Literal(value)), null, EmptyRow) + } + } else if (expectNaN) { domain.foreach { value => checkNaN(c(Literal(value)), EmptyRow) } @@ -78,14 +84,19 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { * @param c The DataFrame function * @param f The functions in scala.math * @param domain The set of values to run the function with + * @param expectNull Whether the given values should return null or not * @param expectNaN Whether the given values should eval to NaN or not */ private def testBinary( c: (Expression, Expression) => Expression, f: (Double, Double) => Double, domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), - expectNaN: Boolean = false): Unit = { - if (expectNaN) { + expectNull: Boolean = false, expectNaN: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { case (v1, v2) => + checkEvaluation(c(Literal(v1), Literal(v2)), null, create_row(null)) + } + } else if (expectNaN) { domain.foreach { case (v1, v2) => checkNaN(c(Literal(v1), Literal(v2)), EmptyRow) } @@ -265,18 +276,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("log") { - testUnary(Log, math.log, (0 to 20).map(_ * 0.1)) - testUnary(Log, math.log, (-5 to -1).map(_ * 0.1), expectNaN = true) + testUnary(Log, math.log, (1 to 20).map(_ * 0.1)) + testUnary(Log, math.log, (-5 to 0).map(_ * 0.1), expectNull = true) } test("log10") { - testUnary(Log10, math.log10, (0 to 20).map(_ * 0.1)) - testUnary(Log10, math.log10, (-5 to -1).map(_ * 0.1), expectNaN = true) + testUnary(Log10, math.log10, (1 to 20).map(_ * 0.1)) + testUnary(Log10, math.log10, (-5 to 0).map(_ * 0.1), expectNull = true) } test("log1p") { - testUnary(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) - testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNaN = true) + testUnary(Log1p, math.log1p, (0 to 20).map(_ * 0.1)) + testUnary(Log1p, math.log1p, (-10 to -1).map(_ * 1.0), expectNull = true) } test("bin") { @@ -298,8 +309,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("log2") { def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) - testUnary(Log2, f, (0 to 20).map(_ * 0.1)) - testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNaN = true) + testUnary(Log2, f, (1 to 20).map(_ * 0.1)) + testUnary(Log2, f, (-5 to 0).map(_ * 1.0), expectNull = true) } test("sqrt") { @@ -406,12 +417,14 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { null, create_row(null)) - // negative input should yield NaN output - checkNaN( + // negative input should yield null output + checkEvaluation( Logarithm(Literal(-1.0), Literal(1.0)), + null, create_row(null)) - checkNaN( + checkEvaluation( Logarithm(Literal(1.0), Literal(-1.0)), + null, create_row(null)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 6db9cbe6e690..a51523f1a7a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -68,7 +68,7 @@ class MathExpressionsSuite extends QueryTest { if (f(-1) === math.log1p(-1)) { checkAnswer( nnDoubleData.select(c('b)), - (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(Double.NegativeInfinity) + (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(null) ) } From 47a529dc6bc62ae8dac568515f6b5e44df9b00f8 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 17 Jul 2015 15:13:56 +0800 Subject: [PATCH 3/3] style fix --- .../org/apache/spark/sql/catalyst/expressions/math.scala | 6 +++--- .../spark/sql/catalyst/expressions/MathFunctionsSuite.scala | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 3b6ee1693e0c..e3c185c59e75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -96,7 +96,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) } else { ${ev.primitive} = java.lang.Math.${funcName}($c); } - """ + """ ) } } @@ -426,7 +426,7 @@ case class Log2(child: Expression) } else { ${ev.primitive} = java.lang.Math.log($c) / java.lang.Math.log(2); } - """ + """ ) } } @@ -732,7 +732,7 @@ case class Logarithm(left: Expression, right: Expression) if ($c1 <= 0.0 || $c2 <= 0.0) { ${ev.isNull} = true; } else { - ${ev.primitive} = java.lang.Math.log($c2) / java.lang.Math.log($c1); + ${ev.primitive} = java.lang.Math.log($c2) / java.lang.Math.log($c1); } """) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index c839ad2879ae..df988f57fbfd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -63,7 +63,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { expectNaN: Boolean = false, evalType: DataType = DoubleType): Unit = { if (expectNull) { - domain.foreach { case value => + domain.foreach { value => checkEvaluation(c(Literal(value)), null, EmptyRow) } } else if (expectNaN) {