@@ -24,8 +24,8 @@ import org.apache.spark.annotation.Experimental
2424import org .apache .spark .sql .catalyst .analysis .{UnresolvedFunction , UnresolvedAlias , UnresolvedAttribute , Star }
2525import org .apache .spark .sql .catalyst .expressions ._
2626import org .apache .spark .sql .catalyst .expressions .aggregate ._
27- import org .apache .spark .sql .catalyst .plans .logical .{Rollup , Cube , Aggregate }
28- import org .apache .spark .sql .types .NumericType
27+ import org .apache .spark .sql .catalyst .plans .logical .{Pivot , Rollup , Cube , Aggregate }
28+ import org .apache .spark .sql .types .{ StringType , NumericType }
2929
3030
3131/**
@@ -50,14 +50,8 @@ class GroupedData protected[sql](
5050 aggExprs
5151 }
5252
53- val aliasedAgg = aggregates.map {
54- // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
55- // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
56- // make it a NamedExpression.
57- case u : UnresolvedAttribute => UnresolvedAlias (u)
58- case expr : NamedExpression => expr
59- case expr : Expression => Alias (expr, expr.prettyString)()
60- }
53+ val aliasedAgg = aggregates.map(alias)
54+
6155 groupType match {
6256 case GroupedData .GroupByType =>
6357 DataFrame (
@@ -68,9 +62,22 @@ class GroupedData protected[sql](
6862 case GroupedData .CubeType =>
6963 DataFrame (
7064 df.sqlContext, Cube (groupingExprs, df.logicalPlan, aliasedAgg))
65+ case GroupedData .PivotType (pivotCol, values) =>
66+ val aliasedGrps = groupingExprs.map(alias)
67+ DataFrame (
68+ df.sqlContext, Pivot (aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
7169 }
7270 }
7371
72+ // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
73+ // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
74+ // make it a NamedExpression.
75+ private [this ] def alias (expr : Expression ): NamedExpression = expr match {
76+ case u : UnresolvedAttribute => UnresolvedAlias (u)
77+ case expr : NamedExpression => expr
78+ case expr : Expression => Alias (expr, expr.prettyString)()
79+ }
80+
7481 private [this ] def aggregateNumericColumns (colNames : String * )(f : Expression => AggregateFunction )
7582 : DataFrame = {
7683
@@ -273,6 +280,77 @@ class GroupedData protected[sql](
273280 def sum (colNames : String * ): DataFrame = {
274281 aggregateNumericColumns(colNames : _* )(Sum )
275282 }
283+
284+ /**
285+ * (Scala-specific) Pivots a column of the current [[DataFrame ]] and preform the specified
286+ * aggregation.
287+ * {{{
288+ * // Compute the sum of earnings for each year by course with each course as a separate column
289+ * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings"))
290+ * // Or without specifying column values
291+ * df.groupBy($"year").pivot($"course").agg(sum($"earnings"))
292+ * }}}
293+ * @param pivotColumn Column to pivot
294+ * @param values Optional list of values of pivotColumn that will be translated to columns in the
295+ * output data frame. If values are not provided the method with do an immediate
296+ * call to .distinct() on the pivot column.
297+ * @since 1.6.0
298+ */
299+ @ scala.annotation.varargs
300+ def pivot (pivotColumn : Column , values : Column * ): GroupedData = groupType match {
301+ case _ : GroupedData .PivotType =>
302+ throw new UnsupportedOperationException (" repeated pivots are not supported" )
303+ case GroupedData .GroupByType =>
304+ val pivotValues = if (values.nonEmpty) {
305+ values.map {
306+ case Column (literal : Literal ) => literal
307+ case other =>
308+ throw new UnsupportedOperationException (
309+ s " The values of a pivot must be literals, found $other" )
310+ }
311+ } else {
312+ // This is to prevent unintended OOM errors when the number of distinct values is large
313+ val maxValues = df.sqlContext.conf.getConf(SQLConf .DATAFRAME_PIVOT_MAX_VALUES )
314+ // Get the distinct values of the column and sort them so its consistent
315+ val values = df.select(pivotColumn)
316+ .distinct()
317+ .sort(pivotColumn)
318+ .map(_.get(0 ))
319+ .take(maxValues + 1 )
320+ .map(Literal (_)).toSeq
321+ if (values.length > maxValues) {
322+ throw new RuntimeException (
323+ s " The pivot column $pivotColumn has more than $maxValues distinct values, " +
324+ " this could indicate an error. " +
325+ " If this was intended, set \" " + SQLConf .DATAFRAME_PIVOT_MAX_VALUES .key + " \" " +
326+ s " to at least the number of distinct values of the pivot column. " )
327+ }
328+ values
329+ }
330+ new GroupedData (df, groupingExprs, GroupedData .PivotType (pivotColumn.expr, pivotValues))
331+ case _ =>
332+ throw new UnsupportedOperationException (" pivot is only supported after a groupBy" )
333+ }
334+
335+ /**
336+ * Pivots a column of the current [[DataFrame ]] and preform the specified aggregation.
337+ * {{{
338+ * // Compute the sum of earnings for each year by course with each course as a separate column
339+ * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings")
340+ * // Or without specifying column values
341+ * df.groupBy("year").pivot("course").sum("earnings")
342+ * }}}
343+ * @param pivotColumn Column to pivot
344+ * @param values Optional list of values of pivotColumn that will be translated to columns in the
345+ * output data frame. If values are not provided the method with do an immediate
346+ * call to .distinct() on the pivot column.
347+ * @since 1.6.0
348+ */
349+ @ scala.annotation.varargs
350+ def pivot (pivotColumn : String , values : Any * ): GroupedData = {
351+ val resolvedPivotColumn = Column (df.resolve(pivotColumn))
352+ pivot(resolvedPivotColumn, values.map(functions.lit): _* )
353+ }
276354}
277355
278356
@@ -307,4 +385,9 @@ private[sql] object GroupedData {
307385 * To indicate it's the ROLLUP
308386 */
309387 private [sql] object RollupType extends GroupType
388+
389+ /**
390+ * To indicate it's the PIVOT
391+ */
392+ private [sql] case class PivotType (pivotCol : Expression , values : Seq [Literal ]) extends GroupType
310393}
0 commit comments