From 9daf31edc74b6c2dd7b83ea55af4cc8e96690169 Mon Sep 17 00:00:00 2001 From: azagrebin Date: Tue, 17 Mar 2015 13:21:05 +0100 Subject: [PATCH 1/3] [SPARK-6117] [SQL] add describe function to DataFrame for summary statistics --- .../org/apache/spark/sql/DataFrame.scala | 80 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 20 +++++ .../scala/org/apache/spark/sql/TestData.scala | 19 +++++ 3 files changed, 119 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 46f50708a918..accf72b15170 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -751,6 +751,86 @@ class DataFrame private[sql]( select(colNames :_*) } + /** + * Compute specified aggregations for given columns of this [[DataFrame]]. + * Each row of the resulting [[DataFrame]] contains column with aggregation name + * and columns with aggregation results for each given column. + * The aggregations are described as a List of mappings of their name to function + * which generates aggregation expression from column name. + * + * Note: can process only simple aggregation expressions + * which can be parsed by spark [[SqlParser]] + * + * {{{ + * val aggregations = List( + * "max" -> (col => s"max($col)"), // expression computes max + * "avg" -> (col => s"sum($col)/count($col)")) // expression computes average + * df.multipleAggExpr("summary", aggregations, "age", "height") + * + * // summary age height + * // max 92.0 192.0 + * // avg 53.0 178.0 + * }}} + */ + @scala.annotation.varargs + private def multipleAggExpr( + aggCol: String, + aggregations: List[(String, String => String)], + cols: String*): DataFrame = { + + val sqlParser = new SqlParser() + + def addAggNameCol(aggDF: DataFrame, aggName: String = "") = + aggDF.selectExpr(s"'$aggName' as $aggCol"::cols.toList:_*) + + def unionWithNextAgg(aggSoFarDF: DataFrame, nextAgg: (String, String => String)) = + nextAgg match { case (aggName, colToAggExpr) => + val nextAggDF = if (cols.nonEmpty) { + def colToAggCol(col: String) = + Column(sqlParser.parseExpression(colToAggExpr(col))).as(col) + val aggCols = cols.map(colToAggCol) + agg(aggCols.head, aggCols.tail:_*) + } else { + sqlContext.emptyDataFrame + } + val nextAggWithNameDF = addAggNameCol(nextAggDF, aggName) + aggSoFarDF.unionAll(nextAggWithNameDF) + } + + val emptyAgg = addAggNameCol(this).limit(0) + aggregations.foldLeft(emptyAgg)(unionWithNextAgg) + } + + /** + * Compute numerical statistics for given columns of this [[DataFrame]]: + * count, mean (avg), stddev (standard deviation), min, max. + * Each row of the resulting [[DataFrame]] contains column with statistic name + * and columns with statistic results for each given column. + * If no columns are given then computes for all numerical columns. + * + * {{{ + * df.describe("age", "height") + * + * // summary age height + * // count 10.0 10.0 + * // mean 53.3 178.05 + * // stddev 11.6 15.7 + * // min 18.0 163.0 + * // max 92.0 192.0 + * }}} + */ + @scala.annotation.varargs + def describe(cols: String*): DataFrame = { + val numCols = if (cols.isEmpty) numericColumns.map(_.prettyString) else cols + val aggregations = List[(String, String => String)]( + "count" -> (col => s"count($col)"), + "mean" -> (col => s"avg($col)"), + "stddev" -> (col => s"sqrt(avg($col*$col) - avg($col)*avg($col))"), + "min" -> (col => s"min($col)"), + "max" -> (col => s"max($col)")) + multipleAggExpr("summary", aggregations, numCols:_*) + } + /** * Returns the first `n` rows. * @group action diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ff441ef26f9c..bee5a49b0bb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -436,6 +436,26 @@ class DataFrameSuite extends QueryTest { assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol")) } + test("describe") { + def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq + + val describeAllCols = describeTestData.describe("age", "height") + assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height")) + checkAnswer(describeAllCols, describeResult) + + val describeNoCols = describeTestData.describe() + assert(getSchemaAsSeq(describeNoCols) === Seq("summary", "age", "height")) + checkAnswer(describeNoCols, describeResult) + + val describeOneCol = describeTestData.describe("age") + assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) + checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} ) + + val emptyDescription = describeTestData.limit(0).describe() + assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height")) + checkAnswer(emptyDescription, emptyDescribeResult) + } + test("apply on query results (SPARK-5462)") { val df = testData.sqlContext.sql("select key from testData") checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 637f59b2e68c..d96b5be9aa9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -199,6 +199,25 @@ object TestData { Salary(1, 1000.0) :: Nil).toDF() salary.registerTempTable("salary") + case class PersonToDescribe(name: String, age: Int, height: Double) + val describeTestData = TestSQLContext.sparkContext.parallelize( + PersonToDescribe("Bob", 16, 176) :: + PersonToDescribe("Alice", 32, 164) :: + PersonToDescribe("David", 60, 192) :: + PersonToDescribe("Amy", 24, 180) :: Nil).toDF() + val describeResult = + Row("count", 4.0, 4.0) :: + Row("mean", 33.0, 178.0) :: + Row("stddev", 16.583123951777, 10.0) :: + Row("min", 16.0, 164) :: + Row("max", 60.0, 192) :: Nil + val emptyDescribeResult = + Row("count", 0, 0) :: + Row("mean", null, null) :: + Row("stddev", null, null) :: + Row("min", null, null) :: + Row("max", null, null) :: Nil + case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean) val complexData = TestSQLContext.sparkContext.parallelize( From ddb3950327058b5133644766bf35325b153e8866 Mon Sep 17 00:00:00 2001 From: azagrebin Date: Wed, 18 Mar 2015 01:12:49 +0100 Subject: [PATCH 2/3] [SPARK-6117] [SQL] simplify implementation, add test for DF without numeric columns --- .../org/apache/spark/sql/DataFrame.scala | 83 ++++++------------- .../org/apache/spark/sql/DataFrameSuite.scala | 14 ++-- .../scala/org/apache/spark/sql/TestData.scala | 4 +- 3 files changed, 37 insertions(+), 64 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index accf72b15170..1746cf0b2717 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -751,56 +751,6 @@ class DataFrame private[sql]( select(colNames :_*) } - /** - * Compute specified aggregations for given columns of this [[DataFrame]]. - * Each row of the resulting [[DataFrame]] contains column with aggregation name - * and columns with aggregation results for each given column. - * The aggregations are described as a List of mappings of their name to function - * which generates aggregation expression from column name. - * - * Note: can process only simple aggregation expressions - * which can be parsed by spark [[SqlParser]] - * - * {{{ - * val aggregations = List( - * "max" -> (col => s"max($col)"), // expression computes max - * "avg" -> (col => s"sum($col)/count($col)")) // expression computes average - * df.multipleAggExpr("summary", aggregations, "age", "height") - * - * // summary age height - * // max 92.0 192.0 - * // avg 53.0 178.0 - * }}} - */ - @scala.annotation.varargs - private def multipleAggExpr( - aggCol: String, - aggregations: List[(String, String => String)], - cols: String*): DataFrame = { - - val sqlParser = new SqlParser() - - def addAggNameCol(aggDF: DataFrame, aggName: String = "") = - aggDF.selectExpr(s"'$aggName' as $aggCol"::cols.toList:_*) - - def unionWithNextAgg(aggSoFarDF: DataFrame, nextAgg: (String, String => String)) = - nextAgg match { case (aggName, colToAggExpr) => - val nextAggDF = if (cols.nonEmpty) { - def colToAggCol(col: String) = - Column(sqlParser.parseExpression(colToAggExpr(col))).as(col) - val aggCols = cols.map(colToAggCol) - agg(aggCols.head, aggCols.tail:_*) - } else { - sqlContext.emptyDataFrame - } - val nextAggWithNameDF = addAggNameCol(nextAggDF, aggName) - aggSoFarDF.unionAll(nextAggWithNameDF) - } - - val emptyAgg = addAggNameCol(this).limit(0) - aggregations.foldLeft(emptyAgg)(unionWithNextAgg) - } - /** * Compute numerical statistics for given columns of this [[DataFrame]]: * count, mean (avg), stddev (standard deviation), min, max. @@ -821,14 +771,33 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def describe(cols: String*): DataFrame = { + + def aggCol(name: String = "") = s"'$name' as summary" + val statistics = List[(String, Expression => Expression)]( + "count" -> (expr => Count(expr)), + "mean" -> (expr => Average(expr)), + "stddev" -> (expr => Sqrt(Subtract(Average(Multiply(expr, expr)), + Multiply(Average(expr), Average(expr))))), + "min" -> (expr => Min(expr)), + "max" -> (expr => Max(expr))) + val numCols = if (cols.isEmpty) numericColumns.map(_.prettyString) else cols - val aggregations = List[(String, String => String)]( - "count" -> (col => s"count($col)"), - "mean" -> (col => s"avg($col)"), - "stddev" -> (col => s"sqrt(avg($col*$col) - avg($col)*avg($col))"), - "min" -> (col => s"min($col)"), - "max" -> (col => s"max($col)")) - multipleAggExpr("summary", aggregations, numCols:_*) + + // union all statistics starting from empty one + var description = selectExpr(aggCol()::numCols.toList:_*).limit(0) + for ((name, colToAgg) <- statistics) { + // generate next statistic aggregation + val nextAgg = if (numCols.nonEmpty) { + val aggCols = numCols.map(c => Column(colToAgg(Column(c).expr)).as(c)) + agg(aggCols.head, aggCols.tail:_*) + } else { + sqlContext.emptyDataFrame + } + // add statistic name column + val nextStat = nextAgg.selectExpr(aggCol(name)::numCols.toList:_*) + description = description.unionAll(nextStat) + } + description } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index bee5a49b0bb6..0f37664ce1b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -439,18 +439,22 @@ class DataFrameSuite extends QueryTest { test("describe") { def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq - val describeAllCols = describeTestData.describe("age", "height") + val describeTwoCols = describeTestData.describe("age", "height") + assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height")) + checkAnswer(describeTwoCols, describeResult) + + val describeAllCols = describeTestData.describe() assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height")) checkAnswer(describeAllCols, describeResult) - val describeNoCols = describeTestData.describe() - assert(getSchemaAsSeq(describeNoCols) === Seq("summary", "age", "height")) - checkAnswer(describeNoCols, describeResult) - val describeOneCol = describeTestData.describe("age") assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} ) + val describeNoCol = describeTestData.select("name").describe() + assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _) => Row(s)} ) + val emptyDescription = describeTestData.limit(0).describe() assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height")) checkAnswer(emptyDescription, emptyDescribeResult) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index d96b5be9aa9b..e4446cd5e081 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -209,8 +209,8 @@ object TestData { Row("count", 4.0, 4.0) :: Row("mean", 33.0, 178.0) :: Row("stddev", 16.583123951777, 10.0) :: - Row("min", 16.0, 164) :: - Row("max", 60.0, 192) :: Nil + Row("min", 16.0, 164.0) :: + Row("max", 60.0, 192.0) :: Nil val emptyDescribeResult = Row("count", 0, 0) :: Row("mean", null, null) :: From f9056ac510868cb3dc878b34e9259155c7aa9d88 Mon Sep 17 00:00:00 2001 From: azagrebin Date: Wed, 18 Mar 2015 10:47:20 +0100 Subject: [PATCH 3/3] [SPARK-6117] [SQL] create one aggregation and split it locally into resulting DF, colocate test data with test case --- .../org/apache/spark/sql/DataFrame.scala | 50 ++++++++++--------- .../org/apache/spark/sql/DataFrameSuite.scala | 21 ++++++++ .../scala/org/apache/spark/sql/TestData.scala | 19 ------- 3 files changed, 47 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 1746cf0b2717..10bcd7a3f171 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.jdbc.JDBCWriteDetails import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.sql.types.{NumericType, StructType, StructField, StringType} import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} import org.apache.spark.util.Utils @@ -772,32 +772,34 @@ class DataFrame private[sql]( @scala.annotation.varargs def describe(cols: String*): DataFrame = { - def aggCol(name: String = "") = s"'$name' as summary" + def stddevExpr(expr: Expression) = + Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) + val statistics = List[(String, Expression => Expression)]( - "count" -> (expr => Count(expr)), - "mean" -> (expr => Average(expr)), - "stddev" -> (expr => Sqrt(Subtract(Average(Multiply(expr, expr)), - Multiply(Average(expr), Average(expr))))), - "min" -> (expr => Min(expr)), - "max" -> (expr => Max(expr))) - - val numCols = if (cols.isEmpty) numericColumns.map(_.prettyString) else cols - - // union all statistics starting from empty one - var description = selectExpr(aggCol()::numCols.toList:_*).limit(0) - for ((name, colToAgg) <- statistics) { - // generate next statistic aggregation - val nextAgg = if (numCols.nonEmpty) { - val aggCols = numCols.map(c => Column(colToAgg(Column(c).expr)).as(c)) - agg(aggCols.head, aggCols.tail:_*) - } else { - sqlContext.emptyDataFrame + "count" -> Count, + "mean" -> Average, + "stddev" -> stddevExpr, + "min" -> Min, + "max" -> Max) + + val aggCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList + + val localAgg = if (aggCols.nonEmpty) { + val aggExprs = statistics.flatMap { case (_, colToAgg) => + aggCols.map(c => Column(colToAgg(Column(c).expr)).as(c)) } - // add statistic name column - val nextStat = nextAgg.selectExpr(aggCol(name)::numCols.toList:_*) - description = description.unionAll(nextStat) + + agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + .grouped(aggCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) => + Row(statistic :: aggregation.toList: _*) + } + } else { + statistics.map { case (name, _) => Row(name) } } - description + + val schema = StructType(("summary" :: aggCols).map(StructField(_, StringType))) + val rowRdd = sqlContext.sparkContext.parallelize(localAgg) + sqlContext.createDataFrame(rowRdd, schema) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0f37664ce1b0..ab9d1b93d05d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -437,6 +437,27 @@ class DataFrameSuite extends QueryTest { } test("describe") { + + val describeTestData = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + + val describeResult = Seq( + Row("count", 4, 4), + Row("mean", 33.0, 178.0), + Row("stddev", 16.583123951777, 10.0), + Row("min", 16, 164), + Row("max", 60, 192)) + + val emptyDescribeResult = Seq( + Row("count", 0, 0), + Row("mean", null, null), + Row("stddev", null, null), + Row("min", null, null), + Row("max", null, null)) + def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq val describeTwoCols = describeTestData.describe("age", "height") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index e4446cd5e081..637f59b2e68c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -199,25 +199,6 @@ object TestData { Salary(1, 1000.0) :: Nil).toDF() salary.registerTempTable("salary") - case class PersonToDescribe(name: String, age: Int, height: Double) - val describeTestData = TestSQLContext.sparkContext.parallelize( - PersonToDescribe("Bob", 16, 176) :: - PersonToDescribe("Alice", 32, 164) :: - PersonToDescribe("David", 60, 192) :: - PersonToDescribe("Amy", 24, 180) :: Nil).toDF() - val describeResult = - Row("count", 4.0, 4.0) :: - Row("mean", 33.0, 178.0) :: - Row("stddev", 16.583123951777, 10.0) :: - Row("min", 16.0, 164.0) :: - Row("max", 60.0, 192.0) :: Nil - val emptyDescribeResult = - Row("count", 0, 0) :: - Row("mean", null, null) :: - Row("stddev", null, null) :: - Row("min", null, null) :: - Row("max", null, null) :: Nil - case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean) val complexData = TestSQLContext.sparkContext.parallelize(