Skip to content

Commit 3909f48

Browse files
committed
math function: log2
1 parent 4e42842 commit 3909f48

File tree

5 files changed

+56
-2
lines changed

5 files changed

+56
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ object FunctionRegistry {
115115
expression[Log10]("log10"),
116116
expression[Log1p]("log1p"),
117117
expression[Pi]("pi"),
118+
expression[Log2]("log2"),
118119
expression[Pow]("pow"),
119120
expression[Rint]("rint"),
120121
expression[Signum]("signum"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,23 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO
161161

162162
case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG")
163163

164+
case class Log2(child: Expression)
165+
extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") {
166+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
167+
val eval = child.gen(ctx)
168+
eval.code + s"""
169+
boolean ${ev.isNull} = ${eval.isNull};
170+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
171+
if (!${ev.isNull}) {
172+
${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2);
173+
if (Double.valueOf(${ev.primitive}).isNaN()) {
174+
${ev.isNull} = true;
175+
}
176+
}
177+
"""
178+
}
179+
}
180+
164181
case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10")
165182

166183
case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P")

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
185185
testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true)
186186
}
187187

188+
test("log2") {
189+
def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2)
190+
testUnary(Log2, f, (0 to 20).map(_ * 0.1))
191+
testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true)
192+
}
193+
188194
test("pow") {
189195
testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0)))
190196
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true)

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,15 +1083,15 @@ object functions {
10831083
def log(columnName: String): Column = log(Column(columnName))
10841084

10851085
/**
1086-
* Computes the logarithm of the given value in Base 10.
1086+
* Computes the logarithm of the given value in base 10.
10871087
*
10881088
* @group math_funcs
10891089
* @since 1.4.0
10901090
*/
10911091
def log10(e: Column): Column = Log10(e.expr)
10921092

10931093
/**
1094-
* Computes the logarithm of the given value in Base 10.
1094+
* Computes the logarithm of the given value in base 10.
10951095
*
10961096
* @group math_funcs
10971097
* @since 1.4.0
@@ -1123,6 +1123,22 @@ object functions {
11231123
*/
11241124
def pi(): Column = Pi()
11251125

1126+
/**
1127+
* Computes the logarithm of the given column in base 2.
1128+
*
1129+
* @group math_funcs
1130+
* @since 1.5.0
1131+
*/
1132+
def log2(expr: Column): Column = Log2(expr.expr)
1133+
1134+
/**
1135+
* Computes the logarithm of the given value in base 2.
1136+
*
1137+
* @group math_funcs
1138+
* @since 1.5.0
1139+
*/
1140+
def log2(columnName: String): Column = log2(Column(columnName))
1141+
11261142
/**
11271143
* Returns the value of the first argument raised to the power of the second argument.
11281144
*

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,18 @@ class DataFrameFunctionsSuite extends QueryTest {
109109
testData2.select(bitwiseNOT($"a")),
110110
testData2.collect().toSeq.map(r => Row(~r.getInt(0))))
111111
}
112+
113+
test("log2 functions test") {
114+
val df = Seq((1, 2)).toDF("a", "b")
115+
checkAnswer(
116+
df.select(log2("b") + log2("a")),
117+
Row(1))
118+
119+
checkAnswer(
120+
ctx.sql("SELECT LOG2(8)"),
121+
Row(3))
122+
checkAnswer(
123+
ctx.sql("SELECT LOG2(null)"),
124+
Row(null))
125+
}
112126
}

0 commit comments

Comments
 (0)