Skip to content

Commit 3164112

Browse files
viiryarxin
authored andcommitted
[SPARK-8363][SQL] Move sqrt to math and extend UnaryMathExpression
JIRA: https://issues.apache.org/jira/browse/SPARK-8363 Author: Liang-Chi Hsieh <[email protected]> Closes apache#6823 from viirya/move_sqrt and squashes the following commits: 8977e11 [Liang-Chi Hsieh] Remove unnecessary old tests. d23e79e [Liang-Chi Hsieh] Explicitly indicate sqrt value sequence. 699f48b [Liang-Chi Hsieh] Use correct @SInCE tag. 8dff6d1 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into move_sqrt bc2ed77 [Liang-Chi Hsieh] Remove/move arithmetic expression test and expression type checking test. Remove unnecessary Sqrt type rule. d38492f [Liang-Chi Hsieh] Now sqrt accepts boolean because type casting is handled by HiveTypeCoercion. 297cc90 [Liang-Chi Hsieh] Sqrt only accepts double input. ef4a21a [Liang-Chi Hsieh] Move sqrt to math.
1 parent ddc5baf commit 3164112

File tree

8 files changed

+31
-51
lines changed

8 files changed

+31
-51
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,6 @@ trait HiveTypeCoercion {
307307

308308
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
309309
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
310-
case Sqrt(e @ StringType()) => Sqrt(Cast(e, DoubleType))
311310
}
312311
}
313312

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -67,38 +67,6 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic {
6767
protected override def evalInternal(evalE: Any) = evalE
6868
}
6969

70-
case class Sqrt(child: Expression) extends UnaryArithmetic {
71-
override def dataType: DataType = DoubleType
72-
override def nullable: Boolean = true
73-
override def toString: String = s"SQRT($child)"
74-
75-
override def checkInputDataTypes(): TypeCheckResult =
76-
TypeUtils.checkForNumericExpr(child.dataType, "function sqrt")
77-
78-
private lazy val numeric = TypeUtils.getNumeric(child.dataType)
79-
80-
protected override def evalInternal(evalE: Any) = {
81-
val value = numeric.toDouble(evalE)
82-
if (value < 0) null
83-
else math.sqrt(value)
84-
}
85-
86-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
87-
val eval = child.gen(ctx)
88-
eval.code + s"""
89-
boolean ${ev.isNull} = ${eval.isNull};
90-
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
91-
if (!${ev.isNull}) {
92-
if (${eval.primitive} < 0.0) {
93-
${ev.isNull} = true;
94-
} else {
95-
${ev.primitive} = java.lang.Math.sqrt(${eval.primitive});
96-
}
97-
}
98-
"""
99-
}
100-
}
101-
10270
/**
10371
* A function that get the absolute value of the numeric value.
10472
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN")
193193

194194
case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH")
195195

196+
case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT")
197+
196198
case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN")
197199

198200
case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH")

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -142,19 +142,4 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
142142
checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1)
143143
checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1)
144144
}
145-
146-
test("SQRT") {
147-
val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24))
148-
val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble))
149-
val rowSequence = inputSequence.map(l => create_row(l.toDouble))
150-
val d = 'a.double.at(0)
151-
152-
for ((row, expected) <- rowSequence zip expectedResults) {
153-
checkEvaluation(Sqrt(d), expected, row)
154-
}
155-
156-
checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
157-
checkEvaluation(Sqrt(-1), null, EmptyRow)
158-
checkEvaluation(Sqrt(-1.5), null, EmptyRow)
159-
}
160145
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
5454

5555
test("check types for unary arithmetic") {
5656
assertError(UnaryMinus('stringField), "operator - accepts numeric type")
57-
assertSuccess(Sqrt('stringField)) // We will cast String to Double for sqrt
58-
assertError(Sqrt('booleanField), "function sqrt accepts numeric type")
5957
assertError(Abs('stringField), "function abs accepts numeric type")
6058
assertError(BitwiseNot('stringField), "operator ~ accepts integral type")
6159
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import org.apache.spark.sql.catalyst.dsl.expressions._
2021
import org.apache.spark.SparkFunSuite
2122
import org.apache.spark.sql.types.DoubleType
2223

@@ -191,6 +192,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
191192
testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true)
192193
}
193194

195+
test("sqrt") {
196+
testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1))
197+
testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNull = true)
198+
199+
checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
200+
checkEvaluation(Sqrt(Literal(-1.0)), null, EmptyRow)
201+
checkEvaluation(Sqrt(Literal(-1.5)), null, EmptyRow)
202+
}
203+
194204
test("pow") {
195205
testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0)))
196206
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true)

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,11 +707,19 @@ object functions {
707707
/**
708708
* Computes the square root of the specified float value.
709709
*
710-
* @group normal_funcs
710+
* @group math_funcs
711711
* @since 1.3.0
712712
*/
713713
def sqrt(e: Column): Column = Sqrt(e.expr)
714714

715+
/**
716+
* Computes the square root of the specified float value.
717+
*
718+
* @group math_funcs
719+
* @since 1.5.0
720+
*/
721+
def sqrt(colName: String): Column = sqrt(Column(colName))
722+
715723
/**
716724
* Creates a new struct column. The input column must be a column in a [[DataFrame]], or
717725
* a derived column expression that is named (i.e. aliased).

sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,16 @@ class MathExpressionsSuite extends QueryTest {
270270
checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
271271
}
272272

273+
test("sqrt") {
274+
val df = Seq((1, 4)).toDF("a", "b")
275+
checkAnswer(
276+
df.select(sqrt("a"), sqrt("b")),
277+
Row(1.0, 2.0))
278+
279+
checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null))
280+
checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null))
281+
}
282+
273283
test("negative") {
274284
checkAnswer(
275285
ctx.sql("SELECT negative(1), negative(0), negative(-1)"),

0 commit comments

Comments
 (0)