-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-8280][SPARK-8281][SQL]Handle NaN, null and Infinity in math #6835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a06c924
2d1dfc1
ef8c28d
ec8bee2
4be400a
307ba7e
a150de5
f19f651
0c96d86
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -69,8 +69,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) | |
| if (evalE == null) { | ||
| null | ||
| } else { | ||
| val result = f(evalE.asInstanceOf[Double]) | ||
| if (result.isNaN) null else result | ||
| f(evalE.asInstanceOf[Double]) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -84,9 +83,37 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) | |
| ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; | ||
| if (!${ev.isNull}) { | ||
| ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); | ||
| if (Double.valueOf(${ev.primitive}).isNaN()) { | ||
| ${ev.isNull} = true; | ||
| } | ||
| } | ||
| """ | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * A expression specifically for unary log functions. | ||
| * @param f The math function for non codegen evaluation | ||
| * @param name The short name of the log function | ||
| * @param yAsymptote values less than or equal to yAsymptote are considered eval to null | ||
| */ | ||
| abstract class UnaryLogarithmExpression(f: Double => Double, name: String, yAsymptote: Double) | ||
| extends UnaryMathExpression(f, name) { | ||
| self: Product => | ||
|
|
||
| override def eval(input: InternalRow): Any = { | ||
| val evalE = child.eval(input) | ||
| if (evalE == null || evalE.asInstanceOf[Double] <= yAsymptote) { | ||
| null | ||
| } else { | ||
| f(evalE.asInstanceOf[Double]) | ||
| } | ||
| } | ||
|
|
||
| override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { | ||
| val eval = child.gen(ctx) | ||
| eval.code + s""" | ||
| boolean ${ev.isNull} = ${eval.isNull} || ${eval.primitive} <= $yAsymptote; | ||
| ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; | ||
| if (!${ev.isNull}) { | ||
| ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); | ||
| } | ||
| """ | ||
| } | ||
|
|
@@ -116,14 +143,15 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) | |
| if (evalE2 == null) { | ||
| null | ||
| } else { | ||
| val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) | ||
| if (result.isNaN) null else result | ||
| f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| def funcName: String = name.toLowerCase | ||
|
|
||
| override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { | ||
| defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") | ||
| defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${funcName}($c1, $c2)") | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -163,29 +191,27 @@ case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXP | |
|
|
||
| case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") | ||
|
|
||
| case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") | ||
| case class Log(child: Expression) extends UnaryLogarithmExpression(math.log, "LOG", 0.0) | ||
|
|
||
| case class Log10(child: Expression) extends UnaryLogarithmExpression(math.log10, "LOG10", 0.0) | ||
|
|
||
| case class Log1p(child: Expression) extends UnaryLogarithmExpression(math.log1p, "LOG1P", -1.0) | ||
|
|
||
| case class Log2(child: Expression) | ||
| extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { | ||
| extends UnaryLogarithmExpression((x: Double) => math.log(x) / math.log(2), "LOG2", 0.0) { | ||
|
|
||
| override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { | ||
| val eval = child.gen(ctx) | ||
| eval.code + s""" | ||
| boolean ${ev.isNull} = ${eval.isNull}; | ||
| boolean ${ev.isNull} = ${eval.isNull} || ${eval.primitive} <= 0.0; | ||
| ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; | ||
| if (!${ev.isNull}) { | ||
| ${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2); | ||
| if (Double.valueOf(${ev.primitive}).isNaN()) { | ||
| ${ev.isNull} = true; | ||
| } | ||
| } | ||
| """ | ||
| } | ||
| } | ||
|
|
||
| case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10") | ||
|
|
||
| case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") | ||
|
|
||
| case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { | ||
| override def funcName: String = "rint" | ||
| } | ||
|
|
@@ -259,19 +285,14 @@ case class Atan2(left: Expression, right: Expression) | |
| null | ||
| } else { | ||
| // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 | ||
| val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, | ||
| math.atan2(evalE1.asInstanceOf[Double] + 0.0, | ||
| evalE2.asInstanceOf[Double] + 0.0) | ||
| if (result.isNaN) null else result | ||
| } | ||
| } | ||
| } | ||
|
|
||
| override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { | ||
| defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" | ||
| if (Double.valueOf(${ev.primitive}).isNaN()) { | ||
| ${ev.isNull} = true; | ||
| } | ||
| """ | ||
| defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -281,34 +302,81 @@ case class Hypot(left: Expression, right: Expression) | |
| case class Pow(left: Expression, right: Expression) | ||
| extends BinaryMathExpression(math.pow, "POWER") { | ||
| override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { | ||
| defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" | ||
| if (Double.valueOf(${ev.primitive}).isNaN()) { | ||
| ${ev.isNull} = true; | ||
| } | ||
| """ | ||
| defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") | ||
| } | ||
| } | ||
|
|
||
| case class Logarithm(left: Expression, right: Expression) | ||
| extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { | ||
|
|
||
| /** | ||
| * Natural log, i.e. using e as the base. | ||
| */ | ||
| def this(child: Expression) = { | ||
| this(EulerNumber(), child) | ||
| override def eval(input: InternalRow): Any = { | ||
| val evalE1 = left.eval(input) | ||
| if (evalE1 == null || evalE1.asInstanceOf[Double] <= 0.0) { | ||
| null | ||
| } else { | ||
| val evalE2 = right.eval(input) | ||
| if (evalE2 == null || evalE2.asInstanceOf[Double] <= 0.0) { | ||
| null | ||
| } else { | ||
| math.log(evalE2.asInstanceOf[Double]) / math.log(evalE1.asInstanceOf[Double]) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { | ||
| val logCode = if (left.isInstanceOf[EulerNumber]) { | ||
| defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)") | ||
| } else { | ||
| defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)") | ||
| } | ||
| logCode + s""" | ||
| if (Double.valueOf(${ev.primitive}).isNaN()) { | ||
| val eval1 = left.gen(ctx) | ||
| val eval2 = right.gen(ctx) | ||
| s""" | ||
| ${eval1.code} | ||
| boolean ${ev.isNull} = ${eval1.isNull} || ${eval1.primitive} <= 0.0; | ||
| ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; | ||
| if (${ev.isNull}) { | ||
| ${ev.isNull} = true; | ||
| } else { | ||
| ${eval2.code} | ||
| if (${eval2.isNull} || ${eval2.primitive} <= 0.0) { | ||
| ${ev.isNull} = true; | ||
| } else { | ||
| ${ev.primitive} = java.lang.Math.${funcName}(${eval2.primitive}) / | ||
| java.lang.Math.${funcName}(${eval1.primitive}); | ||
| } | ||
| } | ||
| """ | ||
| } | ||
|
|
||
| // TODO: Hive's UDFLog doesn't support base in range (0.0, 1.0] | ||
| // If we want just behaves like Hive, use the code below and turn `udf_7` on | ||
|
|
||
| // override def eval(input: InternalRow): Any = { | ||
| // val evalE1 = left.eval(input) | ||
| // val evalE2 = right.eval(input) | ||
| // if (evalE1 == null || evalE2 == null) { | ||
| // null | ||
| // } else { | ||
| // if (evalE1.asInstanceOf[Double] <= 1.0 || evalE2.asInstanceOf[Double] <= 0.0) { | ||
| // null | ||
| // } else { | ||
| // math.log(evalE2.asInstanceOf[Double]) / math.log(evalE1.asInstanceOf[Double]) | ||
| // } | ||
| // } | ||
| // } | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hive's UDFLog doesn't support base in range (0.0, 1.0], it would just eval to null in this case, just act as the commented out code above. I'm not sure it we also want to behave like this, any ideas?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should support these. just remove the commented code, and add to inline comment for log that we support (0.0, 1.0], unlike hive. |
||
| // | ||
| // override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { | ||
| // val eval1 = left.gen(ctx) | ||
| // val eval2 = right.gen(ctx) | ||
| // s""" | ||
| // ${eval1.code} | ||
| // ${eval2.code} | ||
| // boolean ${ev.isNull} = ${eval1.isNull} || ${eval2.isNull}; | ||
| // ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; | ||
| // if (!${ev.isNull}) { | ||
| // if (${eval2.primitive} <= 1.0 || ${eval2.primitive} <= 0.0) { | ||
| // ${ev.isNull} = true; | ||
| // } else { | ||
| // ${ev.primitive} = java.lang.Math.${funcName}(${eval2.primitive}) / | ||
| // java.lang.Math.${funcName}(${eval1.primitive}); | ||
| // } | ||
| // } | ||
| // """ | ||
| // } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't think u need to wrap here.