Skip to content

Commit 9daf31e

Browse files
committed
[SPARK-6117] [SQL] add describe function to DataFrame for summary statistics
1 parent e26db9b commit 9daf31e

File tree

3 files changed

+119
-0
lines changed

3 files changed

+119
-0
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,86 @@ class DataFrame private[sql](
751751
select(colNames :_*)
752752
}
753753

754+
/**
755+
* Compute specified aggregations for given columns of this [[DataFrame]].
756+
* Each row of the resulting [[DataFrame]] contains column with aggregation name
757+
* and columns with aggregation results for each given column.
758+
* The aggregations are described as a List of mappings of their name to function
759+
* which generates aggregation expression from column name.
760+
*
761+
* Note: can process only simple aggregation expressions
762+
* which can be parsed by spark [[SqlParser]]
763+
*
764+
* {{{
765+
* val aggregations = List(
766+
* "max" -> (col => s"max($col)"), // expression computes max
767+
* "avg" -> (col => s"sum($col)/count($col)")) // expression computes average
768+
* df.multipleAggExpr("summary", aggregations, "age", "height")
769+
*
770+
* // summary age height
771+
* // max 92.0 192.0
772+
* // avg 53.0 178.0
773+
* }}}
774+
*/
775+
@scala.annotation.varargs
776+
private def multipleAggExpr(
777+
aggCol: String,
778+
aggregations: List[(String, String => String)],
779+
cols: String*): DataFrame = {
780+
781+
val sqlParser = new SqlParser()
782+
783+
def addAggNameCol(aggDF: DataFrame, aggName: String = "") =
784+
aggDF.selectExpr(s"'$aggName' as $aggCol"::cols.toList:_*)
785+
786+
def unionWithNextAgg(aggSoFarDF: DataFrame, nextAgg: (String, String => String)) =
787+
nextAgg match { case (aggName, colToAggExpr) =>
788+
val nextAggDF = if (cols.nonEmpty) {
789+
def colToAggCol(col: String) =
790+
Column(sqlParser.parseExpression(colToAggExpr(col))).as(col)
791+
val aggCols = cols.map(colToAggCol)
792+
agg(aggCols.head, aggCols.tail:_*)
793+
} else {
794+
sqlContext.emptyDataFrame
795+
}
796+
val nextAggWithNameDF = addAggNameCol(nextAggDF, aggName)
797+
aggSoFarDF.unionAll(nextAggWithNameDF)
798+
}
799+
800+
val emptyAgg = addAggNameCol(this).limit(0)
801+
aggregations.foldLeft(emptyAgg)(unionWithNextAgg)
802+
}
803+
804+
/**
805+
* Compute numerical statistics for given columns of this [[DataFrame]]:
806+
* count, mean (avg), stddev (standard deviation), min, max.
807+
* Each row of the resulting [[DataFrame]] contains column with statistic name
808+
* and columns with statistic results for each given column.
809+
* If no columns are given then computes for all numerical columns.
810+
*
811+
* {{{
812+
* df.describe("age", "height")
813+
*
814+
* // summary age height
815+
* // count 10.0 10.0
816+
* // mean 53.3 178.05
817+
* // stddev 11.6 15.7
818+
* // min 18.0 163.0
819+
* // max 92.0 192.0
820+
* }}}
821+
*/
822+
@scala.annotation.varargs
823+
def describe(cols: String*): DataFrame = {
824+
val numCols = if (cols.isEmpty) numericColumns.map(_.prettyString) else cols
825+
val aggregations = List[(String, String => String)](
826+
"count" -> (col => s"count($col)"),
827+
"mean" -> (col => s"avg($col)"),
828+
"stddev" -> (col => s"sqrt(avg($col*$col) - avg($col)*avg($col))"),
829+
"min" -> (col => s"min($col)"),
830+
"max" -> (col => s"max($col)"))
831+
multipleAggExpr("summary", aggregations, numCols:_*)
832+
}
833+
754834
/**
755835
* Returns the first `n` rows.
756836
* @group action

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,26 @@ class DataFrameSuite extends QueryTest {
436436
assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
437437
}
438438

439+
test("describe") {
440+
def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq
441+
442+
val describeAllCols = describeTestData.describe("age", "height")
443+
assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height"))
444+
checkAnswer(describeAllCols, describeResult)
445+
446+
val describeNoCols = describeTestData.describe()
447+
assert(getSchemaAsSeq(describeNoCols) === Seq("summary", "age", "height"))
448+
checkAnswer(describeNoCols, describeResult)
449+
450+
val describeOneCol = describeTestData.describe("age")
451+
assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age"))
452+
checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} )
453+
454+
val emptyDescription = describeTestData.limit(0).describe()
455+
assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height"))
456+
checkAnswer(emptyDescription, emptyDescribeResult)
457+
}
458+
439459
test("apply on query results (SPARK-5462)") {
440460
val df = testData.sqlContext.sql("select key from testData")
441461
checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq)

sql/core/src/test/scala/org/apache/spark/sql/TestData.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,25 @@ object TestData {
199199
Salary(1, 1000.0) :: Nil).toDF()
200200
salary.registerTempTable("salary")
201201

202+
case class PersonToDescribe(name: String, age: Int, height: Double)
203+
val describeTestData = TestSQLContext.sparkContext.parallelize(
204+
PersonToDescribe("Bob", 16, 176) ::
205+
PersonToDescribe("Alice", 32, 164) ::
206+
PersonToDescribe("David", 60, 192) ::
207+
PersonToDescribe("Amy", 24, 180) :: Nil).toDF()
208+
val describeResult =
209+
Row("count", 4.0, 4.0) ::
210+
Row("mean", 33.0, 178.0) ::
211+
Row("stddev", 16.583123951777, 10.0) ::
212+
Row("min", 16.0, 164) ::
213+
Row("max", 60.0, 192) :: Nil
214+
val emptyDescribeResult =
215+
Row("count", 0, 0) ::
216+
Row("mean", null, null) ::
217+
Row("stddev", null, null) ::
218+
Row("min", null, null) ::
219+
Row("max", null, null) :: Nil
220+
202221
case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean)
203222
val complexData =
204223
TestSQLContext.sparkContext.parallelize(

0 commit comments

Comments
 (0)