Skip to content

Commit d563c8f

Browse files
committed
Revert "[SPARK-13376] [SQL] improve column pruning"
This reverts commit e9533b4.
1 parent 382b27b commit d563c8f

File tree

4 files changed

+156
-187
lines changed

4 files changed

+156
-187
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 70 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -313,85 +313,97 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
313313
*/
314314
object ColumnPruning extends Rule[LogicalPlan] {
315315
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
316-
// Prunes the unused columns from project list of Project/Aggregate/Window/Expand
317-
case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty =>
318-
p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
319-
case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty =>
320-
p.copy(
321-
child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains)))
322-
case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty =>
323-
p.copy(child = w.copy(
324-
projectList = w.projectList.filter(p.references.contains),
325-
windowExpressions = w.windowExpressions.filter(p.references.contains)))
326-
case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty =>
327-
val newOutput = e.output.filter(a.references.contains(_))
328-
val newProjects = e.projections.map { proj =>
329-
proj.zip(e.output).filter { case (e, a) =>
316+
case a @ Aggregate(_, _, e @ Expand(projects, output, child))
317+
if (e.outputSet -- a.references).nonEmpty =>
318+
val newOutput = output.filter(a.references.contains(_))
319+
val newProjects = projects.map { proj =>
320+
proj.zip(output).filter { case (e, a) =>
330321
newOutput.contains(a)
331322
}.unzip._1
332323
}
333-
a.copy(child = Expand(newProjects, newOutput, grandChild))
334-
// TODO: support some logical plan for Dataset
324+
a.copy(child = Expand(newProjects, newOutput, child))
335325

336-
// Prunes the unused columns from child of Aggregate/Window/Expand/Generate
326+
case a @ Aggregate(_, _, e @ Expand(_, _, child))
327+
if (child.outputSet -- e.references -- a.references).nonEmpty =>
328+
a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references)))
329+
330+
// Eliminate attributes that are not needed to calculate the specified aggregates.
337331
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
338-
a.copy(child = prunedChild(child, a.references))
339-
case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty =>
340-
w.copy(child = prunedChild(child, w.references))
341-
case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty =>
342-
e.copy(child = prunedChild(child, e.references))
332+
a.copy(child = Project(a.references.toSeq, child))
333+
334+
// Eliminate attributes that are not needed to calculate the Generate.
343335
case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>
344-
g.copy(child = prunedChild(g.child, g.references))
336+
g.copy(child = Project(g.references.toSeq, g.child))
345337

346-
// Turn off `join` for Generate if no column from it's child is used
347338
case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
348339
p.copy(child = g.copy(join = false))
349340

350-
// Eliminate unneeded attributes from right side of a LeftSemiJoin.
351-
case j @ Join(left, right, LeftSemi, condition) =>
352-
j.copy(right = prunedChild(right, j.references))
353-
354-
// all the columns will be used to compare, so we can't prune them
355-
case p @ Project(_, _: SetOperation) => p
356-
case p @ Project(_, _: Distinct) => p
357-
// Eliminate unneeded attributes from children of Union.
358-
case p @ Project(_, u: Union) =>
359-
if ((u.outputSet -- p.references).nonEmpty) {
360-
val firstChild = u.children.head
361-
val newOutput = prunedChild(firstChild, p.references).output
362-
// pruning the columns of all children based on the pruned first child.
363-
val newChildren = u.children.map { p =>
364-
val selected = p.output.zipWithIndex.filter { case (a, i) =>
365-
newOutput.contains(firstChild.output(i))
366-
}.map(_._1)
367-
Project(selected, p)
368-
}
369-
p.copy(child = u.withNewChildren(newChildren))
370-
} else {
341+
case p @ Project(projectList, g: Generate) if g.join =>
342+
val neededChildOutput = p.references -- g.generatorOutput ++ g.references
343+
if (neededChildOutput == g.child.outputSet) {
371344
p
345+
} else {
346+
Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child)))
372347
}
373348

374-
// Can't prune the columns on LeafNode
375-
case p @ Project(_, l: LeafNode) => p
349+
case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child))
350+
if (a.outputSet -- p.references).nonEmpty =>
351+
Project(
352+
projectList,
353+
Aggregate(
354+
groupingExpressions,
355+
aggregateExpressions.filter(e => p.references.contains(e)),
356+
child))
376357

377-
// Eliminate no-op Projects
378-
case p @ Project(projectList, child) if child.output == p.output => child
379-
380-
// for all other logical plans that inherits the output from it's children
381-
case p @ Project(_, child) =>
382-
val required = child.references ++ p.references
383-
if ((child.inputSet -- required).nonEmpty) {
384-
val newChildren = child.children.map(c => prunedChild(c, required))
385-
p.copy(child = child.withNewChildren(newChildren))
358+
// Eliminate unneeded attributes from either side of a Join.
359+
case Project(projectList, Join(left, right, joinType, condition)) =>
360+
// Collect the list of all references required either above or to evaluate the condition.
361+
val allReferences: AttributeSet =
362+
AttributeSet(
363+
projectList.flatMap(_.references.iterator)) ++
364+
condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
365+
366+
/** Applies a projection only when the child is producing unnecessary attributes */
367+
def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences)
368+
369+
Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition))
370+
371+
// Eliminate unneeded attributes from right side of a LeftSemiJoin.
372+
case Join(left, right, LeftSemi, condition) =>
373+
// Collect the list of all references required to evaluate the condition.
374+
val allReferences: AttributeSet =
375+
condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
376+
377+
Join(left, prunedChild(right, allReferences), LeftSemi, condition)
378+
379+
// Push down project through limit, so that we may have chance to push it further.
380+
case Project(projectList, Limit(exp, child)) =>
381+
Limit(exp, Project(projectList, child))
382+
383+
// Push down project if possible when the child is sort.
384+
case p @ Project(projectList, s @ Sort(_, _, grandChild)) =>
385+
if (s.references.subsetOf(p.outputSet)) {
386+
s.copy(child = Project(projectList, grandChild))
386387
} else {
387-
p
388+
val neededReferences = s.references ++ p.references
389+
if (neededReferences == grandChild.outputSet) {
390+
// No column we can prune, return the original plan.
391+
p
392+
} else {
393+
// Do not use neededReferences.toSeq directly, should respect grandChild's output order.
394+
val newProjectList = grandChild.output.filter(neededReferences.contains)
395+
p.copy(child = s.copy(child = Project(newProjectList, grandChild)))
396+
}
388397
}
398+
399+
// Eliminate no-op Projects
400+
case Project(projectList, child) if child.output == projectList => child
389401
}
390402

391403
/** Applies a projection only when the child is producing unnecessary attributes */
392404
private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
393405
if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
394-
Project(c.output.filter(allReferences.contains), c)
406+
Project(allReferences.filter(c.outputSet.contains).toSeq, c)
395407
} else {
396408
c
397409
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala

Lines changed: 2 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20-
import org.apache.spark.sql.catalyst.analysis
2120
import org.apache.spark.sql.catalyst.dsl.expressions._
2221
import org.apache.spark.sql.catalyst.dsl.plans._
23-
import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder}
22+
import org.apache.spark.sql.catalyst.expressions.{Explode, Literal}
2423
import org.apache.spark.sql.catalyst.plans.PlanTest
2524
import org.apache.spark.sql.catalyst.plans.logical._
2625
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -120,134 +119,11 @@ class ColumnPruningSuite extends PlanTest {
120119
Seq('c, Literal.create(null, StringType), 1),
121120
Seq('c, 'a, 2)),
122121
Seq('c, 'aa.int, 'gid.int),
123-
Project(Seq('a, 'c),
122+
Project(Seq('c, 'a),
124123
input))).analyze
125124

126125
comparePlans(optimized, expected)
127126
}
128127

129-
test("Column pruning on Filter") {
130-
val input = LocalRelation('a.int, 'b.string, 'c.double)
131-
val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze
132-
val expected =
133-
Project('a :: Nil,
134-
Filter('c > Literal(0.0),
135-
Project(Seq('a, 'c), input))).analyze
136-
comparePlans(Optimize.execute(query), expected)
137-
}
138-
139-
test("Column pruning on except/intersect/distinct") {
140-
val input = LocalRelation('a.int, 'b.string, 'c.double)
141-
val query = Project('a :: Nil, Except(input, input)).analyze
142-
comparePlans(Optimize.execute(query), query)
143-
144-
val query2 = Project('a :: Nil, Intersect(input, input)).analyze
145-
comparePlans(Optimize.execute(query2), query2)
146-
val query3 = Project('a :: Nil, Distinct(input)).analyze
147-
comparePlans(Optimize.execute(query3), query3)
148-
}
149-
150-
test("Column pruning on Project") {
151-
val input = LocalRelation('a.int, 'b.string, 'c.double)
152-
val query = Project('a :: Nil, Project(Seq('a, 'b), input)).analyze
153-
val expected = Project(Seq('a), input).analyze
154-
comparePlans(Optimize.execute(query), expected)
155-
}
156-
157-
test("column pruning for group") {
158-
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
159-
val originalQuery =
160-
testRelation
161-
.groupBy('a)('a, count('b))
162-
.select('a)
163-
164-
val optimized = Optimize.execute(originalQuery.analyze)
165-
val correctAnswer =
166-
testRelation
167-
.select('a)
168-
.groupBy('a)('a).analyze
169-
170-
comparePlans(optimized, correctAnswer)
171-
}
172-
173-
test("column pruning for group with alias") {
174-
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
175-
176-
val originalQuery =
177-
testRelation
178-
.groupBy('a)('a as 'c, count('b))
179-
.select('c)
180-
181-
val optimized = Optimize.execute(originalQuery.analyze)
182-
val correctAnswer =
183-
testRelation
184-
.select('a)
185-
.groupBy('a)('a as 'c).analyze
186-
187-
comparePlans(optimized, correctAnswer)
188-
}
189-
190-
test("column pruning for Project(ne, Limit)") {
191-
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
192-
193-
val originalQuery =
194-
testRelation
195-
.select('a, 'b)
196-
.limit(2)
197-
.select('a)
198-
199-
val optimized = Optimize.execute(originalQuery.analyze)
200-
val correctAnswer =
201-
testRelation
202-
.select('a)
203-
.limit(2).analyze
204-
205-
comparePlans(optimized, correctAnswer)
206-
}
207-
208-
test("push down project past sort") {
209-
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
210-
val x = testRelation.subquery('x)
211-
212-
// push down valid
213-
val originalQuery = {
214-
x.select('a, 'b)
215-
.sortBy(SortOrder('a, Ascending))
216-
.select('a)
217-
}
218-
219-
val optimized = Optimize.execute(originalQuery.analyze)
220-
val correctAnswer =
221-
x.select('a)
222-
.sortBy(SortOrder('a, Ascending)).analyze
223-
224-
comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
225-
226-
// push down invalid
227-
val originalQuery1 = {
228-
x.select('a, 'b)
229-
.sortBy(SortOrder('a, Ascending))
230-
.select('b)
231-
}
232-
233-
val optimized1 = Optimize.execute(originalQuery1.analyze)
234-
val correctAnswer1 =
235-
x.select('a, 'b)
236-
.sortBy(SortOrder('a, Ascending))
237-
.select('b).analyze
238-
239-
comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
240-
}
241-
242-
test("Column pruning on Union") {
243-
val input1 = LocalRelation('a.int, 'b.string, 'c.double)
244-
val input2 = LocalRelation('c.int, 'd.string, 'e.double)
245-
val query = Project('b :: Nil,
246-
Union(input1 :: input2 :: Nil)).analyze
247-
val expected = Project('b :: Nil,
248-
Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze
249-
comparePlans(Optimize.execute(query), expected)
250-
}
251-
252128
// todo: add more tests for column pruning
253129
}

0 commit comments

Comments
 (0)