-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-24722][SQL] pivot() with Column type argument #21699
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
889e922
f736ea2
5e68226
c82c397
7d0d226
0fdd11f
74ddcdd
390d832
d62b7e7
fae4fd2
8ffdc32
57c0f64
f32a85b
b9996df
e76e7ad
34535a9
e869f85
cf55135
5da5a2c
ca1250b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -314,7 +314,67 @@ class RelationalGroupedDataset protected[sql]( | |
| * @param pivotColumn Name of the column to pivot. | ||
| * @since 1.6.0 | ||
| */ | ||
| def pivot(pivotColumn: String): RelationalGroupedDataset = { | ||
| def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(Column(pivotColumn)) | ||
|
|
||
| /** | ||
| * Pivots a column of the current `DataFrame` and performs the specified aggregation. | ||
| * There are two versions of pivot function: one that requires the caller to specify the list | ||
| * of distinct values to pivot on, and one that does not. The latter is more concise but less | ||
| * efficient, because Spark needs to first compute the list of distinct values internally. | ||
| * | ||
| * {{{ | ||
| * // Compute the sum of earnings for each year by course with each course as a separate column | ||
| * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") | ||
| * | ||
| * // Or without specifying column values (less efficient) | ||
| * df.groupBy("year").pivot("course").sum("earnings") | ||
| * }}} | ||
| * | ||
| * @param pivotColumn Name of the column to pivot. | ||
| * @param values List of values that will be translated to columns in the output DataFrame. | ||
| * @since 1.6.0 | ||
| */ | ||
| def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { | ||
| pivot(Column(pivotColumn), values) | ||
| } | ||
|
|
||
| /** | ||
| * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified | ||
| * aggregation. | ||
| * | ||
| * There are two versions of pivot function: one that requires the caller to specify the list | ||
| * of distinct values to pivot on, and one that does not. The latter is more concise but less | ||
| * efficient, because Spark needs to first compute the list of distinct values internally. | ||
| * | ||
| * {{{ | ||
| * // Compute the sum of earnings for each year by course with each course as a separate column | ||
| * df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings"); | ||
| * | ||
| * // Or without specifying column values (less efficient) | ||
| * df.groupBy("year").pivot("course").sum("earnings"); | ||
| * }}} | ||
| * | ||
| * @param pivotColumn Name of the column to pivot. | ||
| * @param values List of values that will be translated to columns in the output DataFrame. | ||
| * @since 1.6.0 | ||
| */ | ||
| def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { | ||
| pivot(Column(pivotColumn), values) | ||
| } | ||
|
|
||
| /** | ||
| * Pivots a column of the current `DataFrame` and performs the specified aggregation. | ||
| * This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type. | ||
| * | ||
| * {{{ | ||
| * // Or without specifying column values (less efficient) | ||
| * df.groupBy($"year").pivot($"course").sum($"earnings"); | ||
| * }}} | ||
| * | ||
| * @param pivotColumn he column to pivot. | ||
| * @since 2.4.0 | ||
| */ | ||
| def pivot(pivotColumn: Column): RelationalGroupedDataset = { | ||
| // This is to prevent unintended OOM errors when the number of distinct values is large | ||
| val maxValues = df.sparkSession.sessionState.conf.dataFramePivotMaxValues | ||
| // Get the distinct values of the column and sort them so its consistent | ||
|
|
@@ -339,29 +399,24 @@ class RelationalGroupedDataset protected[sql]( | |
|
|
||
| /** | ||
| * Pivots a column of the current `DataFrame` and performs the specified aggregation. | ||
| * There are two versions of pivot function: one that requires the caller to specify the list | ||
| * of distinct values to pivot on, and one that does not. The latter is more concise but less | ||
| * efficient, because Spark needs to first compute the list of distinct values internally. | ||
| * This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type. | ||
| * | ||
| * {{{ | ||
| * // Compute the sum of earnings for each year by course with each course as a separate column | ||
| * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") | ||
| * | ||
| * // Or without specifying column values (less efficient) | ||
| * df.groupBy("year").pivot("course").sum("earnings") | ||
| * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") | ||
| * }}} | ||
| * | ||
| * @param pivotColumn Name of the column to pivot. | ||
| * @param pivotColumn the column to pivot. | ||
| * @param values List of values that will be translated to columns in the output DataFrame. | ||
| * @since 1.6.0 | ||
| * @since 2.4.0 | ||
| */ | ||
| def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { | ||
| def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To make diffs smaller, can you move this under the signature |
||
| groupType match { | ||
| case RelationalGroupedDataset.GroupByType => | ||
| new RelationalGroupedDataset( | ||
| df, | ||
| groupingExprs, | ||
| RelationalGroupedDataset.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) | ||
| RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply))) | ||
| case _: RelationalGroupedDataset.PivotType => | ||
| throw new UnsupportedOperationException("repeated pivots are not supported") | ||
| case _ => | ||
|
|
@@ -371,25 +426,14 @@ class RelationalGroupedDataset protected[sql]( | |
|
|
||
| /** | ||
| * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified | ||
| * aggregation. | ||
| * | ||
| * There are two versions of pivot function: one that requires the caller to specify the list | ||
| * of distinct values to pivot on, and one that does not. The latter is more concise but less | ||
| * efficient, because Spark needs to first compute the list of distinct values internally. | ||
| * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of | ||
| * the `String` type. | ||
| * | ||
| * {{{ | ||
| * // Compute the sum of earnings for each year by course with each course as a separate column | ||
| * df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings"); | ||
| * | ||
| * // Or without specifying column values (less efficient) | ||
| * df.groupBy("year").pivot("course").sum("earnings"); | ||
| * }}} | ||
| * | ||
| * @param pivotColumn Name of the column to pivot. | ||
| * @param pivotColumn the column to pivot. | ||
| * @param values List of values that will be translated to columns in the output DataFrame. | ||
| * @since 1.6.0 | ||
| * @since 2.4.0 | ||
| */ | ||
| def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { | ||
| def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's a bad idea to use In Spark 3.0 we should audit all the APIs in
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there's a plan for auditing it in 3.0.0, I am okay with going ahead with For the current status, I think the problem here is, this is an overloaded version of pivot and wouldn't necessarily make the differences. I used |
||
| pivot(pivotColumn, values.asScala) | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,28 +27,40 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { | |
| import testImplicits._ | ||
|
|
||
| test("pivot courses") { | ||
| val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil | ||
| checkAnswer( | ||
| courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) | ||
| .agg(sum($"earnings")), | ||
| Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil | ||
| ) | ||
| expected) | ||
| checkAnswer( | ||
| courseSales.groupBy($"year").pivot($"course", Seq("dotNET", "Java")) | ||
| .agg(sum($"earnings")), | ||
| expected) | ||
| } | ||
|
|
||
| test("pivot year") { | ||
| val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil | ||
| checkAnswer( | ||
| courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")), | ||
| Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil | ||
| ) | ||
| expected) | ||
| checkAnswer( | ||
| courseSales.groupBy('course).pivot('year, Seq(2012, 2013)).agg(sum('earnings)), | ||
| expected) | ||
| } | ||
|
|
||
| test("pivot courses with multiple aggregations") { | ||
| val expected = Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: | ||
| Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil | ||
| checkAnswer( | ||
| courseSales.groupBy($"year") | ||
| .pivot("course", Seq("dotNET", "Java")) | ||
| .agg(sum($"earnings"), avg($"earnings")), | ||
| Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: | ||
| Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil | ||
| ) | ||
| expected) | ||
| checkAnswer( | ||
| courseSales.groupBy($"year") | ||
| .pivot($"course", Seq("dotNET", "Java")) | ||
| .agg(sum($"earnings"), avg($"earnings")), | ||
| expected) | ||
| } | ||
|
|
||
| test("pivot year with string values (cast)") { | ||
|
|
@@ -67,17 +79,23 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { | |
|
|
||
| test("pivot courses with no values") { | ||
| // Note Java comes before dotNet in sorted order | ||
| val expected = Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil | ||
| checkAnswer( | ||
| courseSales.groupBy("year").pivot("course").agg(sum($"earnings")), | ||
| Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil | ||
| ) | ||
| expected) | ||
| checkAnswer( | ||
| courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")), | ||
| expected) | ||
| } | ||
|
|
||
| test("pivot year with no values") { | ||
| val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil | ||
| checkAnswer( | ||
| courseSales.groupBy("course").pivot("year").agg(sum($"earnings")), | ||
| Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil | ||
| ) | ||
| expected) | ||
| checkAnswer( | ||
| courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")), | ||
| expected) | ||
| } | ||
|
|
||
| test("pivot max values enforced") { | ||
|
|
@@ -181,10 +199,13 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { | |
| } | ||
|
|
||
| test("pivot with datatype not supported by PivotFirst") { | ||
| val expected = Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil | ||
| checkAnswer( | ||
| complexData.groupBy().pivot("b", Seq(true, false)).agg(max("a")), | ||
| Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil | ||
| ) | ||
| expected) | ||
| checkAnswer( | ||
| complexData.groupBy().pivot('b, Seq(true, false)).agg(max('a)), | ||
| expected) | ||
| } | ||
|
|
||
| test("pivot with datatype not supported by PivotFirst 2") { | ||
|
|
@@ -246,4 +267,45 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { | |
| checkAnswer(df.select($"a".cast(StringType)), Row(tsWithZone)) | ||
| } | ||
| } | ||
|
|
||
| test("SPARK-24722: pivoting nested columns") { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I usually leave a JIRA number for one specific regression test when it's a bug since that's going to explicitly point out which case was broken .. but not a big deal though. |
||
| val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil | ||
| val df = trainingSales | ||
| .groupBy($"sales.year") | ||
| .pivot(lower($"sales.course"), Seq("dotNet", "Java").map(_.toLowerCase)) | ||
| .agg(sum($"sales.earnings")) | ||
|
|
||
| checkAnswer(df, expected) | ||
| } | ||
|
|
||
| test("SPARK-24722: references to multiple columns in the pivot column") { | ||
| val expected = Row(2012, 10000.0) :: Row(2013, 48000.0) :: Nil | ||
| val df = trainingSales | ||
| .groupBy($"sales.year") | ||
| .pivot(concat_ws("-", $"training", $"sales.course"), Seq("Experts-dotNET")) | ||
| .agg(sum($"sales.earnings")) | ||
|
|
||
| checkAnswer(df, expected) | ||
| } | ||
|
|
||
| test("SPARK-24722: pivoting by a constant") { | ||
| val expected = Row(2012, 35000.0) :: Row(2013, 78000.0) :: Nil | ||
| val df1 = trainingSales | ||
| .groupBy($"sales.year") | ||
| .pivot(lit(123), Seq(123)) | ||
| .agg(sum($"sales.earnings")) | ||
|
|
||
| checkAnswer(df1, expected) | ||
| } | ||
|
|
||
| test("SPARK-24722: aggregate as the pivot column") { | ||
| val exception = intercept[AnalysisException] { | ||
| trainingSales | ||
| .groupBy($"sales.year") | ||
| .pivot(min($"training"), Seq("Experts")) | ||
| .agg(sum($"sales.earnings")) | ||
| } | ||
|
|
||
| assert(exception.getMessage.contains("aggregate functions are not allowed")) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we note this in
ColumnAPI too, or note that this is an overloaded version of string's?