diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index dc95751bf905c..1b43874af6feb 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -414,7 +414,16 @@ groupingSet ; pivotClause - : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn=identifier IN '(' pivotValues+=constant (',' pivotValues+=constant)* ')' ')' + : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn IN '(' pivotValues+=pivotValue (',' pivotValues+=pivotValue)* ')' ')' + ; + +pivotColumn + : identifiers+=identifier + | '(' identifiers+=identifier (',' identifiers+=identifier)* ')' + ; + +pivotValue + : expression (AS? identifier)? ; lateralView diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 36f14ccdc6989..59c371eb1557b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -509,17 +509,39 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transform { case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) - || !p.pivotColumn.resolved => p + || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => // Check all aggregate expressions. aggregates.foreach(checkValidAggregateExpression) + // Check all pivot values are literal and match pivot column data type. + val evalPivotValues = pivotValues.map { value => + val foldable = value match { + case Alias(v, _) => v.foldable + case _ => value.foldable + } + if (!foldable) { + throw new AnalysisException( + s"Literal expressions required for pivot values, found '$value'") + } + if (!Cast.canCast(value.dataType, pivotColumn.dataType)) { + throw new AnalysisException(s"Invalid pivot value '$value': " + + s"value data type ${value.dataType.simpleString} does not match " + + s"pivot column data type ${pivotColumn.dataType.catalogString}") + } + Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) + } // Group-by expressions coming from SQL are implicit and need to be deduced. val groupByExprs = groupByExprsOpt.getOrElse( (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) val singleAgg = aggregates.size == 1 - def outputName(value: Literal, aggregate: Expression): String = { - val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) - val stringValue: String = Option(utf8Value).map(_.toString).getOrElse("null") + def outputName(value: Expression, aggregate: Expression): String = { + val stringValue = value match { + case n: NamedExpression => n.name + case _ => + val utf8Value = + Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) + Option(utf8Value).map(_.toString).getOrElse("null") + } if (singleAgg) { stringValue } else { @@ -534,15 +556,10 @@ class Analyzer( // Since evaluating |pivotValues| if statements for each input row can get slow this is an // alternate plan that instead uses two steps of aggregation. val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)()) - val namedPivotCol = pivotColumn match { - case n: NamedExpression => n - case _ => Alias(pivotColumn, "__pivot_col")() - } - val bigGroup = groupByExprs :+ namedPivotCol + val bigGroup = groupByExprs ++ pivotColumn.references val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) - val castPivotValues = pivotValues.map(Cast(_, pivotColumn.dataType).eval(EmptyRow)) val pivotAggs = namedAggExps.map { a => - Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues) + Alias(PivotFirst(pivotColumn, a.toAttribute, evalPivotValues) .toAggregateExpression() , "__pivot_" + a.sql)() } @@ -557,8 +574,12 @@ class Analyzer( Project(groupByExprsAttr ++ pivotOutputs, secondAgg) } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => - def ifExpr(expr: Expression) = { - If(EqualNullSafe(pivotColumn, value), expr, Literal(null)) + def ifExpr(e: Expression) = { + If( + EqualNullSafe( + pivotColumn, + Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone))), + e, Literal(null)) } aggregates.map { aggregate => val filteredAggregate = aggregate.transformDown { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f398b479dc273..49f578a24aaeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -630,11 +630,29 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val aggregates = Option(ctx.aggregates).toSeq .flatMap(_.namedExpression.asScala) .map(typedVisit[Expression]) - val pivotColumn = UnresolvedAttribute.quoted(ctx.pivotColumn.getText) - val pivotValues = ctx.pivotValues.asScala.map(typedVisit[Expression]).map(Literal.apply) + val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) { + UnresolvedAttribute.quoted(ctx.pivotColumn.identifier.getText) + } else { + CreateStruct( + ctx.pivotColumn.identifiers.asScala.map( + identifier => UnresolvedAttribute.quoted(identifier.getText))) + } + val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue) Pivot(None, pivotColumn, pivotValues, aggregates, query) } + /** + * Create a Pivot column value with or without an alias. + */ + override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else { + e + } + } + /** * Add a [[Generate]] (Lateral View) to a logical plan. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3bf32ef7884e5..ea5a9b8ed5542 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -700,7 +700,7 @@ case class GroupingSets( case class Pivot( groupByExprsOpt: Option[Seq[NamedExpression]], pivotColumn: Expression, - pivotValues: Seq[Literal], + pivotValues: Seq[Expression], aggregates: Seq[Expression], child: LogicalPlan) extends UnaryNode { override lazy val resolved = false // Pivot will be replaced after being resolved. diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql index b3d53adfbebe7..a6c8d4854ff38 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -11,6 +11,11 @@ create temporary view years as select * from values (2013, 2) as years(y, s); +create temporary view yearsWithArray as select * from values + (2012, array(1, 1)), + (2013, array(2, 2)) + as yearsWithArray(y, a); + -- pivot courses SELECT * FROM ( SELECT year, course, earnings FROM courseSales @@ -96,6 +101,15 @@ PIVOT ( FOR y IN (2012, 2013) ); +-- pivot with projection and value aliases +SELECT firstYear_s, secondYear_s, firstYear_a, secondYear_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012 as firstYear, 2013 secondYear) +); + -- pivot years with non-aggregate function SELECT * FROM courseSales PIVOT ( @@ -103,6 +117,15 @@ PIVOT ( FOR year IN (2012, 2013) ); +-- pivot with one of the expressions as non-aggregate function +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), year + FOR course IN ('dotNET', 'Java') +); + -- pivot with unresolvable columns SELECT * FROM ( SELECT course, earnings FROM courseSales @@ -129,3 +152,72 @@ PIVOT ( sum(avg(earnings)) FOR course IN ('dotNET', 'Java') ); + +-- pivot on multiple pivot columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN (('dotNET', 2012), ('Java', 2013)) +); + +-- pivot on multiple pivot columns with aliased values +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', 2) as c1, ('Java', 1) as c2) +); + +-- pivot on multiple pivot columns with values of wrong data types +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN ('dotNET', 'Java') +); + +-- pivot with unresolvable values +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (s, 2013) +); + +-- pivot with non-literal values +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (course, 2013) +); + +-- pivot on join query with columns of complex data types +SELECT * FROM ( + SELECT course, year, a + FROM courseSales + JOIN yearsWithArray ON year = y +) +PIVOT ( + min(a) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on multiple pivot columns with agg columns of complex data types +SELECT * FROM ( + SELECT course, year, y, a + FROM courseSales + JOIN yearsWithArray ON year = y +) +PIVOT ( + max(a) + FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) +); diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index 922d8b9f9152c..6bb51b946f960 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 25 -- !query 0 @@ -28,6 +28,17 @@ struct<> -- !query 2 +create temporary view yearsWithArray as select * from values + (2012, array(1, 1)), + (2013, array(2, 2)) + as yearsWithArray(y, a) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -35,27 +46,27 @@ PIVOT ( sum(earnings) FOR course IN ('dotNET', 'Java') ) --- !query 2 schema +-- !query 3 schema struct --- !query 2 output +-- !query 3 output 2012 15000 20000 2013 48000 30000 --- !query 3 +-- !query 4 SELECT * FROM courseSales PIVOT ( sum(earnings) FOR year IN (2012, 2013) ) --- !query 3 schema +-- !query 4 schema struct --- !query 3 output +-- !query 4 output Java 20000 30000 dotNET 15000 48000 --- !query 4 +-- !query 5 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -63,14 +74,14 @@ PIVOT ( sum(earnings), avg(earnings) FOR course IN ('dotNET', 'Java') ) --- !query 4 schema +-- !query 5 schema struct --- !query 4 output +-- !query 5 output 2012 15000 7500.0 20000 20000.0 2013 48000 48000.0 30000 30000.0 --- !query 5 +-- !query 6 SELECT * FROM ( SELECT course, earnings FROM courseSales ) @@ -78,13 +89,13 @@ PIVOT ( sum(earnings) FOR course IN ('dotNET', 'Java') ) --- !query 5 schema +-- !query 6 schema struct --- !query 5 output +-- !query 6 output 63000 50000 --- !query 6 +-- !query 7 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -92,13 +103,13 @@ PIVOT ( sum(earnings), min(year) FOR course IN ('dotNET', 'Java') ) --- !query 6 schema +-- !query 7 schema struct --- !query 6 output +-- !query 7 output 63000 2012 50000 2012 --- !query 7 +-- !query 8 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -108,16 +119,16 @@ PIVOT ( sum(earnings) FOR s IN (1, 2) ) --- !query 7 schema +-- !query 8 schema struct --- !query 7 output +-- !query 8 output Java 2012 20000 NULL Java 2013 NULL 30000 dotNET 2012 15000 NULL dotNET 2013 NULL 48000 --- !query 8 +-- !query 9 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -127,14 +138,14 @@ PIVOT ( sum(earnings), min(s) FOR course IN ('dotNET', 'Java') ) --- !query 8 schema +-- !query 9 schema struct --- !query 8 output +-- !query 9 output 2012 15000 1 20000 1 2013 48000 2 30000 2 --- !query 9 +-- !query 10 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -144,14 +155,14 @@ PIVOT ( sum(earnings * s) FOR course IN ('dotNET', 'Java') ) --- !query 9 schema +-- !query 10 schema struct --- !query 9 output +-- !query 10 output 2012 15000 20000 2013 96000 60000 --- !query 10 +-- !query 11 SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( SELECT year y, course c, earnings e FROM courseSales ) @@ -159,27 +170,57 @@ PIVOT ( sum(e) s, avg(e) a FOR y IN (2012, 2013) ) --- !query 10 schema +-- !query 11 schema struct<2012_s:bigint,2013_s:bigint,2012_a:double,2013_a:double,c:string> --- !query 10 output +-- !query 11 output 15000 48000 7500.0 48000.0 dotNET 20000 30000 20000.0 30000.0 Java --- !query 11 +-- !query 12 +SELECT firstYear_s, secondYear_s, firstYear_a, secondYear_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012 as firstYear, 2013 secondYear) +) +-- !query 12 schema +struct +-- !query 12 output +15000 48000 7500.0 48000.0 dotNET +20000 30000 20000.0 30000.0 Java + + +-- !query 13 SELECT * FROM courseSales PIVOT ( abs(earnings) FOR year IN (2012, 2013) ) --- !query 11 schema +-- !query 13 schema struct<> --- !query 11 output +-- !query 13 output org.apache.spark.sql.AnalysisException Aggregate expression required for pivot, but 'coursesales.`earnings`' did not appear in any aggregate function.; --- !query 12 +-- !query 14 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), year + FOR course IN ('dotNET', 'Java') +) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +Aggregate expression required for pivot, but '__auto_generated_subquery_name.`year`' did not appear in any aggregate function.; + + +-- !query 15 SELECT * FROM ( SELECT course, earnings FROM courseSales ) @@ -187,14 +228,14 @@ PIVOT ( sum(earnings) FOR year IN (2012, 2013) ) --- !query 12 schema +-- !query 15 schema struct<> --- !query 12 output +-- !query 15 output org.apache.spark.sql.AnalysisException cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 --- !query 13 +-- !query 16 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -202,14 +243,14 @@ PIVOT ( ceil(sum(earnings)), avg(earnings) + 1 as a1 FOR course IN ('dotNET', 'Java') ) --- !query 13 schema +-- !query 16 schema struct --- !query 13 output +-- !query 16 output 2012 15000 7501.0 20000 20001.0 2013 48000 48001.0 30000 30001.0 --- !query 14 +-- !query 17 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -217,8 +258,119 @@ PIVOT ( sum(avg(earnings)) FOR course IN ('dotNET', 'Java') ) --- !query 14 schema +-- !query 17 schema struct<> --- !query 14 output +-- !query 17 output org.apache.spark.sql.AnalysisException It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.; + + +-- !query 18 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN (('dotNET', 2012), ('Java', 2013)) +) +-- !query 18 schema +struct +-- !query 18 output +1 15000 NULL +2 NULL 30000 + + +-- !query 19 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', 2) as c1, ('Java', 1) as c2) +) +-- !query 19 schema +struct +-- !query 19 output +2012 NULL 20000 +2013 48000 NULL + + +-- !query 20 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN ('dotNET', 'Java') +) +-- !query 20 schema +struct<> +-- !query 20 output +org.apache.spark.sql.AnalysisException +Invalid pivot value 'dotNET': value data type string does not match pivot column data type struct; + + +-- !query 21 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (s, 2013) +) +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '`s`' given input columns: [coursesales.course, coursesales.year, coursesales.earnings]; line 4 pos 15 + + +-- !query 22 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (course, 2013) +) +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +Literal expressions required for pivot values, found 'course#x'; + + +-- !query 23 +SELECT * FROM ( + SELECT course, year, a + FROM courseSales + JOIN yearsWithArray ON year = y +) +PIVOT ( + min(a) + FOR course IN ('dotNET', 'Java') +) +-- !query 23 schema +struct,Java:array> +-- !query 23 output +2012 [1,1] [1,1] +2013 [2,2] [2,2] + + +-- !query 24 +SELECT * FROM ( + SELECT course, year, y, a + FROM courseSales + JOIN yearsWithArray ON year = y +) +PIVOT ( + max(a) + FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) +) +-- !query 24 schema +struct,[2013, Java]:array> +-- !query 24 output +2012 [1,1] NULL +2013 NULL [2,2]