Skip to content

Commit 616a78a

Browse files
committed
[SPARK-18969][SQL] Support grouping by nondeterministic expressions
## What changes were proposed in this pull request? Currently nondeterministic expressions are allowed in `Aggregate`(see the [comment](https://github.com/apache/spark/blob/v2.0.2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala#L249-L251)), but the `PullOutNondeterministic` analyzer rule failed to handle `Aggregate`, this PR fixes it. close #16379 There is still one remaining issue: `SELECT a + rand() FROM t GROUP BY a + rand()` is not allowed, because the 2 `rand()` are different(we generate random seed as the default seed for `rand()`). https://issues.apache.org/jira/browse/SPARK-19035 is tracking this issue. ## How was this patch tested? a new test suite Author: Wenchen Fan <[email protected]> Closes #16404 from cloud-fan/groupby. (cherry picked from commit 871d266) Signed-off-by: Wenchen Fan <[email protected]>
1 parent 9b9867e commit 616a78a

File tree

3 files changed

+86
-17
lines changed

3 files changed

+86
-17
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1859,28 +1859,37 @@ class Analyzer(
18591859
case p: Project => p
18601860
case f: Filter => f
18611861

1862+
case a: Aggregate if a.groupingExpressions.exists(!_.deterministic) =>
1863+
val nondeterToAttr = getNondeterToAttr(a.groupingExpressions)
1864+
val newChild = Project(a.child.output ++ nondeterToAttr.values, a.child)
1865+
a.transformExpressions { case e =>
1866+
nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
1867+
}.copy(child = newChild)
1868+
18621869
// todo: It's hard to write a general rule to pull out nondeterministic expressions
18631870
// from LogicalPlan, currently we only do it for UnaryNode which has same output
18641871
// schema with its child.
18651872
case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) =>
1866-
val nondeterministicExprs = p.expressions.filterNot(_.deterministic).flatMap { expr =>
1867-
val leafNondeterministic = expr.collect {
1868-
case n: Nondeterministic => n
1869-
}
1870-
leafNondeterministic.map { e =>
1871-
val ne = e match {
1872-
case n: NamedExpression => n
1873-
case _ => Alias(e, "_nondeterministic")(isGenerated = true)
1874-
}
1875-
new TreeNodeRef(e) -> ne
1876-
}
1877-
}.toMap
1873+
val nondeterToAttr = getNondeterToAttr(p.expressions)
18781874
val newPlan = p.transformExpressions { case e =>
1879-
nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e)
1875+
nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
18801876
}
1881-
val newChild = Project(p.child.output ++ nondeterministicExprs.values, p.child)
1877+
val newChild = Project(p.child.output ++ nondeterToAttr.values, p.child)
18821878
Project(p.output, newPlan.withNewChildren(newChild :: Nil))
18831879
}
1880+
1881+
private def getNondeterToAttr(exprs: Seq[Expression]): Map[Expression, NamedExpression] = {
1882+
exprs.filterNot(_.deterministic).flatMap { expr =>
1883+
val leafNondeterministic = expr.collect { case n: Nondeterministic => n }
1884+
leafNondeterministic.distinct.map { e =>
1885+
val ne = e match {
1886+
case n: NamedExpression => n
1887+
case _ => Alias(e, "_nondeterministic")(isGenerated = true)
1888+
}
1889+
e -> ne
1890+
}
1891+
}.toMap
1892+
}
18841893
}
18851894

18861895
/**
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import org.apache.spark.sql.catalyst.dsl.expressions._
21+
import org.apache.spark.sql.catalyst.dsl.plans._
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
24+
25+
/**
26+
* Test suite for moving non-deterministic expressions into Project.
27+
*/
28+
class PullOutNondeterministicSuite extends AnalysisTest {
29+
30+
private lazy val a = 'a.int
31+
private lazy val b = 'b.int
32+
private lazy val r = LocalRelation(a, b)
33+
private lazy val rnd = Rand(10).as('_nondeterministic)
34+
private lazy val rndref = rnd.toAttribute
35+
36+
test("no-op on filter") {
37+
checkAnalysis(
38+
r.where(Rand(10) > Literal(1.0)),
39+
r.where(Rand(10) > Literal(1.0))
40+
)
41+
}
42+
43+
test("sort") {
44+
checkAnalysis(
45+
r.sortBy(SortOrder(Rand(10), Ascending)),
46+
r.select(a, b, rnd).sortBy(SortOrder(rndref, Ascending)).select(a, b)
47+
)
48+
}
49+
50+
test("aggregate") {
51+
checkAnalysis(
52+
r.groupBy(Rand(10))(Rand(10).as("rnd")),
53+
r.select(a, b, rnd).groupBy(rndref)(rndref.as("rnd"))
54+
)
55+
}
56+
}

sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,14 @@ GROUP BY position 3 is an aggregate function, and aggregate functions are not al
137137
-- !query 13
138138
select a, rand(0), sum(b) from data group by a, 2
139139
-- !query 13 schema
140-
struct<>
140+
struct<a:int,rand(0):double,sum(b):bigint>
141141
-- !query 13 output
142-
org.apache.spark.sql.AnalysisException
143-
nondeterministic expression rand(0) should not appear in grouping expression.;
142+
1 0.4048454303385226 2
143+
1 0.8446490682263027 1
144+
2 0.5871875724155838 1
145+
2 0.8865128837019473 2
146+
3 0.742083829230211 1
147+
3 0.9179913208300406 2
144148

145149

146150
-- !query 14

0 commit comments

Comments
 (0)