From fd235025a2099aa1d239f4b84014301ec41c815d Mon Sep 17 00:00:00 2001 From: maryannxue Date: Tue, 3 Jul 2018 13:54:46 -0700 Subject: [PATCH 1/4] spark-24164 --- .../spark/sql/catalyst/parser/SqlBase.g4 | 11 +- .../sql/catalyst/analysis/Analyzer.scala | 41 ++-- .../expressions/namedExpressions.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 22 +- .../plans/logical/basicLogicalOperators.scala | 2 +- .../test/resources/sql-tests/inputs/pivot.sql | 78 ++++++- .../resources/sql-tests/results/pivot.sql.out | 194 ++++++++++++++---- 7 files changed, 294 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index dc95751bf905c..1b43874af6feb 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -414,7 +414,16 @@ groupingSet ; pivotClause - : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn=identifier IN '(' pivotValues+=constant (',' pivotValues+=constant)* ')' ')' + : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn IN '(' pivotValues+=pivotValue (',' pivotValues+=pivotValue)* ')' ')' + ; + +pivotColumn + : identifiers+=identifier + | '(' identifiers+=identifier (',' identifiers+=identifier)* ')' + ; + +pivotValue + : expression (AS? identifier)? ; lateralView diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e187133d03b17..ad64a77e9a00c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -506,7 +506,7 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transform { case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) - || !p.pivotColumn.resolved => p + || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => // Check all aggregate expressions. aggregates.foreach { e => @@ -515,13 +515,33 @@ class Analyzer( s"Aggregate expression required for pivot, found '$e'") } } + // Check all pivot values are literal and match pivot column data type. + val evalPivotValues = pivotValues.map { value => + if (!Cast.canCast(value.dataType, pivotColumn.dataType)) { + throw new AnalysisException(s"Invalid pivot value '$value': " + + s"value data type ${value.dataType.simpleString} does not match " + + s"pivot column data type ${pivotColumn.dataType.simpleString}") + } + try { + Cast(value, pivotColumn.dataType).eval(EmptyRow) + } catch { + case _: UnsupportedOperationException => + throw new AnalysisException( + s"Literal expressions required for pivot values, found '$value'") + } + } // Group-by expressions coming from SQL are implicit and need to be deduced. val groupByExprs = groupByExprsOpt.getOrElse( (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) val singleAgg = aggregates.size == 1 - def outputName(value: Literal, aggregate: Expression): String = { - val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) - val stringValue: String = Option(utf8Value).map(_.toString).getOrElse("null") + def outputName(value: Expression, aggregate: Expression): String = { + val stringValue = value match { + case n: NamedExpression => n.name + case _ => + val utf8Value = + Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) + Option(utf8Value).map(_.toString).getOrElse("null") + } if (singleAgg) { stringValue } else { @@ -536,15 +556,10 @@ class Analyzer( // Since evaluating |pivotValues| if statements for each input row can get slow this is an // alternate plan that instead uses two steps of aggregation. val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)()) - val namedPivotCol = pivotColumn match { - case n: NamedExpression => n - case _ => Alias(pivotColumn, "__pivot_col")() - } - val bigGroup = groupByExprs :+ namedPivotCol + val bigGroup = groupByExprs ++ pivotColumn.references val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) - val castPivotValues = pivotValues.map(Cast(_, pivotColumn.dataType).eval(EmptyRow)) val pivotAggs = namedAggExps.map { a => - Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues) + Alias(PivotFirst(pivotColumn, a.toAttribute, evalPivotValues) .toAggregateExpression() , "__pivot_" + a.sql)() } @@ -559,8 +574,8 @@ class Analyzer( Project(groupByExprsAttr ++ pivotOutputs, secondAgg) } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => - def ifExpr(expr: Expression) = { - If(EqualNullSafe(pivotColumn, value), expr, Literal(null)) + def ifExpr(e: Expression) = { + If(EqualNullSafe(pivotColumn, Cast(value, pivotColumn.dataType)), e, Literal(null)) } aggregates.map { aggregate => val filteredAggregate = aggregate.transformDown { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8df870468c2ad..789a2ae594b04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -122,7 +122,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this * alias. Auto-assigned if left blank. * @param qualifier An optional string that can be used to referred to this attribute in a fully - * qualified way. Consider the examples tableName.name, subQueryAlias.name. + * qualified way. Consider the examples tableName.name, subQueryAlias.name.li * tableName and subQueryAlias are possible qualifiers. * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 383ebde3229d6..46cde7492f85a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -630,11 +630,29 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val aggregates = Option(ctx.aggregates).toSeq .flatMap(_.namedExpression.asScala) .map(typedVisit[Expression]) - val pivotColumn = UnresolvedAttribute.quoted(ctx.pivotColumn.getText) - val pivotValues = ctx.pivotValues.asScala.map(typedVisit[Expression]).map(Literal.apply) + val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) { + UnresolvedAttribute.quoted(ctx.pivotColumn.identifier.getText) + } else { + CreateStruct( + ctx.pivotColumn.identifiers.asScala.map( + identifier => UnresolvedAttribute.quoted(identifier.getText))) + } + val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue) Pivot(None, pivotColumn, pivotValues, aggregates, query) } + /** + * Create a Pivot column value with or without an alias. + */ + override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else { + e + } + } + /** * Add a [[Generate]] (Lateral View) to a logical plan. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3bf32ef7884e5..ea5a9b8ed5542 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -700,7 +700,7 @@ case class GroupingSets( case class Pivot( groupByExprsOpt: Option[Seq[NamedExpression]], pivotColumn: Expression, - pivotValues: Seq[Literal], + pivotValues: Seq[Expression], aggregates: Seq[Expression], child: LogicalPlan) extends UnaryNode { override lazy val resolved = false // Pivot will be replaced after being resolved. diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql index 01dea6c81c11b..651046cc2b2e6 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -11,6 +11,11 @@ create temporary view years as select * from values (2013, 2) as years(y, s); +create temporary view yearsWithArray as select * from values + (2012, array(1, 1)), + (2013, array(2, 2)) + as yearsWithArray(y, a); + -- pivot courses SELECT * FROM ( SELECT year, course, earnings FROM courseSales @@ -88,12 +93,12 @@ PIVOT ( ); -- pivot with aliases and projection -SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( +SELECT firstYear_s, secondYear_s, firstYear_a, secondYear_a, c FROM ( SELECT year y, course c, earnings e FROM courseSales ) PIVOT ( sum(e) s, avg(e) a - FOR y IN (2012, 2013) + FOR y IN (2012 as firstYear, 2013 secondYear) ); -- pivot years with non-aggregate function @@ -111,3 +116,72 @@ PIVOT ( sum(earnings) FOR year IN (2012, 2013) ); + +-- pivot on multiple pivot columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN (('dotNET', 2012), ('Java', 2013)) +); + +-- pivot on multiple pivot columns with aliased values +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', 2) as c1, ('Java', 1) as c2) +); + +-- pivot on multiple pivot columns with values of wrong data types +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN ('dotNET', 'Java') +); + +-- pivot with unresolvable values +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (s, 2013) +); + +-- pivot with non-literal values +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (course, 2013) +); + +-- pivot on join query with columns of complex data types +SELECT * FROM ( + SELECT course, year, a + FROM courseSales + JOIN yearsWithArray ON year = y +) +PIVOT ( + min(a) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on multiple pivot columns with agg columns of complex data types +SELECT * FROM ( + SELECT course, year, y, a + FROM courseSales + JOIN yearsWithArray ON year = y +) +PIVOT ( + max(a) + FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) +); diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index 85e3488990e20..fd6926ac12e5f 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 13 +-- Number of queries: 21 -- !query 0 @@ -28,6 +28,17 @@ struct<> -- !query 2 +create temporary view yearsWithArray as select * from values + (2012, array(1, 1)), + (2013, array(2, 2)) + as yearsWithArray(y, a) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -35,27 +46,27 @@ PIVOT ( sum(earnings) FOR course IN ('dotNET', 'Java') ) --- !query 2 schema +-- !query 3 schema struct --- !query 2 output +-- !query 3 output 2012 15000 20000 2013 48000 30000 --- !query 3 +-- !query 4 SELECT * FROM courseSales PIVOT ( sum(earnings) FOR year IN (2012, 2013) ) --- !query 3 schema +-- !query 4 schema struct --- !query 3 output +-- !query 4 output Java 20000 30000 dotNET 15000 48000 --- !query 4 +-- !query 5 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -63,14 +74,14 @@ PIVOT ( sum(earnings), avg(earnings) FOR course IN ('dotNET', 'Java') ) --- !query 4 schema +-- !query 5 schema struct --- !query 4 output +-- !query 5 output 2012 15000 7500.0 20000 20000.0 2013 48000 48000.0 30000 30000.0 --- !query 5 +-- !query 6 SELECT * FROM ( SELECT course, earnings FROM courseSales ) @@ -78,13 +89,13 @@ PIVOT ( sum(earnings) FOR course IN ('dotNET', 'Java') ) --- !query 5 schema +-- !query 6 schema struct --- !query 5 output +-- !query 6 output 63000 50000 --- !query 6 +-- !query 7 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -92,13 +103,13 @@ PIVOT ( sum(earnings), min(year) FOR course IN ('dotNET', 'Java') ) --- !query 6 schema +-- !query 7 schema struct --- !query 6 output +-- !query 7 output 63000 2012 50000 2012 --- !query 7 +-- !query 8 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -108,16 +119,16 @@ PIVOT ( sum(earnings) FOR s IN (1, 2) ) --- !query 7 schema +-- !query 8 schema struct --- !query 7 output +-- !query 8 output Java 2012 20000 NULL Java 2013 NULL 30000 dotNET 2012 15000 NULL dotNET 2013 NULL 48000 --- !query 8 +-- !query 9 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -127,14 +138,14 @@ PIVOT ( sum(earnings), min(s) FOR course IN ('dotNET', 'Java') ) --- !query 8 schema +-- !query 9 schema struct --- !query 8 output +-- !query 9 output 2012 15000 1 20000 1 2013 48000 2 30000 2 --- !query 9 +-- !query 10 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -144,42 +155,42 @@ PIVOT ( sum(earnings * s) FOR course IN ('dotNET', 'Java') ) --- !query 9 schema +-- !query 10 schema struct --- !query 9 output +-- !query 10 output 2012 15000 20000 2013 96000 60000 --- !query 10 -SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( +-- !query 11 +SELECT firstYear_s, secondYear_s, firstYear_a, secondYear_a, c FROM ( SELECT year y, course c, earnings e FROM courseSales ) PIVOT ( sum(e) s, avg(e) a - FOR y IN (2012, 2013) + FOR y IN (2012 as firstYear, 2013 secondYear) ) --- !query 10 schema -struct<2012_s:bigint,2013_s:bigint,2012_a:double,2013_a:double,c:string> --- !query 10 output +-- !query 11 schema +struct +-- !query 11 output 15000 48000 7500.0 48000.0 dotNET 20000 30000 20000.0 30000.0 Java --- !query 11 +-- !query 12 SELECT * FROM courseSales PIVOT ( abs(earnings) FOR year IN (2012, 2013) ) --- !query 11 schema +-- !query 12 schema struct<> --- !query 11 output +-- !query 12 output org.apache.spark.sql.AnalysisException Aggregate expression required for pivot, found 'abs(earnings#x)'; --- !query 12 +-- !query 13 SELECT * FROM ( SELECT course, earnings FROM courseSales ) @@ -187,8 +198,119 @@ PIVOT ( sum(earnings) FOR year IN (2012, 2013) ) --- !query 12 schema +-- !query 13 schema struct<> --- !query 12 output +-- !query 13 output org.apache.spark.sql.AnalysisException cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 + + +-- !query 14 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN (('dotNET', 2012), ('Java', 2013)) +) +-- !query 14 schema +struct +-- !query 14 output +1 15000 NULL +2 NULL 30000 + + +-- !query 15 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', 2) as c1, ('Java', 1) as c2) +) +-- !query 15 schema +struct +-- !query 15 output +2012 NULL 20000 +2013 48000 NULL + + +-- !query 16 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN ('dotNET', 'Java') +) +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.AnalysisException +Invalid pivot value 'dotNET': value data type string does not match pivot column data type struct; + + +-- !query 17 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (s, 2013) +) +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.AnalysisException +cannot resolve '`s`' given input columns: [coursesales.course, coursesales.year, coursesales.earnings]; line 4 pos 15 + + +-- !query 18 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (course, 2013) +) +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.AnalysisException +Literal expressions required for pivot values, found 'course#x'; + + +-- !query 19 +SELECT * FROM ( + SELECT course, year, a + FROM courseSales + JOIN yearsWithArray ON year = y +) +PIVOT ( + min(a) + FOR course IN ('dotNET', 'Java') +) +-- !query 19 schema +struct<> +-- !query 19 output +org.apache.spark.SparkException +Job 17 cancelled because SparkContext was shut down + + +-- !query 20 +SELECT * FROM ( + SELECT course, year, y, a + FROM courseSales + JOIN yearsWithArray ON year = y +) +PIVOT ( + max(a) + FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) +) +-- !query 20 schema +struct<> +-- !query 20 output +org.apache.spark.SparkException +Exception thrown in awaitResult: From 942a30dfc0fd070c067aa8d157075909610d3aaa Mon Sep 17 00:00:00 2001 From: maryannxue Date: Thu, 5 Jul 2018 15:23:59 -0700 Subject: [PATCH 2/4] revert accidental changes --- .../spark/sql/catalyst/expressions/namedExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 789a2ae594b04..8df870468c2ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -122,7 +122,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this * alias. Auto-assigned if left blank. * @param qualifier An optional string that can be used to referred to this attribute in a fully - * qualified way. Consider the examples tableName.name, subQueryAlias.name.li + * qualified way. Consider the examples tableName.name, subQueryAlias.name. * tableName and subQueryAlias are possible qualifiers. * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. */ From 58e99ab82ef5fa43c28e0a8471926e987bc2e404 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Mon, 9 Jul 2018 18:17:29 -0700 Subject: [PATCH 3/4] fix ref file --- .../test/resources/sql-tests/results/pivot.sql.out | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index fd6926ac12e5f..798db6a522d6d 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -293,10 +293,10 @@ PIVOT ( FOR course IN ('dotNET', 'Java') ) -- !query 19 schema -struct<> +struct,Java:array> -- !query 19 output -org.apache.spark.SparkException -Job 17 cancelled because SparkContext was shut down +2012 [1,1] [1,1] +2013 [2,2] [2,2] -- !query 20 @@ -310,7 +310,7 @@ PIVOT ( FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) ) -- !query 20 schema -struct<> +struct,[2013, Java]:array> -- !query 20 output -org.apache.spark.SparkException -Exception thrown in awaitResult: +2012 [1,1] NULL +2013 NULL [2,2] From d468821db03644c1535aa3aece55c7bcb1b211c2 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Mon, 9 Jul 2018 21:04:23 -0700 Subject: [PATCH 4/4] address review comments --- .../sql/catalyst/analysis/Analyzer.scala | 24 +++-- .../test/resources/sql-tests/inputs/pivot.sql | 18 ++++ .../resources/sql-tests/results/pivot.sql.out | 90 ++++++++++++------- 3 files changed, 93 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ad64a77e9a00c..6ebc488732980 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -517,18 +517,20 @@ class Analyzer( } // Check all pivot values are literal and match pivot column data type. val evalPivotValues = pivotValues.map { value => + val foldable = value match { + case Alias(v, _) => v.foldable + case _ => value.foldable + } + if (!foldable) { + throw new AnalysisException( + s"Literal expressions required for pivot values, found '$value'") + } if (!Cast.canCast(value.dataType, pivotColumn.dataType)) { throw new AnalysisException(s"Invalid pivot value '$value': " + s"value data type ${value.dataType.simpleString} does not match " + - s"pivot column data type ${pivotColumn.dataType.simpleString}") - } - try { - Cast(value, pivotColumn.dataType).eval(EmptyRow) - } catch { - case _: UnsupportedOperationException => - throw new AnalysisException( - s"Literal expressions required for pivot values, found '$value'") + s"pivot column data type ${pivotColumn.dataType.catalogString}") } + Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) } // Group-by expressions coming from SQL are implicit and need to be deduced. val groupByExprs = groupByExprsOpt.getOrElse( @@ -575,7 +577,11 @@ class Analyzer( } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => def ifExpr(e: Expression) = { - If(EqualNullSafe(pivotColumn, Cast(value, pivotColumn.dataType)), e, Literal(null)) + If( + EqualNullSafe( + pivotColumn, + Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone))), + e, Literal(null)) } aggregates.map { aggregate => val filteredAggregate = aggregate.transformDown { diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql index 651046cc2b2e6..8eb3806a99682 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -93,6 +93,15 @@ PIVOT ( ); -- pivot with aliases and projection +SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012, 2013) +); + +-- pivot with projection and value aliases SELECT firstYear_s, secondYear_s, firstYear_a, secondYear_a, c FROM ( SELECT year y, course c, earnings e FROM courseSales ) @@ -108,6 +117,15 @@ PIVOT ( FOR year IN (2012, 2013) ); +-- pivot with one of the expressions as non-aggregate function +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), year + FOR course IN ('dotNET', 'Java') +); + -- pivot with unresolvable columns SELECT * FROM ( SELECT course, earnings FROM courseSales diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index 798db6a522d6d..accab94dd04c4 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 21 +-- Number of queries: 23 -- !query 0 @@ -163,34 +163,64 @@ struct -- !query 11 -SELECT firstYear_s, secondYear_s, firstYear_a, secondYear_a, c FROM ( +SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( SELECT year y, course c, earnings e FROM courseSales ) PIVOT ( sum(e) s, avg(e) a - FOR y IN (2012 as firstYear, 2013 secondYear) + FOR y IN (2012, 2013) ) -- !query 11 schema -struct +struct<2012_s:bigint,2013_s:bigint,2012_a:double,2013_a:double,c:string> -- !query 11 output 15000 48000 7500.0 48000.0 dotNET 20000 30000 20000.0 30000.0 Java -- !query 12 +SELECT firstYear_s, secondYear_s, firstYear_a, secondYear_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012 as firstYear, 2013 secondYear) +) +-- !query 12 schema +struct +-- !query 12 output +15000 48000 7500.0 48000.0 dotNET +20000 30000 20000.0 30000.0 Java + + +-- !query 13 SELECT * FROM courseSales PIVOT ( abs(earnings) FOR year IN (2012, 2013) ) --- !query 12 schema +-- !query 13 schema struct<> --- !query 12 output +-- !query 13 output org.apache.spark.sql.AnalysisException Aggregate expression required for pivot, found 'abs(earnings#x)'; --- !query 13 +-- !query 14 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), year + FOR course IN ('dotNET', 'Java') +) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +Aggregate expression required for pivot, found 'year#x'; + + +-- !query 15 SELECT * FROM ( SELECT course, earnings FROM courseSales ) @@ -198,14 +228,14 @@ PIVOT ( sum(earnings) FOR year IN (2012, 2013) ) --- !query 13 schema +-- !query 15 schema struct<> --- !query 13 output +-- !query 15 output org.apache.spark.sql.AnalysisException cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 --- !query 14 +-- !query 16 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -215,14 +245,14 @@ PIVOT ( sum(earnings) FOR (course, year) IN (('dotNET', 2012), ('Java', 2013)) ) --- !query 14 schema +-- !query 16 schema struct --- !query 14 output +-- !query 16 output 1 15000 NULL 2 NULL 30000 --- !query 15 +-- !query 17 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -232,14 +262,14 @@ PIVOT ( sum(earnings) FOR (course, s) IN (('dotNET', 2) as c1, ('Java', 1) as c2) ) --- !query 15 schema +-- !query 17 schema struct --- !query 15 output +-- !query 17 output 2012 NULL 20000 2013 48000 NULL --- !query 16 +-- !query 18 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -249,40 +279,40 @@ PIVOT ( sum(earnings) FOR (course, year) IN ('dotNET', 'Java') ) --- !query 16 schema +-- !query 18 schema struct<> --- !query 16 output +-- !query 18 output org.apache.spark.sql.AnalysisException Invalid pivot value 'dotNET': value data type string does not match pivot column data type struct; --- !query 17 +-- !query 19 SELECT * FROM courseSales PIVOT ( sum(earnings) FOR year IN (s, 2013) ) --- !query 17 schema +-- !query 19 schema struct<> --- !query 17 output +-- !query 19 output org.apache.spark.sql.AnalysisException cannot resolve '`s`' given input columns: [coursesales.course, coursesales.year, coursesales.earnings]; line 4 pos 15 --- !query 18 +-- !query 20 SELECT * FROM courseSales PIVOT ( sum(earnings) FOR year IN (course, 2013) ) --- !query 18 schema +-- !query 20 schema struct<> --- !query 18 output +-- !query 20 output org.apache.spark.sql.AnalysisException Literal expressions required for pivot values, found 'course#x'; --- !query 19 +-- !query 21 SELECT * FROM ( SELECT course, year, a FROM courseSales @@ -292,14 +322,14 @@ PIVOT ( min(a) FOR course IN ('dotNET', 'Java') ) --- !query 19 schema +-- !query 21 schema struct,Java:array> --- !query 19 output +-- !query 21 output 2012 [1,1] [1,1] 2013 [2,2] [2,2] --- !query 20 +-- !query 22 SELECT * FROM ( SELECT course, year, y, a FROM courseSales @@ -309,8 +339,8 @@ PIVOT ( max(a) FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) ) --- !query 20 schema +-- !query 22 schema struct,[2013, Java]:array> --- !query 20 output +-- !query 22 output 2012 [1,1] NULL 2013 NULL [2,2]