Skip to content

Commit 4be400a

Browse files
committed
Refactor Unary logs as well as fix the binary one
1 parent ec8bee2 commit 4be400a

File tree

4 files changed

+37
-44
lines changed

4 files changed

+37
-44
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,11 @@ object FunctionRegistry {
114114
expression[Hypot]("hypot"),
115115
expression[Logarithm]("log"),
116116
expression[Log]("ln"),
117-
expression[Log10]("log10"),
118117
expression[Log1p]("log1p"),
118+
expression[Log10]("log10"),
119+
expression[Log2]("log2"),
119120
expression[UnaryMinus]("negative"),
120121
expression[Pi]("pi"),
121-
expression[Log2]("log2"),
122122
expression[Pow]("pow"),
123123
expression[Pow]("power"),
124124
expression[UnaryPositive]("positive"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,19 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
8585
}
8686
}
8787

88-
abstract class UnaryMathLogExpression(f: Double => Double, name: String)
88+
/**
89+
* A expression specifically for unary log functions.
90+
* @param f The math function for non codegen evaluation
91+
* @param name The short name of the log function
92+
* @param yAsymptote values less than or equal to yAsymptote are considered eval to null
93+
*/
94+
abstract class UnaryLogarithmExpression(f: Double => Double, name: String, yAsymptote: Double)
8995
extends UnaryMathExpression(f, name) {
9096
self: Product =>
9197

9298
override def eval(input: InternalRow): Any = {
9399
val evalE = child.eval(input)
94-
if (evalE == null || evalE.asInstanceOf[Double] <= 0.0) {
100+
if (evalE == null || evalE.asInstanceOf[Double] <= yAsymptote) {
95101
null
96102
} else {
97103
f(evalE.asInstanceOf[Double])
@@ -101,7 +107,7 @@ abstract class UnaryMathLogExpression(f: Double => Double, name: String)
101107
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
102108
val eval = child.gen(ctx)
103109
eval.code + s"""
104-
boolean ${ev.isNull} = ${eval.isNull} || Double.valueOf(${eval.primitive}) <= 0.0;
110+
boolean ${ev.isNull} = ${eval.isNull} || Double.valueOf(${eval.primitive}) <= $yAsymptote;
105111
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
106112
if (!${ev.isNull}) {
107113
${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive});
@@ -139,8 +145,10 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
139145
}
140146
}
141147

148+
def funcName = name.toLowerCase
149+
142150
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
143-
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)")
151+
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${funcName}($c1, $c2)")
144152
}
145153
}
146154

@@ -180,41 +188,22 @@ case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXP
180188

181189
case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR")
182190

183-
case class Log(child: Expression) extends UnaryMathLogExpression(math.log, "LOG")
191+
case class Log(child: Expression) extends UnaryLogarithmExpression(math.log, "LOG", 0.0)
184192

185-
case class Log2(child: Expression)
186-
extends UnaryMathLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") {
187-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
188-
val eval = child.gen(ctx)
189-
eval.code + s"""
190-
boolean ${ev.isNull} = ${eval.isNull} || Double.valueOf(${eval.primitive}) <= 0.0;
191-
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
192-
if (!${ev.isNull}) {
193-
${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2);
194-
}
195-
"""
196-
}
197-
}
193+
case class Log10(child: Expression) extends UnaryLogarithmExpression(math.log10, "LOG10", 0.0)
198194

199-
case class Log10(child: Expression) extends UnaryMathLogExpression(math.log10, "LOG10")
195+
case class Log1p(child: Expression) extends UnaryLogarithmExpression(math.log1p, "LOG1P", -1.0)
200196

201-
case class Log1p(child: Expression) extends UnaryMathLogExpression(math.log1p, "LOG1P") {
202-
override def eval(input: InternalRow): Any = {
203-
val evalE = child.eval(input)
204-
if (evalE == null || evalE.asInstanceOf[Double] + 1 <= 0.0) {
205-
null
206-
} else {
207-
math.log1p(evalE.asInstanceOf[Double])
208-
}
209-
}
197+
case class Log2(child: Expression)
198+
extends UnaryLogarithmExpression((x: Double) => math.log(x) / math.log(2), "LOG2", 0.0) {
210199

211200
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
212201
val eval = child.gen(ctx)
213202
eval.code + s"""
214-
boolean ${ev.isNull} = ${eval.isNull} || Double.valueOf(${eval.primitive}) + 1 <= 0.0;
203+
boolean ${ev.isNull} = ${eval.isNull} || Double.valueOf(${eval.primitive}) <= 0.0;
215204
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
216205
if (!${ev.isNull}) {
217-
${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive});
206+
${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2);
218207
}
219208
"""
220209
}
@@ -286,19 +275,24 @@ case class Pow(left: Expression, right: Expression)
286275

287276
case class Logarithm(left: Expression, right: Expression)
288277
extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") {
289-
def this(child: Expression) = {
290-
this(EulerNumber(), child)
291-
}
292278

293279
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
294-
val logCode = if (left.isInstanceOf[EulerNumber]) {
295-
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)")
296-
} else {
297-
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)")
298-
}
299-
logCode + s"""
300-
if (Double.valueOf(${ev.primitive}).isNaN()) {
280+
val eval1 = left.gen(ctx)
281+
val eval2 = right.gen(ctx)
282+
s"""
283+
${eval1.code}
284+
boolean ${ev.isNull} = ${eval1.isNull} || ${eval1.primitive} <= 0.0;
285+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
286+
if (${ev.isNull}) {
301287
${ev.isNull} = true;
288+
} else {
289+
${eval2.code}
290+
if (${eval2.isNull} || ${eval2.primitive} <= 0.0) {
291+
${ev.isNull} = true;
292+
} else {
293+
${ev.primitive} = java.lang.Math.${funcName}(${eval2.primitive}) /
294+
java.lang.Math.${funcName}(${eval1.primitive});
295+
}
302296
}
303297
"""
304298
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
280280
domain.foreach { case (v1, v2) =>
281281
checkEvaluation(Logarithm(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow)
282282
checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow)
283-
checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow)
284283
}
285284
checkEvaluation(
286285
Logarithm(Literal.create(null, DoubleType), Literal(1.0)),

sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ class MathExpressionsSuite extends QueryTest {
245245
Row(math.log(123), math.log(123) / math.log(2), null))
246246

247247
checkAnswer(
248-
df.selectExpr("log(a)", "log(2.0, a)", "log(b)"),
248+
df.selectExpr("ln(a)", "log(2.0, a)", "ln(b)"),
249249
Row(math.log(123), math.log(123) / math.log(2), null))
250250
}
251251

0 commit comments

Comments
 (0)