diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 0906c9c6b329a..cc1da8e7c1f72 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -211,6 +211,8 @@ def pivot(self, pivot_col, values=None): >>> df4.groupBy("year").pivot("course").sum("earnings").collect() [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] + >>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").collect() + [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] """ if values is None: jgd = self._jgd.pivot(pivot_col) @@ -296,6 +298,12 @@ def _test(): Row(course="dotNET", year=2012, earnings=5000), Row(course="dotNET", year=2013, earnings=48000), Row(course="Java", year=2013, earnings=30000)]).toDF() + globs['df5'] = sc.parallelize([ + Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)), + Row(training="junior", sales=Row(course="Java", year=2012, earnings=20000)), + Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)), + Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)), + Row(training="expert", sales=Row(course="Java", year=2013, earnings=30000))]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.group, globs=globs, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 8412219b1250b..ed39eac5598d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -314,7 +314,67 @@ class RelationalGroupedDataset protected[sql]( * @param pivotColumn Name of the column to pivot. * @since 1.6.0 */ - def pivot(pivotColumn: String): RelationalGroupedDataset = { + def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(Column(pivotColumn)) + + /** + * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { + pivot(Column(pivotColumn), values) + } + + /** + * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified + * aggregation. + * + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings"); + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { + pivot(Column(pivotColumn), values) + } + + /** + * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type. + * + * {{{ + * // Or without specifying column values (less efficient) + * df.groupBy($"year").pivot($"course").sum($"earnings"); + * }}} + * + * @param pivotColumn he column to pivot. + * @since 2.4.0 + */ + def pivot(pivotColumn: Column): RelationalGroupedDataset = { // This is to prevent unintended OOM errors when the number of distinct values is large val maxValues = df.sparkSession.sessionState.conf.dataFramePivotMaxValues // Get the distinct values of the column and sort them so its consistent @@ -339,29 +399,24 @@ class RelationalGroupedDataset protected[sql]( /** * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * There are two versions of pivot function: one that requires the caller to specify the list - * of distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. + * This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type. * * {{{ * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings") + * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") * }}} * - * @param pivotColumn Name of the column to pivot. + * @param pivotColumn the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. - * @since 1.6.0 + * @since 2.4.0 */ - def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { + def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { groupType match { case RelationalGroupedDataset.GroupByType => new RelationalGroupedDataset( df, groupingExprs, - RelationalGroupedDataset.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) + RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply))) case _: RelationalGroupedDataset.PivotType => throw new UnsupportedOperationException("repeated pivots are not supported") case _ => @@ -371,25 +426,14 @@ class RelationalGroupedDataset protected[sql]( /** * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified - * aggregation. - * - * There are two versions of pivot function: one that requires the caller to specify the list - * of distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. + * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of + * the `String` type. * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings"); - * }}} - * - * @param pivotColumn Name of the column to pivot. + * @param pivotColumn the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. - * @since 1.6.0 + * @since 2.4.0 */ - def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { + def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = { pivot(pivotColumn, values.asScala) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 6ca9ee57e8f49..b972b9ef93e5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -27,28 +27,40 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("pivot courses") { + val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings")), - Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"year").pivot($"course", Seq("dotNET", "Java")) + .agg(sum($"earnings")), + expected) } test("pivot year") { + val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")), - Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy('course).pivot('year, Seq(2012, 2013)).agg(sum('earnings)), + expected) } test("pivot courses with multiple aggregations") { + val expected = Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: + Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy($"year") .pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings"), avg($"earnings")), - Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: - Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"year") + .pivot($"course", Seq("dotNET", "Java")) + .agg(sum($"earnings"), avg($"earnings")), + expected) } test("pivot year with string values (cast)") { @@ -67,17 +79,23 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { test("pivot courses with no values") { // Note Java comes before dotNet in sorted order + val expected = Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil checkAnswer( courseSales.groupBy("year").pivot("course").agg(sum($"earnings")), - Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")), + expected) } test("pivot year with no values") { + val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy("course").pivot("year").agg(sum($"earnings")), - Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")), + expected) } test("pivot max values enforced") { @@ -181,10 +199,13 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { } test("pivot with datatype not supported by PivotFirst") { + val expected = Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil checkAnswer( complexData.groupBy().pivot("b", Seq(true, false)).agg(max("a")), - Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil - ) + expected) + checkAnswer( + complexData.groupBy().pivot('b, Seq(true, false)).agg(max('a)), + expected) } test("pivot with datatype not supported by PivotFirst 2") { @@ -246,4 +267,45 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { checkAnswer(df.select($"a".cast(StringType)), Row(tsWithZone)) } } + + test("SPARK-24722: pivoting nested columns") { + val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + val df = trainingSales + .groupBy($"sales.year") + .pivot(lower($"sales.course"), Seq("dotNet", "Java").map(_.toLowerCase)) + .agg(sum($"sales.earnings")) + + checkAnswer(df, expected) + } + + test("SPARK-24722: references to multiple columns in the pivot column") { + val expected = Row(2012, 10000.0) :: Row(2013, 48000.0) :: Nil + val df = trainingSales + .groupBy($"sales.year") + .pivot(concat_ws("-", $"training", $"sales.course"), Seq("Experts-dotNET")) + .agg(sum($"sales.earnings")) + + checkAnswer(df, expected) + } + + test("SPARK-24722: pivoting by a constant") { + val expected = Row(2012, 35000.0) :: Row(2013, 78000.0) :: Nil + val df1 = trainingSales + .groupBy($"sales.year") + .pivot(lit(123), Seq(123)) + .agg(sum($"sales.earnings")) + + checkAnswer(df1, expected) + } + + test("SPARK-24722: aggregate as the pivot column") { + val exception = intercept[AnalysisException] { + trainingSales + .groupBy($"sales.year") + .pivot(min($"training"), Seq("Experts")) + .agg(sum($"sales.earnings")) + } + + assert(exception.getMessage.contains("aggregate functions are not allowed")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index deea9dbb30aae..615923fe02d6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -268,6 +268,17 @@ private[sql] trait SQLTestData { self => df } + protected lazy val trainingSales: DataFrame = { + val df = spark.sparkContext.parallelize( + TrainingSales("Experts", CourseSales("dotNET", 2012, 10000)) :: + TrainingSales("Experts", CourseSales("JAVA", 2012, 20000)) :: + TrainingSales("Dummies", CourseSales("dotNet", 2012, 5000)) :: + TrainingSales("Experts", CourseSales("dotNET", 2013, 48000)) :: + TrainingSales("Dummies", CourseSales("Java", 2013, 30000)) :: Nil).toDF() + df.createOrReplaceTempView("trainingSales") + df + } + /** * Initialize all test data such that all temp tables are properly registered. */ @@ -323,4 +334,5 @@ private[sql] object SQLTestData { case class Salary(personId: Int, salary: Double) case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) case class CourseSales(course: String, year: Int, earnings: Double) + case class TrainingSales(training: String, sales: CourseSales) }