Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatal
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.planning.IntegerIndex
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand Down Expand Up @@ -497,6 +498,41 @@ class Analyzer(
ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder])
Sort(newOrdering, global, child)

// Replace the index with the related attribute for ORDER BY
// which is a 1-base position of the projection list.
case s @ Sort(orders, global, child) if child.resolved &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this rule is getting pretty long -- i wonder if there are ways to break it down

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will move it to the rule ResolveSortReferences

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am unable to find a good place for group by ordinal resolution, after placing order by ordinal resolution in ResolveSortReferences. Two options are in my mind:

In the next PR, I will first pick the second option, if nobody is against it. : ) Thanks!

orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty) =>
val newOrders = orders map {
case s @ SortOrder(IntegerIndex(index), direction) =>
if (index > 0 && index <= child.output.size) {
SortOrder(child.output(index - 1), direction)
} else {
throw new UnresolvedException(s,
s"""Order by position: $index does not exist \n
|The Select List is indexed from 1 to ${child.output.size}""".stripMargin)
}
case other => other
}
Sort(newOrders, global, child)

// Replace the index with the related attribute for Group BY
// which is a 1-base position of the underlying columns.
case a @ Aggregate(groups, aggs, child) if child.resolved &&
groups.exists(IntegerIndex.unapply(_).nonEmpty) =>
val newGroups = groups map {
case IntegerIndex(index) =>
val attributes = child.output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is correct. 1 and 2 in the group by here refers to the position 1 and 2 in the select list, not the underlying query output.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need to check whether the position is an aggregate function. in postgres

rxin=# select 'one', 'two', count(*) from r1 group by 1, 3;
ERROR:  aggregate functions are not allowed in GROUP BY
LINE 1: select 'one', 'two', count(*) from r1 group by 1, 3;
                             ^

if (index > 0 && index <= attributes.size) {
attributes(index - 1)
} else {
throw new UnresolvedException(a,
s"""Aggregate by position: $index does not exist \n
|The Select List is indexed from 1 to ${attributes.size}""".stripMargin)
}
case other => other
}
Aggregate(newGroups, aggs, child)

// A special case for Generate, because the output of Generate should not be resolved by
// ResolveReferences. Attributes in the output will be resolved by ResolveGenerate.
case g @ Generate(generator, join, outer, qualifier, output, child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType

/**
* A pattern that matches any number of project or filter operations on top of another relational
Expand Down Expand Up @@ -202,3 +203,14 @@ object Unions {
}
}
}

/**
* Extractor for retrieving Int value.
*/
object IntegerIndex {
def unapply(a: Any): Option[Int] = a match {
case Literal(a: Int, IntegerType) => Some(a)
case UnaryMinus(IntegerLiteral(v)) => Some(-v)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it standard to support -(-1)? I see postgres support it, but somewhat strange to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is used for catching the illegal case:

      sql("SELECT * FROM testData2 ORDER BY -1 DESC, b ASC").collect()

I plan to keep it untouched in the PR. Thanks!

case _ => None
}
}
57 changes: 47 additions & 10 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import java.math.MathContext
import java.sql.Timestamp

import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.catalyst.CatalystQl
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.parser.ParserConf
import org.apache.spark.sql.execution.{aggregate, SparkQl}
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -461,22 +461,41 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}

test("literal in agg grouping expressions") {
intercept[UnresolvedException[Aggregate]] {
sql("SELECT * FROM testData2 GROUP BY -1").collect()
}

intercept[UnresolvedException[Aggregate]] {
sql("SELECT * FROM testData2 GROUP BY 3").collect()
}

checkAnswer(
sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
sql("SELECT SUM(a) FROM testData2 GROUP BY 2"),
Seq(Row(6), Row(6)))

checkAnswer(
sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
sql("SELECT a, SUM(b) FROM testData2 GROUP BY 1"),
Seq(Row(1, 3), Row(2, 3), Row(3, 3)))

checkAnswer(
sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
Seq(Row(1, 2), Row(2, 2), Row(3, 2)))

checkAnswer(
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the meaning of 1 in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 would be refer to the first index of columns for group by which is a for this case. kind of redundant here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have a table with a single column, what will happen if we call GROUP BY 1, 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would throw exception like "Invalid call to Aggregate by position: 2 does not exist", let me add that as an unit-test.

sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))

checkAnswer(
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"),
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
Seq(Row(1, 1), Row(1, 1), Row(2, 1), Row(2, 1), Row(3, 1), Row(3, 1)))

checkAnswer(
sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"),
sql("SELECT 1, 2, sum(b) FROM testData2"))
sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY a, b"))

checkAnswer(
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"),
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
}

test("aggregates with nulls") {
Expand All @@ -499,6 +518,24 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Row("1"))
}

test("order by number index") {
intercept[UnresolvedException[SortOrder]] {
sql("SELECT * FROM testData2 ORDER BY -1 DESC, b ASC").collect()
}
intercept[UnresolvedException[SortOrder]] {
sql("SELECT * FROM testData2 ORDER BY 3 DESC, b ASC").collect()
}
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY 1 DESC"),
Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2)))
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"),
sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"))
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY 1 ASC, b ASC"),
Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2)))
}

def sortTest(): Unit = {
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"),
Expand Down