Skip to content

Commit e61429f

Browse files
committed
support order by index and group by index
refactor style didn't catched in local style forgot to enbale style checking for test in local override resolved refactor a version refactor
1 parent d741599 commit e61429f

File tree

2 files changed

+100
-10
lines changed

2 files changed

+100
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class Analyzer(
7171
Batch("Resolution", fixedPoint,
7272
ResolveRelations ::
7373
ResolveReferences ::
74+
ResolveIndexReferences ::
7475
ResolveGroupingAnalytics ::
7576
ResolvePivot ::
7677
ResolveUpCast ::
@@ -321,6 +322,62 @@ class Analyzer(
321322
}
322323
}
323324

325+
/**
326+
* Replace the index with the related attribute for order by and group by.
327+
*/
328+
object ResolveIndexReferences extends Rule[LogicalPlan] {
329+
330+
def indexToColumn(
331+
sortOrder: SortOrder,
332+
child: LogicalPlan,
333+
index: Int,
334+
direction: SortDirection): SortOrder = {
335+
val orderNodes = child.output
336+
if (index > 0 && index <= orderNodes.size) {
337+
SortOrder(orderNodes(index - 1), direction)
338+
} else {
339+
throw new UnresolvedException(sortOrder,
340+
s"""Order by position: $index does not exist \n
341+
|The Select List is indexed from 1 to ${orderNodes.size}""".stripMargin)
342+
}
343+
}
344+
345+
def indexToColumn(
346+
agg: Aggregate,
347+
child: LogicalPlan,
348+
index: Int): Expression = {
349+
val attributes = child.output
350+
if (index > 0 && index <= attributes.size) {
351+
attributes(index - 1)
352+
} else {
353+
throw new UnresolvedException(agg,
354+
s"""Order by position: $index does not exist \n
355+
|The Select List is indexed from 1 to ${attributes.size}""".stripMargin)
356+
}
357+
}
358+
359+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
360+
case s @ Sort(orders, global, child) if child.resolved =>
361+
val newOrders = orders map {
362+
case s @ SortOrder(IntegerLiteral(index), direction) =>
363+
indexToColumn(s, child, index, direction)
364+
case s @ SortOrder(UnaryMinus(IntegerLiteral(index)), direction) =>
365+
indexToColumn(s, child, -index, direction)
366+
case other => other
367+
}
368+
Sort(newOrders, global, child)
369+
case a @ Aggregate(groups, aggs, child) =>
370+
val newGroups = groups map {
371+
case IntegerLiteral(index) =>
372+
indexToColumn(a, child, index)
373+
case UnaryMinus(IntegerLiteral(index)) =>
374+
indexToColumn(a, child, -index)
375+
case other => other
376+
}
377+
Aggregate(newGroups, aggs, child)
378+
}
379+
}
380+
324381
/**
325382
* Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
326383
* a logical plan node's children.
@@ -446,6 +503,10 @@ class Analyzer(
446503
val newOrdering = resolveSortOrders(ordering, child, throws = false)
447504
Sort(newOrdering, global, child)
448505

506+
case s @ Sort(ordering, global, child) if child.resolved && !s.resolved =>
507+
val newOrdering = resolveSortOrders(ordering, child, throws = false)
508+
Sort(newOrdering, global, child)
509+
449510
// A special case for Generate, because the output of Generate should not be resolved by
450511
// ResolveReferences. Attributes in the output will be resolved by ResolveGenerate.
451512
case g @ Generate(generator, join, outer, qualifier, output, child)

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@ import java.math.MathContext
2121
import java.sql.Timestamp
2222

2323
import org.apache.spark.AccumulatorSuite
24-
import org.apache.spark.sql.catalyst.CatalystQl
25-
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
26-
import org.apache.spark.sql.catalyst.parser.ParserConf
27-
import org.apache.spark.sql.execution.{aggregate, SparkQl}
24+
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, UnresolvedException}
25+
import org.apache.spark.sql.catalyst.expressions.SortOrder
26+
import org.apache.spark.sql.execution.aggregate
2827
import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin}
2928
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
3029
import org.apache.spark.sql.functions._
@@ -456,21 +455,36 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
456455

457456
test("literal in agg grouping expressions") {
458457
checkAnswer(
459-
sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
460-
Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
458+
sql("SELECT SUM(a) FROM testData2 GROUP BY 2"),
459+
Seq(Row(6), Row(6)))
460+
461461
checkAnswer(
462-
sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
462+
sql("SELECT a, SUM(b) FROM testData2 GROUP BY 1"),
463+
Seq(Row(1, 3), Row(2, 3), Row(3, 3)))
464+
465+
checkAnswer(
466+
sql("SELECT a, SUM(b) FROM testData2 GROUP BY 1"),
467+
Seq(Row(1, 3), Row(2, 3), Row(3, 3)))
468+
469+
checkAnswer(
470+
sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
463471
Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
464472

465473
checkAnswer(
466474
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"),
467475
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
476+
468477
checkAnswer(
469-
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"),
470-
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
478+
sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
479+
Seq(Row(1, 1), Row(1, 1), Row(2, 1), Row(2, 1), Row(3, 1), Row(3, 1)))
480+
471481
checkAnswer(
472482
sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"),
473-
sql("SELECT 1, 2, sum(b) FROM testData2"))
483+
sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY a, b"))
484+
485+
checkAnswer(
486+
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"),
487+
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
474488
}
475489

476490
test("aggregates with nulls") {
@@ -493,6 +507,21 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
493507
Row("1"))
494508
}
495509

510+
test("order by number index") {
511+
intercept[UnresolvedException[SortOrder]] {
512+
sql("SELECT * FROM testData2 ORDER BY -1 DESC, b ASC").collect()
513+
}
514+
checkAnswer(
515+
sql("SELECT * FROM testData2 ORDER BY 1 DESC"),
516+
Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2)))
517+
checkAnswer(
518+
sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"),
519+
sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"))
520+
checkAnswer(
521+
sql("SELECT * FROM testData2 ORDER BY 1 ASC, b ASC"),
522+
Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2)))
523+
}
524+
496525
def sortTest(): Unit = {
497526
checkAnswer(
498527
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"),

0 commit comments

Comments
 (0)