Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class Analyzer(
ResolveGroupingAnalytics ::
ResolvePivot ::
ResolveOrdinalInOrderByAndGroupBy ::
ResolveAggAliasInGroupBy ::
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it safer to put it in an individual batch after the resolution batch? Ideally we should only run this rule if we make sure there is no other way to resolve the grouping expressions exception this rule. cc @gatorsmile

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea to put this rule outside resolution batch is to uncheck grouping expression resolution in Aggregate.resloved. But, I feel this is a bit unsafe.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a postHocResolutionRules

Copy link
Member Author

@maropu maropu Apr 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aha, ok. I'll move there and check. Thanks!

ResolveMissingReferences ::
ExtractGenerator ::
ResolveGenerate ::
Expand Down Expand Up @@ -172,7 +173,7 @@ class Analyzer(
* Analyze cte definitions and substitute child plan with analyzed cte definitions.
*/
object CTESubstitution extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case With(child, relations) =>
substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) {
case (resolved, (name, relation)) =>
Expand Down Expand Up @@ -200,7 +201,7 @@ class Analyzer(
* Substitute child plan with WindowSpecDefinitions.
*/
object WindowsSubstitution extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
// Lookup WindowSpecDefinitions. This rule works with unresolved children.
case WithWindowDefinition(windowDefinitions, child) =>
child.transform {
Expand Down Expand Up @@ -242,7 +243,7 @@ class Analyzer(
private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) =
exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined)

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) =>
Aggregate(groups, assignAliases(aggs), child)

Expand Down Expand Up @@ -614,7 +615,7 @@ class Analyzer(
case _ => plan
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
EliminateSubqueryAliases(lookupTableFromCatalog(u)) match {
case v: View =>
Expand Down Expand Up @@ -786,7 +787,7 @@ class Analyzer(
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p

// If the projection list contains Stars, expand it.
Expand Down Expand Up @@ -844,11 +845,10 @@ class Analyzer(

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressionsUp {
q.transformExpressionsUp {
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
Copy link
Member

@gatorsmile gatorsmile Mar 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we should first reach here to resolve them using the children, and then resolve the unresolved expression in grouping expression using the alias defined in aggregation expression.

Currently, the order is wrong, but the guard aggs.forall(_.resolved) and groups.exists(!_.resolved) save us. I think we might need to separate it and add a new analyzer rule for ResolveAggAliasInGroupBy?

cc @cloud-fan what is your opinion?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. If we add a new rule for that, ResolveGroupByAlias or something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's have a separated rule

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, I'll update

logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
Expand Down Expand Up @@ -961,7 +961,7 @@ class Analyzer(
* have no effect on the results.
*/
object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.childrenResolved => p
// Replace the index with the related attribute for ORDER BY,
// which is a 1-base position of the projection list.
Expand Down Expand Up @@ -997,6 +997,27 @@ class Analyzer(
}
}

/**
* Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment to say that this rule has to be run after ResolveReferences

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I'll update

* This rule is expected to run after [[ResolveReferences]] applied.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove this now. With the new check, the order doesn't matter

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

*/
object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case agg @ Aggregate(groups, aggs, child)
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very confident about this condition, we may mistakenly resolve grouping expression by aggregate list while it should be resolved by child output.

one example is the star. If the aggregate list contains a star, then we will expand the star in ResolveReferences, without resolving grouping expressions. When we reach here, the condition will match but the grouping expression should not be resolved by aggregate list.

cc @gatorsmile

Copy link
Member Author

@maropu maropu Apr 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think an UnresolvedAttribute in grouping expressions here already implicitly indicates it is not in child plan's output?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure though, since the resolution batch currently has this rule, it seems this rule is firstly applied into unresolved grouping keys in the star case @cloud-fan suggested.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the UnresolvedAttribute in grouping expressions that can be resolved by child's output, should be already resolved by ResolveReferences.

Copy link
Member Author

@maropu maropu Apr 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got debug info and it seemed ResolveAggAliasInGroupBy was firstly applied in that case?

scala> sql("SELECT *, count(1) FROM (select 1 AS a, 1 AS b) GROUP BY a, b").show

=== Applying Rule org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences ===
!'Aggregate ['a, 'b], [*, unresolvedalias('count(1), None)]   'Aggregate ['a, 'b], [a#0, b#1, unresolvedalias('count(1), None)]
 +- Project [1 AS a#0, 1 AS b#1]                              +- Project [1 AS a#0, 1 AS b#1]
    +- OneRowRelation$                                           +- OneRowRelation$

<-- ResolveAggAliasInGroupBy applied -->

=== Applying Rule org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions ===
!'Aggregate ['a, 'b], [a#0, b#1, unresolvedalias('count(1), None)]   'Aggregate ['a, 'b], [a#0, b#1, unresolvedalias(count(1), None)]
 +- Project [1 AS a#0, 1 AS b#1]                                     +- Project [1 AS a#0, 1 AS b#1]
    +- OneRowRelation$                                                  +- OneRowRelation$

=== Applying Rule org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases ===
!'Aggregate ['a, 'b], [a#0, b#1, unresolvedalias(count(1), None)]   'Aggregate ['a, 'b], [a#0, b#1, count(1) AS count(1)#3L]
 +- Project [1 AS a#0, 1 AS b#1]                                    +- Project [1 AS a#0, 1 AS b#1]
    +- OneRowRelation$                                                 +- OneRowRelation$

=== Applying Rule org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences ===
!'Aggregate ['a, 'b], [a#0, b#1, count(1) AS count(1)#3L]   Aggregate [a#0, b#1], [a#0, b#1, count(1) AS count(1)#3L]
 +- Project [1 AS a#0, 1 AS b#1]                            +- Project [1 AS a#0, 1 AS b#1]
    +- OneRowRelation$                                         +- OneRowRelation$

<-- ResolveAggAliasInGroupBy applied -->

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I have an idea: first, check there is UnresolvedAttribute in grouping expression, then check these UnresolvedAttributes can't be resolved by child.output.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest fix is not enough for your suggestion?;
This fix checks if UnresolvedAttribute exists there, then filter out by child.output.
https://github.com/apache/spark/pull/17191/files#diff-57b3d87be744b7d79a9beacf8e5e5eb2R1014

 +        // This is a strict check though, we put this to apply the rule only in alias expressions
 +        def checkIfChildOutputHasNo(attrName: String): Boolean =
 +          !child.output.exists(a => resolver(a.name, attrName))
 +        agg.copy(groupingExpressions = groups.map {
 +          case u: UnresolvedAttribute if checkIfChildOutputHasNo(u.name) =>
 +            aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u)
 +          case e => e
 +        })

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea it works, but maybe give it a better name like notResolvableByChild

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

groups.exists(_.isInstanceOf[UnresolvedAttribute]) =>
// This is a strict check though, we put this to apply the rule only in alias expressions
def notResolvableByChild(attrName: String): Boolean =
!child.output.exists(a => resolver(a.name, attrName))
agg.copy(groupingExpressions = groups.map {
case u: UnresolvedAttribute if notResolvableByChild(u.name) =>
aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u)
case e => e
})
}
}

/**
* In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
* clause. This rule detects such queries and adds the required attributes to the original
Expand All @@ -1006,7 +1027,7 @@ class Analyzer(
* The HAVING clause could also used a grouping columns that is not presented in the SELECT.
*/
object ResolveMissingReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, child: Aggregate) => sa

Expand Down Expand Up @@ -1130,7 +1151,7 @@ class Analyzer(
* Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s.
*/
object ResolveFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case q: LogicalPlan =>
q transformExpressions {
case u if !u.childrenResolved => u // Skip until children are resolved.
Expand Down Expand Up @@ -1469,7 +1490,7 @@ class Analyzer(
/**
* Resolve and rewrite all subqueries in an operator tree..
*/
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
// In case of HAVING (a filter after an aggregate) we use both the aggregate and
// its child for resolution.
case f @ Filter(_, a: Aggregate) if f.childrenResolved =>
Expand All @@ -1484,7 +1505,7 @@ class Analyzer(
* Turns projections that contain aggregate expressions into aggregations.
*/
object GlobalAggregates extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Project(projectList, child) if containsAggregates(projectList) =>
Aggregate(Nil, projectList, child)
}
Expand All @@ -1510,7 +1531,7 @@ class Analyzer(
* underlying aggregate operator and then projected away after the original operator.
*/
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case filter @ Filter(havingCondition,
aggregate @ Aggregate(grouping, originalAggExprs, child))
if aggregate.resolved =>
Expand Down Expand Up @@ -1682,7 +1703,7 @@ class Analyzer(
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Project(projectList, _) if projectList.exists(hasNestedGenerator) =>
val nestedGenerator = projectList.find(hasNestedGenerator).get
throw new AnalysisException("Generators are not supported when it's nested in " +
Expand Down Expand Up @@ -1740,7 +1761,7 @@ class Analyzer(
* that wrap the [[Generator]].
*/
object ResolveGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case g: Generate if !g.child.resolved || !g.generator.resolved => g
case g: Generate if !g.resolved =>
g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
Expand Down Expand Up @@ -2057,7 +2078,7 @@ class Analyzer(
* put them into an inner Project and finally project them away at the outer Project.
*/
object PullOutNondeterministic extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.resolved => p // Skip unresolved nodes.
case p: Project => p
case f: Filter => f
Expand Down Expand Up @@ -2102,7 +2123,7 @@ class Analyzer(
* and we should return null if the input is null.
*/
object HandleNullInputsForUDF extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.resolved => p // Skip unresolved nodes.

case p => p transformExpressionsUp {
Expand Down Expand Up @@ -2167,7 +2188,7 @@ class Analyzer(
* Then apply a Project on a normal Join to eliminate natural or using join.
*/
object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case j @ Join(left, right, UsingJoin(joinType, usingCols), condition)
if left.resolved && right.resolved && j.duplicateResolved =>
commonNaturalJoinProcessing(left, right, joinType, usingCols, None)
Expand Down Expand Up @@ -2232,7 +2253,7 @@ class Analyzer(
* to the given input attributes.
*/
object ResolveDeserializer extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.childrenResolved => p
case p if p.resolved => p

Expand Down Expand Up @@ -2318,7 +2339,7 @@ class Analyzer(
* constructed is an inner class.
*/
object ResolveNewInstance extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.childrenResolved => p
case p if p.resolved => p

Expand Down Expand Up @@ -2352,7 +2373,7 @@ class Analyzer(
"type of the field in the target object")
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.childrenResolved => p
case p if p.resolved => p

Expand Down Expand Up @@ -2406,7 +2427,7 @@ object CleanupAliases extends Rule[LogicalPlan] {
case other => trimAliases(other)
}

override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Project(projectList, child) =>
val cleanedProjectList =
projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
Expand Down Expand Up @@ -2474,7 +2495,7 @@ object TimeWindowing extends Rule[LogicalPlan] {
* @return the logical plan that will generate the time windows using the Expand operator, with
* the Filter operator for correctness and Project for usability.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p: LogicalPlan if p.children.size == 1 =>
val child = p.children.head
val windowExpressions =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,12 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val GROUP_BY_ALIASES = buildConf("spark.sql.groupByAliases")
.doc("When true, aliases in a select list can be used in group by clauses. When false, " +
"an analysis exception is thrown in the case.")
.booleanConf
.createWithDefault(true)

// The output committer class used by data sources. The specified class needs to be a
// subclass of org.apache.hadoop.mapreduce.OutputCommitter.
val OUTPUT_COMMITTER_CLASS =
Expand Down Expand Up @@ -1003,6 +1009,8 @@ class SQLConf extends Serializable with Logging {

def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)

def groupByAliases: Boolean = getConf(GROUP_BY_ALIASES)

def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)

def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ select a, count(a) from (select 1 as a) tmp group by 1 order by 1;
-- group by ordinal followed by having
select count(a), a from (select 1 as a) tmp group by 2 having a > 0;

-- mixed cases: group-by ordinals and aliases
select a, a AS k, count(b) from data group by k, 1;

-- turn of group by ordinal
set spark.sql.groupByOrdinal=false;

Expand Down
18 changes: 18 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/group-by.sql
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,21 @@ FROM testData;

-- Aggregate with foldable input and multiple distinct groups.
SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a;

-- Aliases in SELECT could be used in GROUP BY
SELECT a AS k, COUNT(b) FROM testData GROUP BY k;
SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1;

-- Aggregate functions cannot be used in GROUP BY
SELECT COUNT(b) AS k FROM testData GROUP BY k;

-- Test data.
CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES
(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v);
SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a;

-- turn off group by aliases
set spark.sql.groupByAliases=false;

-- Check analysis exceptions
SELECT a AS k, COUNT(b) FROM testData GROUP BY k;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add the extra EXPLAIN for us to know whether the plan is correct.

EXPLAIN SELECT a AS k, COUNT(b) FROM testData GROUP BY k

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 19
-- Number of queries: 20


-- !query 0
Expand Down Expand Up @@ -173,16 +173,26 @@ struct<count(a):bigint,a:int>


-- !query 17
set spark.sql.groupByOrdinal=false
select a, a AS k, count(b) from data group by k, 1
-- !query 17 schema
struct<key:string,value:string>
struct<a:int,k:int,count(b):bigint>
-- !query 17 output
spark.sql.groupByOrdinal false
1 1 2
2 2 2
3 3 2


-- !query 18
select sum(b) from data group by -1
set spark.sql.groupByOrdinal=false
-- !query 18 schema
struct<sum(b):bigint>
struct<key:string,value:string>
-- !query 18 output
spark.sql.groupByOrdinal false


-- !query 19
select sum(b) from data group by -1
-- !query 19 schema
struct<sum(b):bigint>
-- !query 19 output
9
Loading