Skip to content

Commit a81d0fc

Browse files
committed
string promotion & null value handling
1 parent 0902ceb commit a81d0fc

File tree

3 files changed

+24
-7
lines changed

3 files changed

+24
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,9 @@ object HiveTypeCoercion {
297297
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
298298
case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
299299
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
300+
case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType))
301+
case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
302+
case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
300303
}
301304
}
302305

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,18 +253,21 @@ case class Min(child: Expression) extends AlgebraicAggregate {
253253
case class Stddev(child: Expression) extends StddevAgg(child) {
254254

255255
override def isSample: Boolean = true
256+
override def prettyName: String = "stddev"
256257
}
257258

258259
// Compute the population standard deviation of a column
259260
case class StddevPop(child: Expression) extends StddevAgg(child) {
260261

261262
override def isSample: Boolean = false
263+
override def prettyName: String = "stddev_pop"
262264
}
263265

264266
// Compute the sample standard deviation of a column
265267
case class StddevSamp(child: Expression) extends StddevAgg(child) {
266268

267269
override def isSample: Boolean = true
270+
override def prettyName: String = "stddev_samp"
268271
}
269272

270273
// Compute standard deviation based on online algorithm specified here:
@@ -362,11 +365,15 @@ abstract class StddevAgg(child: Expression) extends AlgebraicAggregate {
362365
}
363366

364367
Seq(
365-
/* preCount = */ currentCount.left,
366-
/* currentCount = */ countMerge,
367-
/* preAvg = */ currentAvg.left,
368-
/* currentAvg = */ avgMerge,
369-
/* currentMk = */ mkMerge
368+
/* preCount = */ If(IsNull(currentCount.left),
369+
Cast(Literal(0), resultType), currentCount.left),
370+
/* currentCount = */ If(IsNull(currentCount.left), currentCount.right,
371+
If(IsNull(currentCount.right), currentCount.left, countMerge)),
372+
/* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left),
373+
/* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right,
374+
If(IsNull(currentAvg.right), currentAvg.left, avgMerge)),
375+
/* currentMk = */ If(IsNull(currentMk.left), currentMk.right,
376+
If(IsNull(currentMk.right), currentMk.left, mkMerge))
370377
)
371378
}
372379

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
328328
testCodeGen(
329329
"SELECT min(key) FROM testData3x",
330330
Row(1) :: Nil)
331+
// STDDEV
332+
testCodeGen(
333+
"SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a",
334+
(1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25))))
335+
testCodeGen(
336+
"SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2",
337+
Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil)
331338
// Some combinations.
332339
testCodeGen(
333340
"""
@@ -348,8 +355,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
348355
Row(100, 1, 50.5, 300, 100) :: Nil)
349356
// Aggregate with Code generation handling all null values
350357
testCodeGen(
351-
"SELECT sum('a'), avg('a'), count(null) FROM testData",
352-
Row(null, null, 0) :: Nil)
358+
"SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData",
359+
Row(null, null, null, 0) :: Nil)
353360
} finally {
354361
sqlContext.dropTempTable("testData3x")
355362
sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue)

0 commit comments

Comments
 (0)