Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,16 @@ groupingSet
;

pivotClause
: PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn=identifier IN '(' pivotValues+=constant (',' pivotValues+=constant)* ')' ')'
: PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn IN '(' pivotValues+=pivotValue (',' pivotValues+=pivotValue)* ')' ')'
;

pivotColumn
: identifiers+=identifier
| '(' identifiers+=identifier (',' identifiers+=identifier)* ')'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any specific reasons to restrict the pivotColumn by identifier? Are there any cases when expressions still don't supported properly with your changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason was that I implemented this pivot SQL support based on ORACLE grammar. Please take a look at https://docs.oracle.com/database/121/SQLRF/img_text/pivot_for_clause.htm. Note that the "column" here is different from "expression" (take this for reference: https://docs.oracle.com/cd/B28359_01/server.111/b28286/expressions002.htm#SQLRF52047).
Another reason was that relaxing it to an "expr" would require a lot more tests and handling of special cases.

;

pivotValue
: expression (AS? identifier)?
;

lateralView
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -509,17 +509,39 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved)
|| (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved))
|| !p.pivotColumn.resolved => p
|| !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By which test is the change covered?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this PR, pivot values can only be single literals (no struct) so they have been converted to Literals in ASTBuilder. Now they are "expressions" and will be handled in this Analyzer rule.

case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) =>
// Check all aggregate expressions.
aggregates.foreach(checkValidAggregateExpression)
// Check all pivot values are literal and match pivot column data type.
val evalPivotValues = pivotValues.map { value =>
val foldable = value match {
case Alias(v, _) => v.foldable
case _ => value.foldable
}
if (!foldable) {
throw new AnalysisException(
s"Literal expressions required for pivot values, found '$value'")
}
if (!Cast.canCast(value.dataType, pivotColumn.dataType)) {
throw new AnalysisException(s"Invalid pivot value '$value': " +
s"value data type ${value.dataType.simpleString} does not match " +
s"pivot column data type ${pivotColumn.dataType.catalogString}")
}
Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
}
// Group-by expressions coming from SQL are implicit and need to be deduced.
val groupByExprs = groupByExprsOpt.getOrElse(
(child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq)
val singleAgg = aggregates.size == 1
def outputName(value: Literal, aggregate: Expression): String = {
val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
val stringValue: String = Option(utf8Value).map(_.toString).getOrElse("null")
def outputName(value: Expression, aggregate: Expression): String = {
val stringValue = value match {
case n: NamedExpression => n.name
case _ =>
val utf8Value =
Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
Option(utf8Value).map(_.toString).getOrElse("null")
}
if (singleAgg) {
stringValue
} else {
Expand All @@ -534,15 +556,10 @@ class Analyzer(
// Since evaluating |pivotValues| if statements for each input row can get slow this is an
// alternate plan that instead uses two steps of aggregation.
val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)())
val namedPivotCol = pivotColumn match {
case n: NamedExpression => n
case _ => Alias(pivotColumn, "__pivot_col")()
}
val bigGroup = groupByExprs :+ namedPivotCol
val bigGroup = groupByExprs ++ pivotColumn.references
val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child)
val castPivotValues = pivotValues.map(Cast(_, pivotColumn.dataType).eval(EmptyRow))
val pivotAggs = namedAggExps.map { a =>
Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues)
Alias(PivotFirst(pivotColumn, a.toAttribute, evalPivotValues)
.toAggregateExpression()
, "__pivot_" + a.sql)()
}
Expand All @@ -557,8 +574,12 @@ class Analyzer(
Project(groupByExprsAttr ++ pivotOutputs, secondAgg)
} else {
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
def ifExpr(expr: Expression) = {
If(EqualNullSafe(pivotColumn, value), expr, Literal(null))
def ifExpr(e: Expression) = {
If(
EqualNullSafe(
pivotColumn,
Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone))),
e, Literal(null))
}
aggregates.map { aggregate =>
val filteredAggregate = aggregate.transformDown {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -630,11 +630,29 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
val aggregates = Option(ctx.aggregates).toSeq
.flatMap(_.namedExpression.asScala)
.map(typedVisit[Expression])
val pivotColumn = UnresolvedAttribute.quoted(ctx.pivotColumn.getText)
val pivotValues = ctx.pivotValues.asScala.map(typedVisit[Expression]).map(Literal.apply)
val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any reasons to handle one pivot column separately? And what happens if size == 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot be "0" as required by the parser rule. if size == 1, then it's single column as before, otherwise it's a construct.

UnresolvedAttribute.quoted(ctx.pivotColumn.identifier.getText)
} else {
CreateStruct(
ctx.pivotColumn.identifiers.asScala.map(
identifier => UnresolvedAttribute.quoted(identifier.getText)))
}
val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue)
Pivot(None, pivotColumn, pivotValues, aggregates, query)
}

/**
* Create a Pivot column value with or without an alias.
*/
override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) {
val e = expression(ctx.expression)
if (ctx.identifier != null) {
Alias(e, ctx.identifier.getText)()
} else {
e
}
}

/**
* Add a [[Generate]] (Lateral View) to a logical plan.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ case class GroupingSets(
case class Pivot(
groupByExprsOpt: Option[Seq[NamedExpression]],
pivotColumn: Expression,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am asking just for my understanding. If you support multiple pivot columns, why it is not declared here explicitly: pivotColumns: Seq[Expression] like for pivotValues?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Pivot column is one "expression" which can be either 1) a single column reference or 2) a struct of multiple columns. Either way the list of pivot values are many-to-one mapping for the pivot column.

pivotValues: Seq[Literal],
pivotValues: Seq[Expression],
aggregates: Seq[Expression],
child: LogicalPlan) extends UnaryNode {
override lazy val resolved = false // Pivot will be replaced after being resolved.
Expand Down
92 changes: 92 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/pivot.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ create temporary view years as select * from values
(2013, 2)
as years(y, s);

create temporary view yearsWithArray as select * from values
(2012, array(1, 1)),
(2013, array(2, 2))
as yearsWithArray(y, a);

-- pivot courses
SELECT * FROM (
SELECT year, course, earnings FROM courseSales
Expand Down Expand Up @@ -96,13 +101,31 @@ PIVOT (
FOR y IN (2012, 2013)
);

-- pivot with projection and value aliases
SELECT firstYear_s, secondYear_s, firstYear_a, secondYear_a, c FROM (
SELECT year y, course c, earnings e FROM courseSales
)
PIVOT (
sum(e) s, avg(e) a
FOR y IN (2012 as firstYear, 2013 secondYear)
);

-- pivot years with non-aggregate function
SELECT * FROM courseSales
PIVOT (
abs(earnings)
FOR year IN (2012, 2013)
);

-- pivot with one of the expressions as non-aggregate function
SELECT * FROM (
SELECT year, course, earnings FROM courseSales
)
PIVOT (
sum(earnings), year
FOR course IN ('dotNET', 'Java')
);

-- pivot with unresolvable columns
SELECT * FROM (
SELECT course, earnings FROM courseSales
Expand All @@ -129,3 +152,72 @@ PIVOT (
sum(avg(earnings))
FOR course IN ('dotNET', 'Java')
);

-- pivot on multiple pivot columns
SELECT * FROM (
SELECT course, year, earnings, s
FROM courseSales
JOIN years ON year = y
)
PIVOT (
sum(earnings)
FOR (course, year) IN (('dotNET', 2012), ('Java', 2013))
);

-- pivot on multiple pivot columns with aliased values
SELECT * FROM (
SELECT course, year, earnings, s
FROM courseSales
JOIN years ON year = y
)
PIVOT (
sum(earnings)
FOR (course, s) IN (('dotNET', 2) as c1, ('Java', 1) as c2)
);

-- pivot on multiple pivot columns with values of wrong data types
SELECT * FROM (
SELECT course, year, earnings, s
FROM courseSales
JOIN years ON year = y
)
PIVOT (
sum(earnings)
FOR (course, year) IN ('dotNET', 'Java')
);

-- pivot with unresolvable values
SELECT * FROM courseSales
PIVOT (
sum(earnings)
FOR year IN (s, 2013)
);

-- pivot with non-literal values
SELECT * FROM courseSales
PIVOT (
sum(earnings)
FOR year IN (course, 2013)
);

-- pivot on join query with columns of complex data types
SELECT * FROM (
SELECT course, year, a
FROM courseSales
JOIN yearsWithArray ON year = y
)
PIVOT (
min(a)
FOR course IN ('dotNET', 'Java')
);

-- pivot on multiple pivot columns with agg columns of complex data types
SELECT * FROM (
SELECT course, year, y, a
FROM courseSales
JOIN yearsWithArray ON year = y
)
PIVOT (
max(a)
FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java'))
);
Loading