Skip to content

Commit ef4a21a

Browse files
committed
Move sqrt to math.
1 parent 4eb48ed commit ef4a21a

File tree

5 files changed

+26
-33
lines changed

5 files changed

+26
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -58,38 +58,6 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {
5858
protected override def evalInternal(evalE: Any) = numeric.negate(evalE)
5959
}
6060

61-
case class Sqrt(child: Expression) extends UnaryArithmetic {
62-
override def dataType: DataType = DoubleType
63-
override def nullable: Boolean = true
64-
override def toString: String = s"SQRT($child)"
65-
66-
override def checkInputDataTypes(): TypeCheckResult =
67-
TypeUtils.checkForNumericExpr(child.dataType, "function sqrt")
68-
69-
private lazy val numeric = TypeUtils.getNumeric(child.dataType)
70-
71-
protected override def evalInternal(evalE: Any) = {
72-
val value = numeric.toDouble(evalE)
73-
if (value < 0) null
74-
else math.sqrt(value)
75-
}
76-
77-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
78-
val eval = child.gen(ctx)
79-
eval.code + s"""
80-
boolean ${ev.isNull} = ${eval.isNull};
81-
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
82-
if (!${ev.isNull}) {
83-
if (${eval.primitive} < 0.0) {
84-
${ev.isNull} = true;
85-
} else {
86-
${ev.primitive} = java.lang.Math.sqrt(${eval.primitive});
87-
}
88-
}
89-
"""
90-
}
91-
}
92-
9361
/**
9462
* A function that get the absolute value of the numeric value.
9563
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN")
193193

194194
case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH")
195195

196+
case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT")
197+
196198
case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN")
197199

198200
case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH")

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,11 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
191191
testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true)
192192
}
193193

194+
test("sqrt") {
195+
testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1))
196+
testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNull = true)
197+
}
198+
194199
test("pow") {
195200
testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0)))
196201
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true)

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,11 +707,19 @@ object functions {
707707
/**
708708
* Computes the square root of the specified float value.
709709
*
710-
* @group normal_funcs
710+
* @group math_funcs
711711
* @since 1.3.0
712712
*/
713713
def sqrt(e: Column): Column = Sqrt(e.expr)
714714

715+
/**
716+
* Computes the square root of the specified float value.
717+
*
718+
* @group math_funcs
719+
* @since 1.4.0
720+
*/
721+
def sqrt(colName: String): Column = sqrt(Column(colName))
722+
715723
/**
716724
* Creates a new struct column. The input column must be a column in a [[DataFrame]], or
717725
* a derived column expression that is named (i.e. aliased).

sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,16 @@ class MathExpressionsSuite extends QueryTest {
257257
checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
258258
}
259259

260+
test("sqrt") {
261+
val df = Seq((1, 4)).toDF("a", "b")
262+
checkAnswer(
263+
df.select(sqrt("a"), sqrt("b")),
264+
Row(1.0, 2.0))
265+
266+
checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null))
267+
checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null))
268+
}
269+
260270
test("negative") {
261271
checkAnswer(
262272
ctx.sql("SELECT negative(1), negative(0), negative(-1)"),

0 commit comments

Comments
 (0)