Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ def pivot(self, pivot_col, values=None):

>>> df4.groupBy("year").pivot("course").sum("earnings").collect()
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
>>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").collect()
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
"""
if values is None:
jgd = self._jgd.pivot(pivot_col)
Expand Down Expand Up @@ -296,6 +298,12 @@ def _test():
Row(course="dotNET", year=2012, earnings=5000),
Row(course="dotNET", year=2013, earnings=48000),
Row(course="Java", year=2013, earnings=30000)]).toDF()
globs['df5'] = sc.parallelize([
Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)),
Row(training="junior", sales=Row(course="Java", year=2012, earnings=20000)),
Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)),
Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)),
Row(training="expert", sales=Row(course="Java", year=2013, earnings=30000))]).toDF()

(failure_count, test_count) = doctest.testmod(
pyspark.sql.group, globs=globs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,67 @@ class RelationalGroupedDataset protected[sql](
* @param pivotColumn Name of the column to pivot.
* @since 1.6.0
*/
def pivot(pivotColumn: String): RelationalGroupedDataset = {
def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(Column(pivotColumn))

/**
* Pivots a column of the current `DataFrame` and performs the specified aggregation.
* There are two versions of pivot function: one that requires the caller to specify the list
* of distinct values to pivot on, and one that does not. The latter is more concise but less
* efficient, because Spark needs to first compute the list of distinct values internally.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
*
* // Or without specifying column values (less efficient)
* df.groupBy("year").pivot("course").sum("earnings")
* }}}
*
* @param pivotColumn Name of the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
*/
def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = {
pivot(Column(pivotColumn), values)
}

/**
* (Java-specific) Pivots a column of the current `DataFrame` and performs the specified
* aggregation.
*
* There are two versions of pivot function: one that requires the caller to specify the list
* of distinct values to pivot on, and one that does not. The latter is more concise but less
* efficient, because Spark needs to first compute the list of distinct values internally.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings");
*
* // Or without specifying column values (less efficient)
* df.groupBy("year").pivot("course").sum("earnings");
* }}}
*
* @param pivotColumn Name of the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
*/
def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = {
pivot(Column(pivotColumn), values)
}

/**
* Pivots a column of the current `DataFrame` and performs the specified aggregation.
* This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type.
*
* {{{
* // Or without specifying column values (less efficient)
* df.groupBy($"year").pivot($"course").sum($"earnings");
* }}}
*
* @param pivotColumn he column to pivot.
* @since 2.4.0
*/
def pivot(pivotColumn: Column): RelationalGroupedDataset = {
// This is to prevent unintended OOM errors when the number of distinct values is large
val maxValues = df.sparkSession.sessionState.conf.dataFramePivotMaxValues
// Get the distinct values of the column and sort them so its consistent
Expand All @@ -339,29 +399,24 @@ class RelationalGroupedDataset protected[sql](

/**
* Pivots a column of the current `DataFrame` and performs the specified aggregation.
* There are two versions of pivot function: one that requires the caller to specify the list
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we note this in Column API too, or note that this is an overloaded version of string's?

* of distinct values to pivot on, and one that does not. The latter is more concise but less
* efficient, because Spark needs to first compute the list of distinct values internally.
* This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
*
* // Or without specifying column values (less efficient)
* df.groupBy("year").pivot("course").sum("earnings")
* df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings")
* }}}
*
* @param pivotColumn Name of the column to pivot.
* @param pivotColumn the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
* @since 2.4.0
*/
def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = {
def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make diffs smaller, can you move this under the signature def pivot(pivotColumn: String, values: Seq[Any])?

groupType match {
case RelationalGroupedDataset.GroupByType =>
new RelationalGroupedDataset(
df,
groupingExprs,
RelationalGroupedDataset.PivotType(df.resolve(pivotColumn), values.map(Literal.apply)))
RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply)))
case _: RelationalGroupedDataset.PivotType =>
throw new UnsupportedOperationException("repeated pivots are not supported")
case _ =>
Expand All @@ -371,25 +426,14 @@ class RelationalGroupedDataset protected[sql](

/**
* (Java-specific) Pivots a column of the current `DataFrame` and performs the specified
* aggregation.
*
* There are two versions of pivot function: one that requires the caller to specify the list
* of distinct values to pivot on, and one that does not. The latter is more concise but less
* efficient, because Spark needs to first compute the list of distinct values internally.
* aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of
* the `String` type.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings");
*
* // Or without specifying column values (less efficient)
* df.groupBy("year").pivot("course").sum("earnings");
* }}}
*
* @param pivotColumn Name of the column to pivot.
* @param pivotColumn the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
* @since 2.4.0
*/
def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = {
def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = {
Copy link
Contributor

@cloud-fan cloud-fan Aug 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a bad idea to use Any in the API. For the existing ones we can't remove, but we should not add new ones using Any.

In Spark 3.0 we should audit all the APIs in functions.scala and make them consistent(e.g. only use Column in the API)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's a plan for auditing it in 3.0.0, I am okay with going ahead with Column but thing is, we should deprecate them first.

For the current status, I think the problem here is, this is an overloaded version of pivot and wouldn't necessarily make the differences.

I used pivot heavily in the previous company before and I am pretty sure pivot(String, actual values) has been a common pattern and usage so far.

pivot(pivotColumn, values.asScala)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,40 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
import testImplicits._

test("pivot courses") {
val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
checkAnswer(
courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
.agg(sum($"earnings")),
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
)
expected)
checkAnswer(
courseSales.groupBy($"year").pivot($"course", Seq("dotNET", "Java"))
.agg(sum($"earnings")),
expected)
}

test("pivot year") {
val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
checkAnswer(
courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
expected)
checkAnswer(
courseSales.groupBy('course).pivot('year, Seq(2012, 2013)).agg(sum('earnings)),
expected)
}

test("pivot courses with multiple aggregations") {
val expected = Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
checkAnswer(
courseSales.groupBy($"year")
.pivot("course", Seq("dotNET", "Java"))
.agg(sum($"earnings"), avg($"earnings")),
Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
)
expected)
checkAnswer(
courseSales.groupBy($"year")
.pivot($"course", Seq("dotNET", "Java"))
.agg(sum($"earnings"), avg($"earnings")),
expected)
}

test("pivot year with string values (cast)") {
Expand All @@ -67,17 +79,23 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {

test("pivot courses with no values") {
// Note Java comes before dotNet in sorted order
val expected = Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
checkAnswer(
courseSales.groupBy("year").pivot("course").agg(sum($"earnings")),
Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
)
expected)
checkAnswer(
courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
expected)
}

test("pivot year with no values") {
val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
checkAnswer(
courseSales.groupBy("course").pivot("year").agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
expected)
checkAnswer(
courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
expected)
}

test("pivot max values enforced") {
Expand Down Expand Up @@ -181,10 +199,13 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
}

test("pivot with datatype not supported by PivotFirst") {
val expected = Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil
checkAnswer(
complexData.groupBy().pivot("b", Seq(true, false)).agg(max("a")),
Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil
)
expected)
checkAnswer(
complexData.groupBy().pivot('b, Seq(true, false)).agg(max('a)),
expected)
}

test("pivot with datatype not supported by PivotFirst 2") {
Expand Down Expand Up @@ -246,4 +267,45 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.select($"a".cast(StringType)), Row(tsWithZone))
}
}

test("SPARK-24722: pivoting nested columns") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually leave a JIRA number for one specific regression test when it's a bug since that's going to explicitly point out which case was broken .. but not a big deal though.

val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
val df = trainingSales
.groupBy($"sales.year")
.pivot(lower($"sales.course"), Seq("dotNet", "Java").map(_.toLowerCase))
.agg(sum($"sales.earnings"))

checkAnswer(df, expected)
}

test("SPARK-24722: references to multiple columns in the pivot column") {
val expected = Row(2012, 10000.0) :: Row(2013, 48000.0) :: Nil
val df = trainingSales
.groupBy($"sales.year")
.pivot(concat_ws("-", $"training", $"sales.course"), Seq("Experts-dotNET"))
.agg(sum($"sales.earnings"))

checkAnswer(df, expected)
}

test("SPARK-24722: pivoting by a constant") {
val expected = Row(2012, 35000.0) :: Row(2013, 78000.0) :: Nil
val df1 = trainingSales
.groupBy($"sales.year")
.pivot(lit(123), Seq(123))
.agg(sum($"sales.earnings"))

checkAnswer(df1, expected)
}

test("SPARK-24722: aggregate as the pivot column") {
val exception = intercept[AnalysisException] {
trainingSales
.groupBy($"sales.year")
.pivot(min($"training"), Seq("Experts"))
.agg(sum($"sales.earnings"))
}

assert(exception.getMessage.contains("aggregate functions are not allowed"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,17 @@ private[sql] trait SQLTestData { self =>
df
}

protected lazy val trainingSales: DataFrame = {
val df = spark.sparkContext.parallelize(
TrainingSales("Experts", CourseSales("dotNET", 2012, 10000)) ::
TrainingSales("Experts", CourseSales("JAVA", 2012, 20000)) ::
TrainingSales("Dummies", CourseSales("dotNet", 2012, 5000)) ::
TrainingSales("Experts", CourseSales("dotNET", 2013, 48000)) ::
TrainingSales("Dummies", CourseSales("Java", 2013, 30000)) :: Nil).toDF()
df.createOrReplaceTempView("trainingSales")
df
}

/**
* Initialize all test data such that all temp tables are properly registered.
*/
Expand Down Expand Up @@ -323,4 +334,5 @@ private[sql] object SQLTestData {
case class Salary(personId: Int, salary: Double)
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
case class CourseSales(course: String, year: Int, earnings: Double)
case class TrainingSales(training: String, sales: CourseSales)
}