diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala old mode 100644 new mode 100755 index a88bd859fc85e..c478b3810d6a0 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -124,6 +124,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val SUBSTR = Keyword("SUBSTR") protected val SUBSTRING = Keyword("SUBSTRING") protected val SQRT = Keyword("SQRT") + protected val POW = Keyword("POW") + protected val POWER = Keyword("POWER") // Use reflection to find the reserved words defined in this class. protected val reservedWords = @@ -326,6 +328,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers { case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l) } | SQRT ~> "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } | + (POW | POWER) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ { + case s ~ "," ~ p => Power(s,p) + } | ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ { case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala old mode 100644 new mode 100755 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala old mode 100644 new mode 100755 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index f988fb010b107..6c6c1b4dbb858 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.types._ +import scala.math.pow case class UnaryMinus(child: Expression) extends UnaryExpression { type EvaluatedType = Any @@ -129,3 +130,40 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def toString = s"MaxOf($left, $right)" } + +/** + * A function that get the power value of two parameters. + * First one is taken as base while second one taken as exponent + */ +case class Power(base: Expression, exponent: Expression) extends Expression { + type EvaluatedType = Any + + def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") + } + DoubleType + } + override def foldable = base.foldable && exponent.foldable + def nullable: Boolean = base.nullable || exponent.nullable + override def toString = s"Power($base, $exponent)" + + override def children = base :: exponent :: Nil + + override def eval(input: Row): Any = { + def convertToDouble(num: EvaluatedType): Double = { + num match { + case d:Double => d + case i:Integer => i.doubleValue() + case f:Float => f.toDouble + } + } + + val base_v = base.eval(input) + val exponent_v = exponent.eval(input) + + if ((base_v == null) || (exponent_v == null)) null + else pow(convertToDouble(base_v), convertToDouble(exponent_v)) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1ac205937714c..7c49020df905f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -41,6 +41,39 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } + test("SPARK-3176 Added Parser of SQL POWER()") { + checkAnswer( + sql("SELECT POWER(0, 512.0)"), + 0.0) + checkAnswer( + sql("SELECT POW(1.0, 256.0)"), + 1.0) + checkAnswer( + sql("SELECT POWER(1, -128)"), + 1.0) + checkAnswer( + sql("SELECT POW(-1.0, -63)"), + -1.0) + checkAnswer( + sql("SELECT POWER(-1, 32.0)"), + 1.0) + checkAnswer( + sql("SELECT POW(2, 8)"), + 256.0) + checkAnswer( + sql("SELECT POWER(0.5, 2)"), + 0.25) + checkAnswer( + sql("SELECT POW(2, -2)"), + 0.25) + checkAnswer( + sql("SELECT POWER(8, 1)"), + 8.0) + checkAnswer( + sql("SELECT POW(16, 0.5)"), + 4.0) + } + test("SPARK-2041 column name equals tablename") { checkAnswer( sql("SELECT tableName FROM tableName"), @@ -53,14 +86,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { (1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq ) } - + test("SQRT with automatic string casts") { checkAnswer( sql("SELECT SQRT(CAST(key AS STRING)) FROM testData"), (1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq ) } - + test("SPARK-2407 Added Parser of SQL SUBSTR()") { checkAnswer( sql("SELECT substr(tableName, 1, 2) FROM tableName"),