From 63b36adfb0a8e8a487200e84e49df2da523023ed Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Mon, 18 Jan 2016 12:50:32 +0800 Subject: [PATCH 1/5] support order by index and group by index refactor style didn't catched in local style forgot to enbale style checking for test in local override resolved refactor a version refactor --- .../sql/catalyst/analysis/Analyzer.scala | 61 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 49 ++++++++++++--- 2 files changed, 100 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 004c1caaffec..5762a2da90bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -73,6 +73,7 @@ class Analyzer( Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: + ResolveIndexReferences :: ResolveGroupingAnalytics :: ResolvePivot :: ResolveUpCast :: @@ -359,6 +360,62 @@ class Analyzer( } } + /** + * Replace the index with the related attribute for order by and group by. + */ + object ResolveIndexReferences extends Rule[LogicalPlan] { + + def indexToColumn( + sortOrder: SortOrder, + child: LogicalPlan, + index: Int, + direction: SortDirection): SortOrder = { + val orderNodes = child.output + if (index > 0 && index <= orderNodes.size) { + SortOrder(orderNodes(index - 1), direction) + } else { + throw new UnresolvedException(sortOrder, + s"""Order by position: $index does not exist \n + |The Select List is indexed from 1 to ${orderNodes.size}""".stripMargin) + } + } + + def indexToColumn( + agg: Aggregate, + child: LogicalPlan, + index: Int): Expression = { + val attributes = child.output + if (index > 0 && index <= attributes.size) { + attributes(index - 1) + } else { + throw new UnresolvedException(agg, + s"""Order by position: $index does not exist \n + |The Select List is indexed from 1 to ${attributes.size}""".stripMargin) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case s @ Sort(orders, global, child) if child.resolved => + val newOrders = orders map { + case s @ SortOrder(IntegerLiteral(index), direction) => + indexToColumn(s, child, index, direction) + case s @ SortOrder(UnaryMinus(IntegerLiteral(index)), direction) => + indexToColumn(s, child, -index, direction) + case other => other + } + Sort(newOrders, global, child) + case a @ Aggregate(groups, aggs, child) => + val newGroups = groups map { + case IntegerLiteral(index) => + indexToColumn(a, child, index) + case UnaryMinus(IntegerLiteral(index)) => + indexToColumn(a, child, -index) + case other => other + } + Aggregate(newGroups, aggs, child) + } + } + /** * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from * a logical plan node's children. @@ -497,6 +554,10 @@ class Analyzer( ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) Sort(newOrdering, global, child) + case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => + val newOrdering = resolveSortOrders(ordering, child, throws = false) + Sort(newOrdering, global, child) + // A special case for Generate, because the output of Generate should not be resolved by // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. case g @ Generate(generator, join, outer, qualifier, output, child) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b3e179755a19..81a63da9ec7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,10 +21,9 @@ import java.math.MathContext import java.sql.Timestamp import org.apache.spark.AccumulatorSuite -import org.apache.spark.sql.catalyst.CatalystQl -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.parser.ParserConf -import org.apache.spark.sql.execution.{aggregate, SparkQl} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, UnresolvedException} +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ @@ -462,21 +461,36 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("literal in agg grouping expressions") { checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + sql("SELECT SUM(a) FROM testData2 GROUP BY 2"), + Seq(Row(6), Row(6))) + checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + sql("SELECT a, SUM(b) FROM testData2 GROUP BY 1"), + Seq(Row(1, 3), Row(2, 3), Row(3, 3))) + + checkAnswer( + sql("SELECT a, SUM(b) FROM testData2 GROUP BY 1"), + Seq(Row(1, 3), Row(2, 3), Row(3, 3))) + + checkAnswer( + sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), Seq(Row(1, 2), Row(2, 2), Row(3, 2))) checkAnswer( sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + Seq(Row(1, 1), Row(1, 1), Row(2, 1), Row(2, 1), Row(3, 1), Row(3, 1))) + checkAnswer( sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), - sql("SELECT 1, 2, sum(b) FROM testData2")) + sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY a, b")) + + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) } test("aggregates with nulls") { @@ -499,6 +513,21 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row("1")) } + test("order by number index") { + intercept[UnresolvedException[SortOrder]] { + sql("SELECT * FROM testData2 ORDER BY -1 DESC, b ASC").collect() + } + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY 1 DESC"), + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"), + sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC")) + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY 1 ASC, b ASC"), + Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) + } + def sortTest(): Unit = { checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), From 0b1e8a770033a32f58c51db02c0c5bd267a4d076 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Thu, 21 Jan 2016 16:40:04 +0800 Subject: [PATCH 2/5] refactor --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5762a2da90bf..3a89d4ae388b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -554,10 +554,6 @@ class Analyzer( ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) Sort(newOrdering, global, child) - case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => - val newOrdering = resolveSortOrders(ordering, child, throws = false) - Sort(newOrdering, global, child) - // A special case for Generate, because the output of Generate should not be resolved by // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. case g @ Generate(generator, join, outer, qualifier, output, child) From 6f880d407967a8079c06e657447c2cd386412f45 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Thu, 21 Jan 2016 16:44:29 +0800 Subject: [PATCH 3/5] redundant --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 81a63da9ec7d..5609e3ce8b43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -468,10 +468,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT a, SUM(b) FROM testData2 GROUP BY 1"), Seq(Row(1, 3), Row(2, 3), Row(3, 3))) - checkAnswer( - sql("SELECT a, SUM(b) FROM testData2 GROUP BY 1"), - Seq(Row(1, 3), Row(2, 3), Row(3, 3))) - checkAnswer( sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), Seq(Row(1, 2), Row(2, 2), Row(3, 2))) From 145482fbaea08046030f86687cf2d6c1812822d7 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Mon, 25 Jan 2016 12:16:28 +0800 Subject: [PATCH 4/5] typo --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3a89d4ae388b..edab0b288988 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -389,7 +389,7 @@ class Analyzer( attributes(index - 1) } else { throw new UnresolvedException(agg, - s"""Order by position: $index does not exist \n + s"""Aggregate by position: $index does not exist \n |The Select List is indexed from 1 to ${attributes.size}""".stripMargin) } } From 66c54b14f8e57d181bf78f293fe0de0d899cb2ba Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Wed, 17 Feb 2016 15:33:30 +0800 Subject: [PATCH 5/5] address comments --- .../sql/catalyst/analysis/Analyzer.scala | 93 +++++++------------ .../sql/catalyst/planning/patterns.scala | 12 +++ .../org/apache/spark/sql/SQLQuerySuite.scala | 12 +++ 3 files changed, 60 insertions(+), 57 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index edab0b288988..ebc0b6777035 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatal import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.planning.IntegerIndex import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -73,7 +74,6 @@ class Analyzer( Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: - ResolveIndexReferences :: ResolveGroupingAnalytics :: ResolvePivot :: ResolveUpCast :: @@ -360,62 +360,6 @@ class Analyzer( } } - /** - * Replace the index with the related attribute for order by and group by. - */ - object ResolveIndexReferences extends Rule[LogicalPlan] { - - def indexToColumn( - sortOrder: SortOrder, - child: LogicalPlan, - index: Int, - direction: SortDirection): SortOrder = { - val orderNodes = child.output - if (index > 0 && index <= orderNodes.size) { - SortOrder(orderNodes(index - 1), direction) - } else { - throw new UnresolvedException(sortOrder, - s"""Order by position: $index does not exist \n - |The Select List is indexed from 1 to ${orderNodes.size}""".stripMargin) - } - } - - def indexToColumn( - agg: Aggregate, - child: LogicalPlan, - index: Int): Expression = { - val attributes = child.output - if (index > 0 && index <= attributes.size) { - attributes(index - 1) - } else { - throw new UnresolvedException(agg, - s"""Aggregate by position: $index does not exist \n - |The Select List is indexed from 1 to ${attributes.size}""".stripMargin) - } - } - - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case s @ Sort(orders, global, child) if child.resolved => - val newOrders = orders map { - case s @ SortOrder(IntegerLiteral(index), direction) => - indexToColumn(s, child, index, direction) - case s @ SortOrder(UnaryMinus(IntegerLiteral(index)), direction) => - indexToColumn(s, child, -index, direction) - case other => other - } - Sort(newOrders, global, child) - case a @ Aggregate(groups, aggs, child) => - val newGroups = groups map { - case IntegerLiteral(index) => - indexToColumn(a, child, index) - case UnaryMinus(IntegerLiteral(index)) => - indexToColumn(a, child, -index) - case other => other - } - Aggregate(newGroups, aggs, child) - } - } - /** * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from * a logical plan node's children. @@ -554,6 +498,41 @@ class Analyzer( ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) Sort(newOrdering, global, child) + // Replace the index with the related attribute for ORDER BY + // which is a 1-base position of the projection list. + case s @ Sort(orders, global, child) if child.resolved && + orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty) => + val newOrders = orders map { + case s @ SortOrder(IntegerIndex(index), direction) => + if (index > 0 && index <= child.output.size) { + SortOrder(child.output(index - 1), direction) + } else { + throw new UnresolvedException(s, + s"""Order by position: $index does not exist \n + |The Select List is indexed from 1 to ${child.output.size}""".stripMargin) + } + case other => other + } + Sort(newOrders, global, child) + + // Replace the index with the related attribute for Group BY + // which is a 1-base position of the underlying columns. + case a @ Aggregate(groups, aggs, child) if child.resolved && + groups.exists(IntegerIndex.unapply(_).nonEmpty) => + val newGroups = groups map { + case IntegerIndex(index) => + val attributes = child.output + if (index > 0 && index <= attributes.size) { + attributes(index - 1) + } else { + throw new UnresolvedException(a, + s"""Aggregate by position: $index does not exist \n + |The Select List is indexed from 1 to ${attributes.size}""".stripMargin) + } + case other => other + } + Aggregate(newGroups, aggs, child) + // A special case for Generate, because the output of Generate should not be resolved by // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. case g @ Generate(generator, join, outer, qualifier, output, child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 7302b63646d6..edcda3584df5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -24,6 +24,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.IntegerType /** * A pattern that matches any number of project or filter operations on top of another relational @@ -202,3 +203,14 @@ object Unions { } } } + +/** + * Extractor for retrieving Int value. + */ +object IntegerIndex { + def unapply(a: Any): Option[Int] = a match { + case Literal(a: Int, IntegerType) => Some(a) + case UnaryMinus(IntegerLiteral(v)) => Some(-v) + case _ => None + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5609e3ce8b43..d4379031469d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -23,6 +23,7 @@ import java.sql.Timestamp import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} @@ -460,6 +461,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("literal in agg grouping expressions") { + intercept[UnresolvedException[Aggregate]] { + sql("SELECT * FROM testData2 GROUP BY -1").collect() + } + + intercept[UnresolvedException[Aggregate]] { + sql("SELECT * FROM testData2 GROUP BY 3").collect() + } + checkAnswer( sql("SELECT SUM(a) FROM testData2 GROUP BY 2"), Seq(Row(6), Row(6))) @@ -513,6 +522,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { intercept[UnresolvedException[SortOrder]] { sql("SELECT * FROM testData2 ORDER BY -1 DESC, b ASC").collect() } + intercept[UnresolvedException[SortOrder]] { + sql("SELECT * FROM testData2 ORDER BY 3 DESC, b ASC").collect() + } checkAnswer( sql("SELECT * FROM testData2 ORDER BY 1 DESC"), Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2)))