Skip to content

Commit ebc9929

Browse files
committed
Let Logarithm accept one parameter too.
1 parent 605574d commit ebc9929

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,47 @@ object Logarithm {
261261

262262
case class Logarithm(left: Expression, right: Expression)
263263
extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") {
264+
override def eval(input: Row): Any = {
265+
val evalE2 = right.eval(input)
266+
if (evalE2 == null) {
267+
null
268+
} else {
269+
val evalE1 = left.eval(input)
270+
var result: Double = 0.0
271+
if (evalE1 == null) {
272+
result = math.log(evalE2.asInstanceOf[Double])
273+
} else {
274+
result = math.log(evalE2.asInstanceOf[Double]) / math.log(evalE1.asInstanceOf[Double])
275+
}
276+
if (result.isNaN) null else result
277+
}
278+
}
279+
264280
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
265-
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)") + s"""
281+
if (left.dataType != right.dataType) {
282+
// log.warn(s"${left.dataType} != ${right.dataType}")
283+
}
284+
285+
val eval1 = left.gen(ctx)
286+
val eval2 = right.gen(ctx)
287+
val resultCode =
288+
s"java.lang.Math.log(${eval2.primitive}) / java.lang.Math.log(${eval1.primitive})"
289+
290+
s"""
291+
${eval2.code}
292+
boolean ${ev.isNull} = ${eval2.isNull};
293+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
294+
if (!${ev.isNull}) {
295+
${eval1.code}
296+
if (!${eval1.isNull}) {
297+
${ev.primitive} = ${resultCode};
298+
} else {
299+
${ev.primitive} = java.lang.Math.log(${eval2.primitive});
300+
}
301+
}
266302
if (Double.valueOf(${ev.primitive}).isNaN()) {
267303
${ev.isNull} = true;
268304
}
269-
"""
305+
"""
270306
}
271307
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,17 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
205205
}
206206

207207
test("binary log") {
208-
testBinary((e1, e2) => Logarithm(e1, e2), (c1, c2) => math.log(c2) / math.log(c1),
209-
(1 to 20).map(v => (v * 0.1, v * 0.2)))
208+
val f = (c1: Double, c2: Double) => math.log(c2) / math.log(c1)
209+
val domain = (1 to 20).map(v => (v * 0.1, v * 0.2))
210+
211+
domain.foreach { case (v1, v2) =>
212+
checkEvaluation(Logarithm(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow)
213+
checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow)
214+
}
215+
// When base is null, Logarithm is as same as Log
216+
checkEvaluation(Logarithm(Literal.create(null, DoubleType), Literal(1.0)),
217+
math.log(1.0), create_row(null))
218+
checkEvaluation(Logarithm(Literal(1.0), Literal.create(null, DoubleType)),
219+
null, create_row(null))
210220
}
211221
}

0 commit comments

Comments
 (0)