Skip to content

Commit 30dbdc6

Browse files
committed
Push predicates through Expand
1 parent d29e429 commit 30dbdc6

File tree

5 files changed

+28
-14
lines changed

5 files changed

+28
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,9 @@ class Analyzer(
300300
a.toAttribute.withNullability((nonNullBitmask & 1 << idx) == 0)
301301
}
302302

303+
val expand = Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child)
304+
val finalGroupingAttrs = expand.output.drop(x.child.output.length)
305+
303306
val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr =>
304307
// collect all the found AggregateExpression, so we can check an expression is part of
305308
// any AggregateExpression or not.
@@ -321,15 +324,12 @@ class Analyzer(
321324
if (index == -1) {
322325
e
323326
} else {
324-
groupByAttributes(index)
327+
finalGroupingAttrs(index)
325328
}
326329
}.asInstanceOf[NamedExpression]
327330
}
328331

329-
Aggregate(
330-
groupByAttributes :+ gid,
331-
aggregations,
332-
Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child))
332+
Aggregate(finalGroupingAttrs, aggregations, expand)
333333

334334
case f @ Filter(cond, child) if hasGroupingFunction(cond) =>
335335
val groupingExprs = findGroupingExprs(child)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,8 +1019,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
10191019
case filter @ Filter(_, f: Filter) => filter
10201020
// should not push predicates through sample, or will generate different results.
10211021
case filter @ Filter(_, s: Sample) => filter
1022-
// TODO: push predicates through expand
1023-
case filter @ Filter(_, e: Expand) => filter
10241022

10251023
case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) =>
10261024
pushDownPredicate(filter, u.child) { predicate =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,10 @@ private[sql] object Expand {
516516
// groupingId is the last output, here we use the bit mask as the concrete value for it.
517517
} :+ Literal.create(bitmask, IntegerType)
518518
}
519-
val output = child.output ++ groupByAttrs :+ gid
519+
520+
// the `groupByAttrs` has different meaning in `Expand.output`, it could be the original
521+
// grouping expression or null, so here we create new instance of it.
522+
val output = child.output ++ groupByAttrs.map(_.newInstance) :+ gid
520523
Expand(projections, output, Project(child.output ++ groupByAliases, child))
521524
}
522525
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,4 +743,19 @@ class FilterPushdownSuite extends PlanTest {
743743

744744
comparePlans(optimized, correctAnswer)
745745
}
746+
747+
test("expand") {
748+
val agg = testRelation
749+
.groupBy(Cube(Seq('a, 'b)))('a, 'b, sum('c))
750+
.analyze
751+
.asInstanceOf[Aggregate]
752+
753+
val a = agg.output(0)
754+
val b = agg.output(1)
755+
756+
val query = agg.where(a > 1 && b > 2)
757+
val optimized = Optimize.execute(query)
758+
val correctedAnswer = agg.copy(child = agg.child.where(a > 1 && b > 2)).analyze
759+
comparePlans(optimized, correctedAnswer)
760+
}
746761
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
288288

289289
private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = {
290290
assert(a.child == e && e.child == p)
291-
a.groupingExpressions.forall(_.isInstanceOf[Attribute]) &&
292-
sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute]))
291+
a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && sameOutput(
292+
e.output.drop(p.child.output.length),
293+
a.groupingExpressions.map(_.asInstanceOf[Attribute]))
293294
}
294295

295296
private def groupingSetToSQL(
@@ -302,13 +303,10 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
302303
val gid = expand.output.last
303304

304305
val numOriginalOutput = project.child.output.length
305-
// Assumption: Aggregate's groupingExpressions is composed of
306-
// 1) the attributes of aliased group by expressions
307-
// 2) gid, which is always the last one
308-
val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute])
309306
// Assumption: Project's projectList is composed of
310307
// 1) the original output (Project's child.output),
311308
// 2) the aliased group by expressions.
309+
val groupByAttributes = project.output.drop(numOriginalOutput)
312310
val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child)
313311
val groupingSQL = groupByExprs.map(_.sql).mkString(", ")
314312

0 commit comments

Comments
 (0)