-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-12789]Support order by index and group by index #10731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatal | |
| import org.apache.spark.sql.catalyst.encoders.OuterScopes | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate._ | ||
| import org.apache.spark.sql.catalyst.planning.IntegerIndex | ||
| import org.apache.spark.sql.catalyst.plans._ | ||
| import org.apache.spark.sql.catalyst.plans.logical._ | ||
| import org.apache.spark.sql.catalyst.rules._ | ||
|
|
@@ -497,6 +498,41 @@ class Analyzer( | |
| ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) | ||
| Sort(newOrdering, global, child) | ||
|
|
||
| // Replace the index with the related attribute for ORDER BY | ||
| // which is a 1-base position of the projection list. | ||
| case s @ Sort(orders, global, child) if child.resolved && | ||
| orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty) => | ||
| val newOrders = orders map { | ||
| case s @ SortOrder(IntegerIndex(index), direction) => | ||
| if (index > 0 && index <= child.output.size) { | ||
| SortOrder(child.output(index - 1), direction) | ||
| } else { | ||
| throw new UnresolvedException(s, | ||
| s"""Order by position: $index does not exist \n | ||
| |The Select List is indexed from 1 to ${child.output.size}""".stripMargin) | ||
| } | ||
| case other => other | ||
| } | ||
| Sort(newOrders, global, child) | ||
|
|
||
| // Replace the index with the related attribute for Group BY | ||
| // which is a 1-base position of the underlying columns. | ||
| case a @ Aggregate(groups, aggs, child) if child.resolved && | ||
| groups.exists(IntegerIndex.unapply(_).nonEmpty) => | ||
| val newGroups = groups map { | ||
| case IntegerIndex(index) => | ||
| val attributes = child.output | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if this is correct. 1 and 2 in the group by here refers to the position 1 and 2 in the select list, not the underlying query output.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we also need to check whether the position is an aggregate function. in postgres |
||
| if (index > 0 && index <= attributes.size) { | ||
| attributes(index - 1) | ||
| } else { | ||
| throw new UnresolvedException(a, | ||
| s"""Aggregate by position: $index does not exist \n | ||
| |The Select List is indexed from 1 to ${attributes.size}""".stripMargin) | ||
| } | ||
| case other => other | ||
| } | ||
| Aggregate(newGroups, aggs, child) | ||
|
|
||
| // A special case for Generate, because the output of Generate should not be resolved by | ||
| // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. | ||
| case g @ Generate(generator, join, outer, qualifier, output, child) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ import org.apache.spark.Logging | |
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.plans._ | ||
| import org.apache.spark.sql.catalyst.plans.logical._ | ||
| import org.apache.spark.sql.types.IntegerType | ||
|
|
||
| /** | ||
| * A pattern that matches any number of project or filter operations on top of another relational | ||
|
|
@@ -202,3 +203,14 @@ object Unions { | |
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Extractor for retrieving Int value. | ||
| */ | ||
| object IntegerIndex { | ||
| def unapply(a: Any): Option[Int] = a match { | ||
| case Literal(a: Int, IntegerType) => Some(a) | ||
| case UnaryMinus(IntegerLiteral(v)) => Some(-v) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it standard to support -(-1)? I see postgres support it, but somewhat strange to me.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line is used for catching the illegal case: sql("SELECT * FROM testData2 ORDER BY -1 DESC, b ASC").collect()I plan to keep it untouched in the PR. Thanks! |
||
| case _ => None | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,10 +21,10 @@ import java.math.MathContext | |
| import java.sql.Timestamp | ||
|
|
||
| import org.apache.spark.AccumulatorSuite | ||
| import org.apache.spark.sql.catalyst.CatalystQl | ||
| import org.apache.spark.sql.catalyst.analysis.FunctionRegistry | ||
| import org.apache.spark.sql.catalyst.parser.ParserConf | ||
| import org.apache.spark.sql.execution.{aggregate, SparkQl} | ||
| import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, UnresolvedException} | ||
| import org.apache.spark.sql.catalyst.expressions.SortOrder | ||
| import org.apache.spark.sql.catalyst.plans.logical.Aggregate | ||
| import org.apache.spark.sql.execution.aggregate | ||
| import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} | ||
| import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} | ||
| import org.apache.spark.sql.functions._ | ||
|
|
@@ -461,22 +461,41 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { | |
| } | ||
|
|
||
| test("literal in agg grouping expressions") { | ||
| intercept[UnresolvedException[Aggregate]] { | ||
| sql("SELECT * FROM testData2 GROUP BY -1").collect() | ||
| } | ||
|
|
||
| intercept[UnresolvedException[Aggregate]] { | ||
| sql("SELECT * FROM testData2 GROUP BY 3").collect() | ||
| } | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), | ||
| Seq(Row(1, 2), Row(2, 2), Row(3, 2))) | ||
| sql("SELECT SUM(a) FROM testData2 GROUP BY 2"), | ||
| Seq(Row(6), Row(6))) | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), | ||
| sql("SELECT a, SUM(b) FROM testData2 GROUP BY 1"), | ||
| Seq(Row(1, 3), Row(2, 3), Row(3, 3))) | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), | ||
| Seq(Row(1, 2), Row(2, 2), Row(3, 2))) | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), | ||
|
||
| sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), | ||
| sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) | ||
| sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), | ||
| Seq(Row(1, 1), Row(1, 1), Row(2, 1), Row(2, 1), Row(3, 1), Row(3, 1))) | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), | ||
| sql("SELECT 1, 2, sum(b) FROM testData2")) | ||
| sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY a, b")) | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), | ||
| sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) | ||
| } | ||
|
|
||
| test("aggregates with nulls") { | ||
|
|
@@ -499,6 +518,24 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { | |
| Row("1")) | ||
| } | ||
|
|
||
| test("order by number index") { | ||
| intercept[UnresolvedException[SortOrder]] { | ||
| sql("SELECT * FROM testData2 ORDER BY -1 DESC, b ASC").collect() | ||
| } | ||
| intercept[UnresolvedException[SortOrder]] { | ||
| sql("SELECT * FROM testData2 ORDER BY 3 DESC, b ASC").collect() | ||
| } | ||
| checkAnswer( | ||
| sql("SELECT * FROM testData2 ORDER BY 1 DESC"), | ||
| Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) | ||
| checkAnswer( | ||
| sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"), | ||
| sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC")) | ||
| checkAnswer( | ||
| sql("SELECT * FROM testData2 ORDER BY 1 ASC, b ASC"), | ||
| Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) | ||
| } | ||
|
|
||
| def sortTest(): Unit = { | ||
| checkAnswer( | ||
| sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this rule is getting pretty long -- i wonder if there are ways to break it down
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will move it to the rule
ResolveSortReferencesThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am unable to find a good place for
group by ordinal resolution, after placingorder by ordinal resolutioninResolveSortReferences. Two options are in my mind:ResolveReferencesto two rules:ResolveStarandResolveReferences. Then,ResolveReferencesis not very long, maybe we can keep resolution of ordinal here.ResolveOrdinalfor both cases.In the next PR, I will first pick the second option, if nobody is against it. : ) Thanks!