Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val SUBSTR = Keyword("SUBSTR")
protected val SUBSTRING = Keyword("SUBSTRING")
protected val SQRT = Keyword("SQRT")
protected val POW = Keyword("POW")
protected val POWER = Keyword("POWER")

// Use reflection to find the reserved words defined in this class.
protected val reservedWords =
Expand Down Expand Up @@ -326,6 +328,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l)
} |
SQRT ~> "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } |
(POW | POWER) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ {
case s ~ "," ~ p => Power(s,p)
} |
ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ {
case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs)
}
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.types._
import scala.math.pow

case class UnaryMinus(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
Expand Down Expand Up @@ -129,3 +130,40 @@ case class MaxOf(left: Expression, right: Expression) extends Expression {

override def toString = s"MaxOf($left, $right)"
}

/**
* A function that get the power value of two parameters.
* First one is taken as base while second one taken as exponent
*/
case class Power(base: Expression, exponent: Expression) extends Expression {
type EvaluatedType = Any

def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved")
}
DoubleType
}
override def foldable = base.foldable && exponent.foldable
def nullable: Boolean = base.nullable || exponent.nullable
override def toString = s"Power($base, $exponent)"

override def children = base :: exponent :: Nil

override def eval(input: Row): Any = {
def convertToDouble(num: EvaluatedType): Double = {
num match {
case d:Double => d
case i:Integer => i.doubleValue()
case f:Float => f.toDouble
}
}

val base_v = base.eval(input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Let's check if base_v is null first and then compute exponent_v, probably performance gains in some cases.

val exponent_v = exponent.eval(input)

if ((base_v == null) || (exponent_v == null)) null
else pow(convertToDouble(base_v), convertToDouble(exponent_v))
}

}
37 changes: 35 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,39 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}


test("SPARK-3176 Added Parser of SQL POWER()") {
checkAnswer(
sql("SELECT POWER(0, 512.0)"),
0.0)
checkAnswer(
sql("SELECT POW(1.0, 256.0)"),
1.0)
checkAnswer(
sql("SELECT POWER(1, -128)"),
1.0)
checkAnswer(
sql("SELECT POW(-1.0, -63)"),
-1.0)
checkAnswer(
sql("SELECT POWER(-1, 32.0)"),
1.0)
checkAnswer(
sql("SELECT POW(2, 8)"),
256.0)
checkAnswer(
sql("SELECT POWER(0.5, 2)"),
0.25)
checkAnswer(
sql("SELECT POW(2, -2)"),
0.25)
checkAnswer(
sql("SELECT POWER(8, 1)"),
8.0)
checkAnswer(
sql("SELECT POW(16, 0.5)"),
4.0)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add null as input also?


test("SPARK-2041 column name equals tablename") {
checkAnswer(
sql("SELECT tableName FROM tableName"),
Expand All @@ -53,14 +86,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
(1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq
)
}

test("SQRT with automatic string casts") {
checkAnswer(
sql("SELECT SQRT(CAST(key AS STRING)) FROM testData"),
(1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq
)
}

test("SPARK-2407 Added Parser of SQL SUBSTR()") {
checkAnswer(
sql("SELECT substr(tableName, 1, 2) FROM tableName"),
Expand Down