Skip to content

Commit 34b22e8

Browse files
committed
addressed comments
1 parent 941bb9e commit 34b22e8

File tree

9 files changed

+171
-28
lines changed

9 files changed

+171
-28
lines changed

python/pyspark/sql/functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ def _():
8787
'sum': 'Aggregate function: returns the sum of all values in the expression.',
8888
'avg': 'Aggregate function: returns the average of the values in a group.',
8989
'mean': 'Aggregate function: returns the average of the values in a group.',
90-
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
90+
'stddev': 'Aggregate function: returns the sample standard deviation in a group.',
91+
'stddevSamp': 'Aggregate function: returns the sample standard deviation in a group.',
92+
'stddevPop': 'Aggregate function: returns the population standard deviation in a group.',
93+
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.'
9194
}
9295

9396
_functions_1_4 = {

python/pyspark/sql/group.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,48 @@ def min(self, *cols):
154154
[Row(min(age)=2, min(height)=80)]
155155
"""
156156

157+
@df_varargs_api
158+
@since(1.5)
159+
def stddev(self, *cols):
160+
"""Computes the sample standard deviation for each numeric column for each group.
161+
Alias for stddevSamp.
162+
163+
:param cols: list of column names (string). Non-numeric columns are ignored.
164+
165+
>>> df.groupBy().stddev('age').collect()
166+
[Row(stddev_samp(age)=2.12...)]
167+
>>> df3.groupBy().stddev('age', 'height').collect()
168+
[Row(stddev_samp(age)=2.12..., stddev_samp(height)=3.53...)]
169+
"""
170+
171+
@df_varargs_api
172+
@since(1.5)
173+
def stddevPop(self, *cols):
174+
"""Computes the sample standard deviation for each numeric column for each group.
175+
Alias for stddevSamp.
176+
177+
:param cols: list of column names (string). Non-numeric columns are ignored.
178+
179+
>>> df.groupBy().stddevPop('age').collect()
180+
[Row(min(age)=1.06...)]
181+
>>> df3.groupBy().stddevPop('age', 'height').collect()
182+
[Row(min(age)=1.06..., min(height)=1.76...)]
183+
"""
184+
185+
@df_varargs_api
186+
@since(1.5)
187+
def stddevSamp(self, *cols):
188+
"""Computes the sample standard deviation for each numeric column for each group.
189+
Alias for stddevSamp.
190+
191+
:param cols: list of column names (string). Non-numeric columns are ignored.
192+
193+
>>> df.groupBy().stddevSamp('age').collect()
194+
[Row(stddev_samp(age)=2.12...)]
195+
>>> df3.groupBy().stddevSamp('age', 'height').collect()
196+
[Row(stddev_samp(age)=2.12..., stddev_samp(height)=3.53...)]
197+
"""
198+
157199
@df_varargs_api
158200
@since(1.3)
159201
def sum(self, *cols):

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,11 @@ case class Sum(child: Expression) extends AlgebraicAggregate {
304304
}
305305

306306
/**
307-
* Calculates the unbiased Standard Deviation using the online formula here:
307+
* Calculates the Standard Deviation using the online formula here:
308308
* https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
309+
* If sample is true, then we will return the unbiased standard deviation.
309310
*/
310-
case class StandardDeviation(child: Expression) extends AlgebraicAggregate {
311+
case class StandardDeviation(child: Expression, sample: Boolean) extends AlgebraicAggregate {
311312

312313
override def children: Seq[Expression] = child :: Nil
313314

@@ -388,8 +389,14 @@ case class StandardDeviation(child: Expression) extends AlgebraicAggregate {
388389
}
389390

390391
override lazy val evaluateExpression = {
391-
val count = If(EqualTo(currentCount, Cast(Literal(0L), LongType)),
392-
currentCount, currentCount - Cast(Literal(1L), LongType))
392+
val count =
393+
if (sample) {
394+
If(EqualTo(currentCount, Cast(Literal(0L), LongType)), currentCount,
395+
currentCount - Cast(Literal(1L), LongType))
396+
} else {
397+
currentCount
398+
}
399+
393400
child.dataType match {
394401
case DecimalType.Fixed(p, s) =>
395402
// increase the precision and scale to prevent precision loss

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,13 @@ object Utils {
170170
* format, but standard deviation uses the new format directly. We wrap it here in one place,
171171
* and use an Alias so that the column name looks pretty as well instead of a long identifier.
172172
*/
173-
def standardDeviation(e: Expression): Expression = {
173+
def standardDeviation(e: Expression, sample: Boolean, name: String): Expression = {
174174
val std = aggregate.AggregateExpression2(
175-
aggregateFunction = aggregate.StandardDeviation(e),
175+
aggregateFunction = aggregate.StandardDeviation(e, sample),
176176
mode = aggregate.Complete,
177177
isDistinct = false)
178-
Alias(std, s"std(${e.prettyString})")()
178+
Alias(std, s"$name(${e.prettyString})")()
179179
}
180+
181+
def sampleStandardDeviation(e: Expression): Expression = standardDeviation(e, true, "stddev_samp")
180182
}

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1273,7 +1273,7 @@ class DataFrame private[sql](
12731273
val statistics = List[(String, Expression => Expression)](
12741274
"count" -> Count,
12751275
"mean" -> Average,
1276-
"stddev" -> aggregate.Utils.standardDeviation,
1276+
"stddev" -> aggregate.Utils.sampleStandardDeviation,
12771277
"min" -> Min,
12781278
"max" -> Max)
12791279

sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,33 @@ class GroupedData protected[sql](
291291
* @since 1.5.0
292292
*/
293293
@scala.annotation.varargs
294-
def std(colNames: String*): DataFrame = {
295-
aggregateNumericColumns(colNames : _*)(aggregate.Utils.standardDeviation)
294+
def stddev(colNames: String*): DataFrame = {
295+
stddevSamp(colNames : _*)
296+
}
297+
298+
/**
299+
* Compute the population standard deviation for each numeric column for each group.
300+
* The resulting [[DataFrame]] will also contain the grouping columns.
301+
* When specified columns are given, only compute the standard deviation for them.
302+
*
303+
* @since 1.5.0
304+
*/
305+
@scala.annotation.varargs
306+
def stddevPop(colNames: String*): DataFrame = {
307+
aggregateNumericColumns(colNames : _*)(aggregate.Utils.standardDeviation(_, sample = false,
308+
"stddev_pop"))
309+
}
310+
311+
/**
312+
* Compute the sample standard deviation for each numeric column for each group.
313+
* The resulting [[DataFrame]] will also contain the grouping columns.
314+
* When specified columns are given, only compute the standard deviation for them.
315+
*
316+
* @since 1.5.0
317+
*/
318+
@scala.annotation.varargs
319+
def stddevSamp(colNames: String*): DataFrame = {
320+
aggregateNumericColumns(colNames : _*)(aggregate.Utils.sampleStandardDeviation)
296321
}
297322

298323
/**

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag}
2222
import scala.util.Try
2323

2424
import org.apache.spark.annotation.Experimental
25-
import org.apache.spark.sql.catalyst.expressions.aggregate.StandardDeviation
2625
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
2726
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
2827
import org.apache.spark.sql.catalyst.expressions._
@@ -297,19 +296,53 @@ object functions {
297296

298297
/**
299298
* Aggregate function: returns the sample standard deviation of the values in a group.
299+
* Alias for stddevSamp.
300300
*
301301
* @group agg_funcs
302302
* @since 1.5.0
303303
*/
304-
def std(e: Column): Column = aggregate.Utils.standardDeviation(e.expr)
304+
def stddev(e: Column): Column = stddevSamp(e)
305305

306306
/**
307307
* Aggregate function: returns the sample standard deviation of the values in a group.
308+
* Alias for stddevSamp.
308309
*
309310
* @group agg_funcs
310311
* @since 1.5.0
311312
*/
312-
def std(columnName: String): Column = std(Column(columnName))
313+
def stddev(columnName: String): Column = stddev(Column(columnName))
314+
315+
/**
316+
* Aggregate function: returns the population standard deviation of the values in a group.
317+
*
318+
* @group agg_funcs
319+
* @since 1.5.0
320+
*/
321+
def stddevPop(e: Column): Column = aggregate.Utils.standardDeviation(e.expr, false, "stddev_pomp")
322+
323+
/**
324+
* Aggregate function: returns the population standard deviation of the values in a group.
325+
*
326+
* @group agg_funcs
327+
* @since 1.5.0
328+
*/
329+
def stddevPop(columnName: String): Column = stddevPop(Column(columnName))
330+
331+
/**
332+
* Aggregate function: returns the sample standard deviation of the values in a group.
333+
*
334+
* @group agg_funcs
335+
* @since 1.5.0
336+
*/
337+
def stddevSamp(e: Column): Column = aggregate.Utils.sampleStandardDeviation(e.expr)
338+
339+
/**
340+
* Aggregate function: returns the sample standard deviation of the values in a group.
341+
*
342+
* @group agg_funcs
343+
* @since 1.5.0
344+
*/
345+
def stddevSamp(columnName: String): Column = stddev(Column(columnName))
313346

314347
/**
315348
* Aggregate function: returns the sum of all values in the expression.

sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,6 @@ object QueryTest {
116116
Row.fromSeq(s.toSeq.map {
117117
case d: java.math.BigDecimal => BigDecimal(d)
118118
case b: Array[Byte] => b.toSeq
119-
case d: Double if !d.isNaN && !d.isInfinity =>
120-
BigDecimal(d).setScale(10, BigDecimal.RoundingMode.HALF_UP)
121119
case o => o
122120
})
123121
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.hive.test.TestHive
2222
import org.apache.spark.sql.test.SQLTestUtils
2323
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
2424
import org.apache.spark.sql._
25-
import org.apache.spark.sql.functions.std
25+
import org.apache.spark.sql.functions.{stddev, stddevPop}
2626
import org.scalatest.BeforeAndAfterAll
2727
import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
2828

@@ -285,30 +285,63 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
285285
Row(11.125) :: Nil)
286286
}
287287

288+
/** For resilience against rounding mismatches. */
289+
private def about(d: Double): BigDecimal = BigDecimal(d).setScale(10, BigDecimal.RoundingMode.HALF_UP)
290+
288291
test("test standard deviation") {
289292
// All results generated in R. Comparisons will be performed up to 10 digits of precision.
290293
val df = Seq.tabulate(10)(i => (i, 1)).toDF("val", "key")
291294
checkAnswer(
292-
df.select(std("val")),
293-
Row(3.0276503540974917) :: Nil)
295+
df.select(stddev("val").cast("decimal(12, 10)")),
296+
Row(about(3.0276503540974917)) :: Nil)
297+
298+
checkAnswer(
299+
df.select(stddevPop("val").cast("decimal(12, 10)")),
300+
Row(about(2.8722813232690148)) :: Nil)
301+
302+
checkAnswer(
303+
sqlContext.table("agg1").groupBy("key").stddev("value")
304+
.select($"key", $"stddev_samp(value)".cast("decimal(12, 10)")),
305+
Row(1, about(10.0)) :: Row(2, about(0.7071067811865476)) :: Row(3, null) ::
306+
Row(null, about(81.8535277187245)) :: Nil)
307+
308+
checkAnswer(
309+
sqlContext.table("agg1").groupBy("key").stddevPop("value")
310+
.select($"key", $"stddev_pop(value)".cast("decimal(12, 10)")),
311+
Row(1, about(8.16496580927726)) :: Row(2, about(0.5)) :: Row(3, null) ::
312+
Row(null, about(66.83312551921139)) :: Nil)
294313

295314
checkAnswer(
296-
sqlContext.table("agg1").groupBy("key").std("value"),
297-
Row(1, 10.0) :: Row(2, 0.7071067811865476) :: Row(3, null) ::
298-
Row(null, 81.8535277187245) :: Nil)
315+
sqlContext.table("agg1").select(stddev("key").cast("decimal(12, 10)"),
316+
stddev("value").cast("decimal(12, 10)")),
317+
Row(about(0.7817359599705717), about(44.898098909801135)) :: Nil)
299318

300319
checkAnswer(
301-
sqlContext.table("agg1").select(std("key"), std("value")),
302-
Row(0.7817359599705717, 44.898098909801135) :: Nil)
320+
sqlContext.table("agg1").select(stddevPop("key").cast("decimal(12, 10)"),
321+
stddevPop("value").cast("decimal(12, 10)")),
322+
Row(about(0.7370277311900889), about(41.99832585949111)) :: Nil)
303323

304324
checkAnswer(
305-
sqlContext.table("agg2").groupBy("key", "value1").std("value2"),
306-
Row(1, 10, null) :: Row(1, 30, 42.42640687119285) :: Row(2, -1, null) ::
307-
Row(2, 1, 0.0) :: Row(2, null, null) :: Row(3, null, null) :: Row(null, -10, null) ::
325+
sqlContext.table("agg2").groupBy("key", "value1").stddev("value2")
326+
.select($"key", $"value1", $"stddev_samp(value2)".cast("decimal(12, 10)")),
327+
Row(1, 10, null) :: Row(1, 30, about(42.42640687119285)) :: Row(2, -1, null) ::
328+
Row(2, 1, about(0.0)) :: Row(2, null, null) :: Row(3, null, null) :: Row(null, -10, null) ::
308329
Row(null, -60, null) :: Row(null, 100, null) :: Row(null, null, null) :: Nil)
309330

310331
checkAnswer(
311-
sqlContext.table("emptyTable").select(std("value")),
332+
sqlContext.table("agg2").groupBy("key", "value1").stddevPop("value2")
333+
.select($"key", $"value1", $"stddev_pop(value2)".cast("decimal(12, 10)")),
334+
Row(1, 10, about(0.0)) :: Row(1, 30, about(30.0)) :: Row(2, -1, null) ::
335+
Row(2, 1, about(0.0)) :: Row(2, null, about(0.0)) :: Row(3, null, about(0.0)) ::
336+
Row(null, -10, about(0.0)) :: Row(null, -60, about(0.0)) :: Row(null, 100, about(0.0)) ::
337+
Row(null, null, null) :: Nil)
338+
339+
checkAnswer(
340+
sqlContext.table("emptyTable").select(stddev("value")),
341+
Row(null) :: Nil)
342+
343+
checkAnswer(
344+
sqlContext.table("emptyTable").select(stddevPop("value")),
312345
Row(null) :: Nil)
313346
}
314347

0 commit comments

Comments
 (0)