Skip to content

Commit 05f652d

Browse files
gatorsmilecloud-fan
authored andcommitted
[SPARK-13957][SQL] Support Group By Ordinal in SQL
#### What changes were proposed in this pull request? This PR is to support group by position in SQL. For example, when users input the following query ```SQL select c1 as a, c2, c3, sum(*) from tbl group by 1, 3, c4 ``` The ordinals are recognized as the positions in the select list. Thus, `Analyzer` converts it to ```SQL select c1, c2, c3, sum(*) from tbl group by c1, c3, c4 ``` This is controlled by the config option `spark.sql.groupByOrdinal`. - When true, the ordinal numbers in group by clauses are treated as the position in the select list. - When false, the ordinal numbers are ignored. - Only convert integer literals (not foldable expressions). If found foldable expressions, ignore them. - When the positions specified in the group by clauses correspond to the aggregate functions in select list, output an exception message. - star is not allowed to use in the select list when users specify ordinals in group by Note: This PR is taken from apache#10731. When merging this PR, please give the credit to zhichao-li Also cc all the people who are involved in the previous discussion: rxin cloud-fan marmbrus yhuai hvanhovell adrian-wang chenghao-intel tejasapatil #### How was this patch tested? Added a few test cases for both positive and negative test cases. Author: gatorsmile <[email protected]> Author: xiaoli <[email protected]> Author: Xiao Li <[email protected]> Closes apache#11846 from gatorsmile/groupByOrdinal.
1 parent 0874ff3 commit 05f652d

File tree

5 files changed

+156
-25
lines changed

5 files changed

+156
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ private[spark] trait CatalystConf {
2323
def caseSensitiveAnalysis: Boolean
2424

2525
def orderByOrdinal: Boolean
26+
def groupByOrdinal: Boolean
2627

2728
/**
2829
* Returns the [[Resolver]] for the current configuration, which can be used to determin if two
@@ -48,11 +49,16 @@ object EmptyConf extends CatalystConf {
4849
override def orderByOrdinal: Boolean = {
4950
throw new UnsupportedOperationException
5051
}
52+
override def groupByOrdinal: Boolean = {
53+
throw new UnsupportedOperationException
54+
}
5155
}
5256

5357
/** A CatalystConf that can be used for local testing. */
5458
case class SimpleCatalystConf(
5559
caseSensitiveAnalysis: Boolean,
56-
orderByOrdinal: Boolean = true)
60+
orderByOrdinal: Boolean = true,
61+
groupByOrdinal: Boolean = true)
62+
5763
extends CatalystConf {
5864
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class Analyzer(
8585
ResolveGroupingAnalytics ::
8686
ResolvePivot ::
8787
ResolveUpCast ::
88+
ResolveOrdinalInOrderByAndGroupBy ::
8889
ResolveSortReferences ::
8990
ResolveGenerate ::
9091
ResolveFunctions ::
@@ -385,7 +386,13 @@ class Analyzer(
385386
p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
386387
// If the aggregate function argument contains Stars, expand it.
387388
case a: Aggregate if containsStar(a.aggregateExpressions) =>
388-
a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
389+
if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) {
390+
failAnalysis(
391+
"Group by position: star is not allowed to use in the select list " +
392+
"when using ordinals in group by")
393+
} else {
394+
a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
395+
}
389396
// If the script transformation input contains Stars, expand it.
390397
case t: ScriptTransformation if containsStar(t.input) =>
391398
t.copy(
@@ -634,21 +641,23 @@ class Analyzer(
634641
}
635642
}
636643

637-
/**
638-
* In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
639-
* clause. This rule detects such queries and adds the required attributes to the original
640-
* projection, so that they will be available during sorting. Another projection is added to
641-
* remove these attributes after sorting.
642-
*
643-
* This rule also resolves the position number in sort references. This support is introduced
644-
* in Spark 2.0. Before Spark 2.0, the integers in Order By has no effect on output sorting.
645-
* - When the sort references are not integer but foldable expressions, ignore them.
646-
* - When spark.sql.orderByOrdinal is set to false, ignore the position numbers too.
647-
*/
648-
object ResolveSortReferences extends Rule[LogicalPlan] {
644+
/**
645+
* In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by
646+
* clauses. This rule is to convert ordinal positions to the corresponding expressions in the
647+
* select list. This support is introduced in Spark 2.0.
648+
*
649+
* - When the sort references or group by expressions are not integer but foldable expressions,
650+
* just ignore them.
651+
* - When spark.sql.orderByOrdinal/spark.sql.groupByOrdinal is set to false, ignore the position
652+
* numbers too.
653+
*
654+
* Before the release of Spark 2.0, the literals in order/sort by and group by clauses
655+
* have no effect on the results.
656+
*/
657+
object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] {
649658
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
650-
case s: Sort if !s.child.resolved => s
651-
// Replace the index with the related attribute for ORDER BY
659+
case p if !p.childrenResolved => p
660+
// Replace the index with the related attribute for ORDER BY,
652661
// which is a 1-base position of the projection list.
653662
case s @ Sort(orders, global, child)
654663
if conf.orderByOrdinal && orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty) =>
@@ -665,10 +674,41 @@ class Analyzer(
665674
}
666675
Sort(newOrders, global, child)
667676

677+
// Replace the index with the corresponding expression in aggregateExpressions. The index is
678+
// a 1-base position of aggregateExpressions, which is output columns (select expression)
679+
case a @ Aggregate(groups, aggs, child)
680+
if conf.groupByOrdinal && aggs.forall(_.resolved) &&
681+
groups.exists(IntegerIndex.unapply(_).nonEmpty) =>
682+
val newGroups = groups.map {
683+
case IntegerIndex(index) if index > 0 && index <= aggs.size =>
684+
aggs(index - 1) match {
685+
case e if ResolveAggregateFunctions.containsAggregate(e) =>
686+
throw new UnresolvedException(a,
687+
s"Group by position: the '$index'th column in the select contains an " +
688+
s"aggregate function: ${e.sql}. Aggregate functions are not allowed in GROUP BY")
689+
case o => o
690+
}
691+
case IntegerIndex(index) =>
692+
throw new UnresolvedException(a,
693+
s"Group by position: '$index' exceeds the size of the select list '${aggs.size}'.")
694+
case o => o
695+
}
696+
Aggregate(newGroups, aggs, child)
697+
}
698+
}
699+
700+
/**
701+
* In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
702+
* clause. This rule detects such queries and adds the required attributes to the original
703+
* projection, so that they will be available during sorting. Another projection is added to
704+
* remove these attributes after sorting.
705+
*/
706+
object ResolveSortReferences extends Rule[LogicalPlan] {
707+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
668708
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
669709
case sa @ Sort(_, _, child: Aggregate) => sa
670710

671-
case s @ Sort(order, _, child) if !s.resolved =>
711+
case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
672712
try {
673713
val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
674714
val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,8 @@ object Unions {
210210
object IntegerIndex {
211211
def unapply(a: Any): Option[Int] = a match {
212212
case Literal(a: Int, IntegerType) => Some(a)
213-
// When resolving ordinal in Sort, negative values are extracted for issuing error messages.
213+
// When resolving ordinal in Sort and Group By, negative values are extracted
214+
// for issuing error messages.
214215
case UnaryMinus(IntegerLiteral(v)) => Some(-v)
215216
case _ => None
216217
}

sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,11 @@ object SQLConf {
445445
doc = "When true, the ordinal numbers are treated as the position in the select list. " +
446446
"When false, the ordinal numbers in order/sort By clause are ignored.")
447447

448+
val GROUP_BY_ORDINAL = booleanConf("spark.sql.groupByOrdinal",
449+
defaultValue = Some(true),
450+
doc = "When true, the ordinal numbers in group by clauses are treated as the position " +
451+
"in the select list. When false, the ordinal numbers are ignored.")
452+
448453
// The output committer class used by HadoopFsRelation. The specified class needs to be a
449454
// subclass of org.apache.hadoop.mapreduce.OutputCommitter.
450455
//
@@ -668,6 +673,7 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin
668673

669674
override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL)
670675

676+
override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
671677
/** ********************** SQLConf functionality methods ************ */
672678

673679
/** Set Spark SQL configuration properties. */

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.sql.Timestamp
2323
import org.apache.spark.AccumulatorSuite
2424
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
2525
import org.apache.spark.sql.catalyst.expressions.SortOrder
26+
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
2627
import org.apache.spark.sql.execution.aggregate
2728
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin}
2829
import org.apache.spark.sql.functions._
@@ -459,25 +460,103 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
459460
Seq(Row(1, 3), Row(2, 3), Row(3, 3)))
460461
}
461462

462-
test("literal in agg grouping expressions") {
463+
test("Group By Ordinal - basic") {
463464
checkAnswer(
464-
sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
465-
Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
466-
checkAnswer(
467-
sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
468-
Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
465+
sql("SELECT a, sum(b) FROM testData2 GROUP BY 1"),
466+
sql("SELECT a, sum(b) FROM testData2 GROUP BY a"))
469467

468+
// duplicate group-by columns
470469
checkAnswer(
471470
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"),
472471
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
472+
473+
checkAnswer(
474+
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY 1, 2"),
475+
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
476+
}
477+
478+
test("Group By Ordinal - non aggregate expressions") {
479+
checkAnswer(
480+
sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, 2"),
481+
sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
482+
483+
checkAnswer(
484+
sql("SELECT a, b + 2 as c, count(2) FROM testData2 GROUP BY a, 2"),
485+
sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
486+
}
487+
488+
test("Group By Ordinal - non-foldable constant expression") {
489+
checkAnswer(
490+
sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b, 1 + 0"),
491+
sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b"))
492+
473493
checkAnswer(
474494
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"),
475495
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
496+
}
497+
498+
test("Group By Ordinal - alias") {
499+
checkAnswer(
500+
sql("SELECT a, (b + 2) as c, count(2) FROM testData2 GROUP BY a, 2"),
501+
sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
502+
503+
checkAnswer(
504+
sql("SELECT a as b, b as a, sum(b) FROM testData2 GROUP BY 1, 2"),
505+
sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b"))
506+
}
507+
508+
test("Group By Ordinal - constants") {
476509
checkAnswer(
477510
sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"),
478511
sql("SELECT 1, 2, sum(b) FROM testData2"))
479512
}
480513

514+
test("Group By Ordinal - negative cases") {
515+
intercept[UnresolvedException[Aggregate]] {
516+
sql("SELECT a, b FROM testData2 GROUP BY -1")
517+
}
518+
519+
intercept[UnresolvedException[Aggregate]] {
520+
sql("SELECT a, b FROM testData2 GROUP BY 3")
521+
}
522+
523+
var e = intercept[UnresolvedException[Aggregate]](
524+
sql("SELECT SUM(a) FROM testData2 GROUP BY 1"))
525+
assert(e.getMessage contains
526+
"Invalid call to Group by position: the '1'th column in the select contains " +
527+
"an aggregate function")
528+
529+
e = intercept[UnresolvedException[Aggregate]](
530+
sql("SELECT SUM(a) + 1 FROM testData2 GROUP BY 1"))
531+
assert(e.getMessage contains
532+
"Invalid call to Group by position: the '1'th column in the select contains " +
533+
"an aggregate function")
534+
535+
var ae = intercept[AnalysisException](
536+
sql("SELECT a, rand(0), sum(b) FROM testData2 GROUP BY a, 2"))
537+
assert(ae.getMessage contains
538+
"nondeterministic expression rand(0) should not appear in grouping expression")
539+
540+
ae = intercept[AnalysisException](
541+
sql("SELECT * FROM testData2 GROUP BY a, b, 1"))
542+
assert(ae.getMessage contains
543+
"Group by position: star is not allowed to use in the select list " +
544+
"when using ordinals in group by")
545+
}
546+
547+
test("Group By Ordinal: spark.sql.groupByOrdinal=false") {
548+
withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "false") {
549+
// If spark.sql.groupByOrdinal=false, ignore the position number.
550+
intercept[AnalysisException] {
551+
sql("SELECT a, sum(b) FROM testData2 GROUP BY 1")
552+
}
553+
// '*' is not allowed to use in the select list when users specify ordinals in group by
554+
checkAnswer(
555+
sql("SELECT * FROM testData2 GROUP BY a, b, 1"),
556+
sql("SELECT * FROM testData2 GROUP BY a, b"))
557+
}
558+
}
559+
481560
test("aggregates with nulls") {
482561
checkAnswer(
483562
sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," +
@@ -2174,7 +2253,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
21742253
checkAnswer(
21752254
sql("SELECT * FROM testData2 ORDER BY 1 + 0 DESC, b ASC"),
21762255
sql("SELECT * FROM testData2 ORDER BY b ASC"))
2177-
21782256
checkAnswer(
21792257
sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"),
21802258
sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"))

0 commit comments

Comments
 (0)