Skip to content

Commit 12a8270

Browse files
committed
Address remaining comments
- Use Literal's for the pivot column values instead of strings. - Change seperator when using multiple aggregates to `_` instead of space. - Some additional unit testing
1 parent 88dd513 commit 12a8270

File tree

4 files changed

+37
-22
lines changed

4 files changed

+37
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ class Analyzer(
260260
val singleAgg = aggregates.size == 1
261261
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
262262
def ifExpr(expr: Expression) = {
263-
If(EqualTo(pivotColumn, Literal(value)), expr, Literal(null))
263+
If(EqualTo(pivotColumn, value), expr, Literal(null))
264264
}
265265
aggregates.map { aggregate =>
266266
val filteredAggregate = aggregate.transformDown {
@@ -278,7 +278,7 @@ class Analyzer(
278278
throw new AnalysisException(
279279
s"Aggregate expression required for pivot, found '$aggregate'")
280280
}
281-
val name = if (singleAgg) value else value + " " + aggregate.prettyString
281+
val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString
282282
Alias(filteredAggregate, name)()
283283
}
284284
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,13 +388,13 @@ case class Rollup(
388388
case class Pivot(
389389
groupByExprs: Seq[NamedExpression],
390390
pivotColumn: Expression,
391-
pivotValues: Seq[String],
391+
pivotValues: Seq[Literal],
392392
aggregates: Seq[Expression],
393393
child: LogicalPlan) extends UnaryNode {
394394
override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match {
395-
case aggregate :: Nil => pivotValues.map(AttributeReference(_, aggregate.dataType)())
395+
case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)())
396396
case _ => pivotValues.flatMap{ value =>
397-
aggregates.map(agg => AttributeReference(value + " " + agg.prettyString, agg.dataType)())
397+
aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)())
398398
}
399399
}
400400
}

sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -297,18 +297,25 @@ class GroupedData protected[sql](
297297
* @since 1.6.0
298298
*/
299299
@scala.annotation.varargs
300-
def pivot(pivotColumn: Column, values: String*): GroupedData = groupType match {
300+
def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match {
301301
case _: GroupedData.PivotType =>
302302
throw new UnsupportedOperationException("repeated pivots are not supported")
303303
case GroupedData.GroupByType =>
304304
val pivotValues = if (values.nonEmpty) {
305-
values
305+
values.map {
306+
case Column(literal: Literal) => literal
307+
case other =>
308+
throw new UnsupportedOperationException(
309+
s"The values of a pivot must be literals, found $other")
310+
}
306311
} else {
307312
// Get the distinct values of the column and sort them so its consistent
308-
df.select(pivotColumn.cast(StringType))
313+
df.select(pivotColumn)
309314
.distinct()
310-
.map(_.getString(0))
311-
.collect().sorted.toSeq
315+
.sort(pivotColumn)
316+
.map(_.get(0))
317+
.collect()
318+
.map(Literal(_)).toSeq
312319
}
313320
new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues))
314321
case _ =>
@@ -330,9 +337,9 @@ class GroupedData protected[sql](
330337
* @since 1.6.0
331338
*/
332339
@scala.annotation.varargs
333-
def pivot(pivotColumn: String, values: String*): GroupedData = {
340+
def pivot(pivotColumn: String, values: Any*): GroupedData = {
334341
val resolvedPivotColumn = Column(df.resolve(pivotColumn))
335-
pivot(resolvedPivotColumn, values: _*)
342+
pivot(resolvedPivotColumn, values.map(functions.lit): _*)
336343
}
337344
}
338345

@@ -372,5 +379,5 @@ private[sql] object GroupedData {
372379
/**
373380
* To indicate it's the PIVOT
374381
*/
375-
private[sql] case class PivotType(pivotCol: Expression, values: Seq[String]) extends GroupType
382+
private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType
376383
}

sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,45 +23,53 @@ import org.apache.spark.sql.test.SharedSQLContext
2323
class DataFramePivotSuite extends QueryTest with SharedSQLContext{
2424
import testImplicits._
2525

26-
test("pivot courses groupBy") {
26+
test("pivot courses with literals") {
2727
checkAnswer(
28-
courseSales.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")),
28+
courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
29+
.agg(sum($"earnings")),
2930
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
3031
)
3132
}
3233

33-
test("pivot year groupBy") {
34+
test("pivot year with literals") {
3435
checkAnswer(
35-
courseSales.groupBy($"course").pivot($"year", "2012", "2013").agg(sum($"earnings")),
36+
courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")),
3637
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
3738
)
3839
}
3940

40-
test("pivot courses groupBy multiple") {
41+
test("pivot courses with literals and multiple aggregations") {
4142
checkAnswer(
42-
courseSales.groupBy($"year").pivot($"course", "dotNET", "Java")
43+
courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
4344
.agg(sum($"earnings"), avg($"earnings")),
4445
Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
4546
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
4647
)
4748
}
4849

49-
test("pivot year groupBy with strings") {
50+
test("pivot year with string values (cast)") {
5051
checkAnswer(
5152
courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"),
5253
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
5354
)
5455
}
5556

56-
test("pivot courses groupBy with no values") {
57+
test("pivot year with int values") {
58+
checkAnswer(
59+
courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"),
60+
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
61+
)
62+
}
63+
64+
test("pivot courses with no values") {
5765
// Note Java comes before dotNet in sorted order
5866
checkAnswer(
5967
courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
6068
Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
6169
)
6270
}
6371

64-
test("pivot year groupBy with no values") {
72+
test("pivot year with no values") {
6573
checkAnswer(
6674
courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
6775
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil

0 commit comments

Comments
 (0)