Skip to content

Commit 0924765

Browse files
committed
fix and unit test
1 parent 01a7d33 commit 0924765

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,14 +463,15 @@ class Analyzer(
463463
.toAggregateExpression()
464464
, "__pivot_" + a.sql)()
465465
}
466-
val secondAgg = Aggregate(groupByExprs, groupByExprs ++ pivotAggs, firstAgg)
466+
val groupByExprsAttr = groupByExprs.map(_.toAttribute)
467+
val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg)
467468
val pivotAggAttribute = pivotAggs.map(_.toAttribute)
468469
val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) =>
469470
aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) =>
470471
Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))()
471472
}
472473
}
473-
Project(groupByExprs ++ pivotOutputs, secondAgg)
474+
Project(groupByExprsAttr ++ pivotOutputs, secondAgg)
474475
} else {
475476
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
476477
def ifExpr(expr: Expression) = {

sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,4 +208,12 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
208208
)
209209
}
210210

211+
test("pivot with column definition in groupby") {
212+
checkAnswer(
213+
courseSales.groupBy(substring(col("course"), 0, 1).as("foo"))
214+
.pivot("year", Seq(2012, 2013))
215+
.sum("earnings"),
216+
Row("d", 15000.0, 48000.0) :: Row("J", 20000.0, 30000.0) :: Nil
217+
)
218+
}
211219
}

0 commit comments

Comments
 (0)