diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 004c1caaffec..ebc0b6777035 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatal import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.planning.IntegerIndex import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -497,6 +498,41 @@ class Analyzer( ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) Sort(newOrdering, global, child) + // Replace the index with the related attribute for ORDER BY + // which is a 1-base position of the projection list. + case s @ Sort(orders, global, child) if child.resolved && + orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty) => + val newOrders = orders map { + case s @ SortOrder(IntegerIndex(index), direction) => + if (index > 0 && index <= child.output.size) { + SortOrder(child.output(index - 1), direction) + } else { + throw new UnresolvedException(s, + s"""Order by position: $index does not exist \n + |The Select List is indexed from 1 to ${child.output.size}""".stripMargin) + } + case other => other + } + Sort(newOrders, global, child) + + // Replace the index with the related attribute for Group BY + // which is a 1-base position of the underlying columns. + case a @ Aggregate(groups, aggs, child) if child.resolved && + groups.exists(IntegerIndex.unapply(_).nonEmpty) => + val newGroups = groups map { + case IntegerIndex(index) => + val attributes = child.output + if (index > 0 && index <= attributes.size) { + attributes(index - 1) + } else { + throw new UnresolvedException(a, + s"""Aggregate by position: $index does not exist \n + |The Select List is indexed from 1 to ${attributes.size}""".stripMargin) + } + case other => other + } + Aggregate(newGroups, aggs, child) + // A special case for Generate, because the output of Generate should not be resolved by // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. case g @ Generate(generator, join, outer, qualifier, output, child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 7302b63646d6..edcda3584df5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -24,6 +24,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.IntegerType /** * A pattern that matches any number of project or filter operations on top of another relational @@ -202,3 +203,14 @@ object Unions { } } } + +/** + * Extractor for retrieving Int value. + */ +object IntegerIndex { + def unapply(a: Any): Option[Int] = a match { + case Literal(a: Int, IntegerType) => Some(a) + case UnaryMinus(IntegerLiteral(v)) => Some(-v) + case _ => None + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b3e179755a19..d4379031469d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,10 +21,10 @@ import java.math.MathContext import java.sql.Timestamp import org.apache.spark.AccumulatorSuite -import org.apache.spark.sql.catalyst.CatalystQl -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.parser.ParserConf -import org.apache.spark.sql.execution.{aggregate, SparkQl} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, UnresolvedException} +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ @@ -461,22 +461,41 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("literal in agg grouping expressions") { + intercept[UnresolvedException[Aggregate]] { + sql("SELECT * FROM testData2 GROUP BY -1").collect() + } + + intercept[UnresolvedException[Aggregate]] { + sql("SELECT * FROM testData2 GROUP BY 3").collect() + } + checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + sql("SELECT SUM(a) FROM testData2 GROUP BY 2"), + Seq(Row(6), Row(6))) + checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + sql("SELECT a, SUM(b) FROM testData2 GROUP BY 1"), + Seq(Row(1, 3), Row(2, 3), Row(3, 3))) + + checkAnswer( + sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), Seq(Row(1, 2), Row(2, 2), Row(3, 2))) checkAnswer( sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + Seq(Row(1, 1), Row(1, 1), Row(2, 1), Row(2, 1), Row(3, 1), Row(3, 1))) + checkAnswer( sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), - sql("SELECT 1, 2, sum(b) FROM testData2")) + sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY a, b")) + + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) } test("aggregates with nulls") { @@ -499,6 +518,24 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row("1")) } + test("order by number index") { + intercept[UnresolvedException[SortOrder]] { + sql("SELECT * FROM testData2 ORDER BY -1 DESC, b ASC").collect() + } + intercept[UnresolvedException[SortOrder]] { + sql("SELECT * FROM testData2 ORDER BY 3 DESC, b ASC").collect() + } + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY 1 DESC"), + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"), + sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC")) + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY 1 ASC, b ASC"), + Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) + } + def sortTest(): Unit = { checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"),