Skip to content

Commit fee3438

Browse files
viiryarxin
authored andcommitted
[SPARK-8218][SQL] Add binary log math function
JIRA: https://issues.apache.org/jira/browse/SPARK-8218 Because there is already `log` unary function defined, the binary log function is called `logarithm` for now. Author: Liang-Chi Hsieh <[email protected]> Closes apache#6725 from viirya/expr_binary_log and squashes the following commits: bf96bd9 [Liang-Chi Hsieh] Compare log result in string. 102070d [Liang-Chi Hsieh] Round log result to better comparing in python test. fd01863 [Liang-Chi Hsieh] For comments. beed631 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log 6089d11 [Liang-Chi Hsieh] Remove unnecessary override. 8cf37b7 [Liang-Chi Hsieh] For comments. bc89597 [Liang-Chi Hsieh] For comments. db7dc38 [Liang-Chi Hsieh] Use ctor instead of companion object. 0634ef7 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log 1750034 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log 3d75bfc [Liang-Chi Hsieh] Fix scala style. 5b39c02 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log 23c54a3 [Liang-Chi Hsieh] Fix scala style. ebc9929 [Liang-Chi Hsieh] Let Logarithm accept one parameter too. 605574d [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log 21c3bfd [Liang-Chi Hsieh] Fix scala style. c6c187f [Liang-Chi Hsieh] For comments. c795342 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log f373bac [Liang-Chi Hsieh] Add binary log expression.
1 parent 78a430e commit fee3438

File tree

6 files changed

+85
-1
lines changed

6 files changed

+85
-1
lines changed

python/pyspark/sql/functions.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"""
1919
A collections of builtin functions
2020
"""
21+
import math
2122
import sys
2223

2324
if sys.version < "3":
@@ -143,7 +144,7 @@ def _():
143144
'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
144145
'polar coordinates (r, theta).',
145146
'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.',
146-
'pow': 'Returns the value of the first argument raised to the power of the second argument.'
147+
'pow': 'Returns the value of the first argument raised to the power of the second argument.',
147148
}
148149

149150
_window_functions = {
@@ -403,6 +404,21 @@ def when(condition, value):
403404
return Column(jc)
404405

405406

407+
@since(1.4)
408+
def log(col, base=math.e):
409+
"""Returns the first argument-based logarithm of the second argument.
410+
411+
>>> df.select(log(df.age, 10.0).alias('ten')).map(lambda l: str(l.ten)[:7]).collect()
412+
['0.30102', '0.69897']
413+
414+
>>> df.select(log(df.age).alias('e')).map(lambda l: str(l.e)[:7]).collect()
415+
['0.69314', '1.60943']
416+
"""
417+
sc = SparkContext._active_spark_context
418+
jc = sc._jvm.functions.log(base, _to_java_column(col))
419+
return Column(jc)
420+
421+
406422
@since(1.4)
407423
def lag(col, count=1, default=None):
408424
"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ object FunctionRegistry {
112112
expression[Expm1]("expm1"),
113113
expression[Floor]("floor"),
114114
expression[Hypot]("hypot"),
115+
expression[Logarithm]("log"),
115116
expression[Log]("ln"),
116117
expression[Log10]("log10"),
117118
expression[Log1p]("log1p"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,23 @@ case class Pow(left: Expression, right: Expression)
255255
"""
256256
}
257257
}
258+
259+
case class Logarithm(left: Expression, right: Expression)
260+
extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") {
261+
def this(child: Expression) = {
262+
this(EulerNumber(), child)
263+
}
264+
265+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
266+
val logCode = if (left.isInstanceOf[EulerNumber]) {
267+
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)")
268+
} else {
269+
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)")
270+
}
271+
logCode + s"""
272+
if (Double.valueOf(${ev.primitive}).isNaN()) {
273+
${ev.isNull} = true;
274+
}
275+
"""
276+
}
277+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,4 +204,22 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
204204
testBinary(Atan2, math.atan2)
205205
}
206206

207+
test("binary log") {
208+
val f = (c1: Double, c2: Double) => math.log(c2) / math.log(c1)
209+
val domain = (1 to 20).map(v => (v * 0.1, v * 0.2))
210+
211+
domain.foreach { case (v1, v2) =>
212+
checkEvaluation(Logarithm(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow)
213+
checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow)
214+
checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow)
215+
}
216+
checkEvaluation(
217+
Logarithm(Literal.create(null, DoubleType), Literal(1.0)),
218+
null,
219+
create_row(null))
220+
checkEvaluation(
221+
Logarithm(Literal(1.0), Literal.create(null, DoubleType)),
222+
null,
223+
create_row(null))
224+
}
207225
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,22 @@ object functions {
10831083
*/
10841084
def log(columnName: String): Column = log(Column(columnName))
10851085

1086+
/**
1087+
* Returns the first argument-base logarithm of the second argument.
1088+
*
1089+
* @group math_funcs
1090+
* @since 1.4.0
1091+
*/
1092+
def log(base: Double, a: Column): Column = Logarithm(lit(base).expr, a.expr)
1093+
1094+
/**
1095+
* Returns the first argument-base logarithm of the second argument.
1096+
*
1097+
* @group math_funcs
1098+
* @since 1.4.0
1099+
*/
1100+
def log(base: Double, columnName: String): Column = log(base, Column(columnName))
1101+
10861102
/**
10871103
* Computes the logarithm of the given value in base 10.
10881104
*

sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,19 @@ class MathExpressionsSuite extends QueryTest {
236236
testOneToOneNonNegativeMathFunction(log1p, math.log1p)
237237
}
238238

239+
test("binary log") {
240+
val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b")
241+
checkAnswer(
242+
df.select(org.apache.spark.sql.functions.log("a"),
243+
org.apache.spark.sql.functions.log(2.0, "a"),
244+
org.apache.spark.sql.functions.log("b")),
245+
Row(math.log(123), math.log(123) / math.log(2), null))
246+
247+
checkAnswer(
248+
df.selectExpr("log(a)", "log(2.0, a)", "log(b)"),
249+
Row(math.log(123), math.log(123) / math.log(2), null))
250+
}
251+
239252
test("abs") {
240253
val input =
241254
Seq[(java.lang.Double, java.lang.Double)]((null, null), (0.0, 0.0), (1.5, 1.5), (-2.5, 2.5))

0 commit comments

Comments
 (0)