Skip to content

Commit f9056ac

Browse files
committed
[SPARK-6117] [SQL] create one aggregation and split it locally into resulting DF, colocate test data with test case
1 parent ddb3950 commit f9056ac

File tree

3 files changed

+47
-43
lines changed

3 files changed

+47
-43
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
4141
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
4242
import org.apache.spark.sql.jdbc.JDBCWriteDetails
4343
import org.apache.spark.sql.json.JsonRDD
44-
import org.apache.spark.sql.types.{NumericType, StructType}
44+
import org.apache.spark.sql.types.{NumericType, StructType, StructField, StringType}
4545
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
4646
import org.apache.spark.util.Utils
4747

@@ -772,32 +772,34 @@ class DataFrame private[sql](
772772
@scala.annotation.varargs
773773
def describe(cols: String*): DataFrame = {
774774

775-
def aggCol(name: String = "") = s"'$name' as summary"
775+
def stddevExpr(expr: Expression) =
776+
Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr))))
777+
776778
val statistics = List[(String, Expression => Expression)](
777-
"count" -> (expr => Count(expr)),
778-
"mean" -> (expr => Average(expr)),
779-
"stddev" -> (expr => Sqrt(Subtract(Average(Multiply(expr, expr)),
780-
Multiply(Average(expr), Average(expr))))),
781-
"min" -> (expr => Min(expr)),
782-
"max" -> (expr => Max(expr)))
783-
784-
val numCols = if (cols.isEmpty) numericColumns.map(_.prettyString) else cols
785-
786-
// union all statistics starting from empty one
787-
var description = selectExpr(aggCol()::numCols.toList:_*).limit(0)
788-
for ((name, colToAgg) <- statistics) {
789-
// generate next statistic aggregation
790-
val nextAgg = if (numCols.nonEmpty) {
791-
val aggCols = numCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
792-
agg(aggCols.head, aggCols.tail:_*)
793-
} else {
794-
sqlContext.emptyDataFrame
779+
"count" -> Count,
780+
"mean" -> Average,
781+
"stddev" -> stddevExpr,
782+
"min" -> Min,
783+
"max" -> Max)
784+
785+
val aggCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList
786+
787+
val localAgg = if (aggCols.nonEmpty) {
788+
val aggExprs = statistics.flatMap { case (_, colToAgg) =>
789+
aggCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
795790
}
796-
// add statistic name column
797-
val nextStat = nextAgg.selectExpr(aggCol(name)::numCols.toList:_*)
798-
description = description.unionAll(nextStat)
791+
792+
agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
793+
.grouped(aggCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) =>
794+
Row(statistic :: aggregation.toList: _*)
795+
}
796+
} else {
797+
statistics.map { case (name, _) => Row(name) }
799798
}
800-
description
799+
800+
val schema = StructType(("summary" :: aggCols).map(StructField(_, StringType)))
801+
val rowRdd = sqlContext.sparkContext.parallelize(localAgg)
802+
sqlContext.createDataFrame(rowRdd, schema)
801803
}
802804

803805
/**

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,27 @@ class DataFrameSuite extends QueryTest {
437437
}
438438

439439
test("describe") {
440+
441+
val describeTestData = Seq(
442+
("Bob", 16, 176),
443+
("Alice", 32, 164),
444+
("David", 60, 192),
445+
("Amy", 24, 180)).toDF("name", "age", "height")
446+
447+
val describeResult = Seq(
448+
Row("count", 4, 4),
449+
Row("mean", 33.0, 178.0),
450+
Row("stddev", 16.583123951777, 10.0),
451+
Row("min", 16, 164),
452+
Row("max", 60, 192))
453+
454+
val emptyDescribeResult = Seq(
455+
Row("count", 0, 0),
456+
Row("mean", null, null),
457+
Row("stddev", null, null),
458+
Row("min", null, null),
459+
Row("max", null, null))
460+
440461
def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq
441462

442463
val describeTwoCols = describeTestData.describe("age", "height")

sql/core/src/test/scala/org/apache/spark/sql/TestData.scala

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -199,25 +199,6 @@ object TestData {
199199
Salary(1, 1000.0) :: Nil).toDF()
200200
salary.registerTempTable("salary")
201201

202-
case class PersonToDescribe(name: String, age: Int, height: Double)
203-
val describeTestData = TestSQLContext.sparkContext.parallelize(
204-
PersonToDescribe("Bob", 16, 176) ::
205-
PersonToDescribe("Alice", 32, 164) ::
206-
PersonToDescribe("David", 60, 192) ::
207-
PersonToDescribe("Amy", 24, 180) :: Nil).toDF()
208-
val describeResult =
209-
Row("count", 4.0, 4.0) ::
210-
Row("mean", 33.0, 178.0) ::
211-
Row("stddev", 16.583123951777, 10.0) ::
212-
Row("min", 16.0, 164.0) ::
213-
Row("max", 60.0, 192.0) :: Nil
214-
val emptyDescribeResult =
215-
Row("count", 0, 0) ::
216-
Row("mean", null, null) ::
217-
Row("stddev", null, null) ::
218-
Row("min", null, null) ::
219-
Row("max", null, null) :: Nil
220-
221202
case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean)
222203
val complexData =
223204
TestSQLContext.sparkContext.parallelize(

0 commit comments

Comments
 (0)