@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
4141import org .apache .spark .sql .execution .{EvaluatePython , ExplainCommand , LogicalRDD }
4242import org .apache .spark .sql .jdbc .JDBCWriteDetails
4343import org .apache .spark .sql .json .JsonRDD
44- import org .apache .spark .sql .types .{NumericType , StructType }
44+ import org .apache .spark .sql .types .{NumericType , StructType , StructField , StringType }
4545import org .apache .spark .sql .sources .{ResolvedDataSource , CreateTableUsingAsSelect }
4646import org .apache .spark .util .Utils
4747
@@ -772,32 +772,34 @@ class DataFrame private[sql](
772772 @ scala.annotation.varargs
773773 def describe (cols : String * ): DataFrame = {
774774
775- def aggCol (name : String = " " ) = s " ' $name' as summary "
775+ def stddevExpr (expr : Expression ) =
776+ Sqrt (Subtract (Average (Multiply (expr, expr)), Multiply (Average (expr), Average (expr))))
777+
776778 val statistics = List [(String , Expression => Expression )](
777- " count" -> (expr => Count (expr)),
778- " mean" -> (expr => Average (expr)),
779- " stddev" -> (expr => Sqrt (Subtract (Average (Multiply (expr, expr)),
780- Multiply (Average (expr), Average (expr))))),
781- " min" -> (expr => Min (expr)),
782- " max" -> (expr => Max (expr)))
783-
784- val numCols = if (cols.isEmpty) numericColumns.map(_.prettyString) else cols
785-
786- // union all statistics starting from empty one
787- var description = selectExpr(aggCol():: numCols.toList:_* ).limit(0 )
788- for ((name, colToAgg) <- statistics) {
789- // generate next statistic aggregation
790- val nextAgg = if (numCols.nonEmpty) {
791- val aggCols = numCols.map(c => Column (colToAgg(Column (c).expr)).as(c))
792- agg(aggCols.head, aggCols.tail:_* )
793- } else {
794- sqlContext.emptyDataFrame
779+ " count" -> Count ,
780+ " mean" -> Average ,
781+ " stddev" -> stddevExpr,
782+ " min" -> Min ,
783+ " max" -> Max )
784+
785+ val aggCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList
786+
787+ val localAgg = if (aggCols.nonEmpty) {
788+ val aggExprs = statistics.flatMap { case (_, colToAgg) =>
789+ aggCols.map(c => Column (colToAgg(Column (c).expr)).as(c))
795790 }
796- // add statistic name column
797- val nextStat = nextAgg.selectExpr(aggCol(name):: numCols.toList:_* )
798- description = description.unionAll(nextStat)
791+
792+ agg(aggExprs.head, aggExprs.tail: _* ).head().toSeq
793+ .grouped(aggCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) =>
794+ Row (statistic :: aggregation.toList: _* )
795+ }
796+ } else {
797+ statistics.map { case (name, _) => Row (name) }
799798 }
800- description
799+
800+ val schema = StructType ((" summary" :: aggCols).map(StructField (_, StringType )))
801+ val rowRdd = sqlContext.sparkContext.parallelize(localAgg)
802+ sqlContext.createDataFrame(rowRdd, schema)
801803 }
802804
803805 /**
0 commit comments