Skip to content

Commit 88dd513

Browse files
committed
Address comments in Analyzer
1 parent 04d643c commit 88dd513

File tree

1 file changed

+16
-12
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis

1 file changed

+16
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -258,20 +258,25 @@ class Analyzer(
258258
case p: Pivot if !p.childrenResolved => p
259259
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
260260
val singleAgg = aggregates.size == 1
261-
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap{ value =>
262-
aggregates.map{ aggregate =>
263-
val filteredAggregate = aggregate.transformDown{
264-
case u: UnaryExpression if u.isInstanceOf[AggregateExpression] =>
265-
u.withNewChildren(Seq(
266-
If(EqualTo(pivotColumn, Literal(value)), u.child, Literal(null))
267-
))
268-
case other: AggregateExpression =>
269-
throw new AnalysisException(
270-
s"Pivot does not support non unary aggregate expressions, found $other")
261+
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
262+
def ifExpr(expr: Expression) = {
263+
If(EqualTo(pivotColumn, Literal(value)), expr, Literal(null))
264+
}
265+
aggregates.map { aggregate =>
266+
val filteredAggregate = aggregate.transformDown {
267+
// Assumption is the aggregate function ignores nulls. This is true for all current
268+
// AggregateFunction's with the exception of First and Last in their default mode
269+
// (which we handle) and possibly some Hive UDAF's.
270+
case First(expr, _) =>
271+
First(ifExpr(expr), Literal(true))
272+
case Last(expr, _) =>
273+
Last(ifExpr(expr), Literal(true))
274+
case a: AggregateFunction =>
275+
a.withNewChildren(a.children.map(ifExpr))
271276
}
272277
if (filteredAggregate.fastEquals(aggregate)) {
273278
throw new AnalysisException(
274-
s"Unary aggregate expression required for pivot, found '$aggregate'")
279+
s"Aggregate expression required for pivot, found '$aggregate'")
275280
}
276281
val name = if (singleAgg) value else value + " " + aggregate.prettyString
277282
Alias(filteredAggregate, name)()
@@ -1034,7 +1039,6 @@ class Analyzer(
10341039
case p if !p.resolved => p // Skip unresolved nodes.
10351040
case p: Project => p
10361041
case f: Filter => f
1037-
case p: Pivot => p
10381042

10391043
// todo: It's hard to write a general rule to pull out nondeterministic expressions
10401044
// from LogicalPlan, currently we only do it for UnaryNode which has same output

0 commit comments

Comments
 (0)