Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ object FunctionRegistry {
expression[Log10]("log10"),
expression[Log1p]("log1p"),
expression[Pi]("pi"),
expression[Log2]("log2"),
expression[Pow]("pow"),
expression[Rint]("rint"),
expression[Signum]("signum"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,23 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO

case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG")

case class Log2(child: Expression)
extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you define a new defineGenCode for UnaryMathExpression to check NaN? Then other math function can benefit from it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's hold off on that because we might want to have both NaN and null (Hive does that).

if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
}
}
"""
}
}

case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10")

case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true)
}

test("log2") {
def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2)
testUnary(Log2, f, (0 to 20).map(_ * 0.1))
testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also test if null value passed in?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, it probably cause test failed due to some bug in codegen. I am trying to solve that in #6724 .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added null test in SQLQuerySuite

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be great to add a test here for the expression as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but this suite never use unfoldable expression.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i realized testUnary already test for null.

}

test("pow") {
testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0)))
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true)
Expand Down
20 changes: 18 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1083,15 +1083,15 @@ object functions {
def log(columnName: String): Column = log(Column(columnName))

/**
* Computes the logarithm of the given value in Base 10.
* Computes the logarithm of the given value in base 10.
*
* @group math_funcs
* @since 1.4.0
*/
def log10(e: Column): Column = Log10(e.expr)

/**
* Computes the logarithm of the given value in Base 10.
* Computes the logarithm of the given value in base 10.
*
* @group math_funcs
* @since 1.4.0
Expand Down Expand Up @@ -1123,6 +1123,22 @@ object functions {
*/
def pi(): Column = Pi()

/**
* Computes the logarithm of the given column in base 2.
*
* @group math_funcs
* @since 1.5.0
*/
def log2(expr: Column): Column = Log2(expr.expr)

/**
* Computes the logarithm of the given value in base 2.
*
* @group math_funcs
* @since 1.5.0
*/
def log2(columnName: String): Column = log2(Column(columnName))

/**
* Returns the value of the first argument raised to the power of the second argument.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,18 @@ class DataFrameFunctionsSuite extends QueryTest {
testData2.select(bitwiseNOT($"a")),
testData2.collect().toSeq.map(r => Row(~r.getInt(0))))
}

test("log2 functions test") {
val df = Seq((1, 2)).toDF("a", "b")
checkAnswer(
df.select(log2("b") + log2("a")),
Row(1))

checkAnswer(
ctx.sql("SELECT LOG2(8)"),
Row(3))
checkAnswer(
ctx.sql("SELECT LOG2(null)"),
Row(null))
}
}