@@ -313,85 +313,97 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
313313 */
314314object ColumnPruning extends Rule [LogicalPlan ] {
315315 def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
316- // Prunes the unused columns from project list of Project/Aggregate/Window/Expand
317- case p @ Project (_, p2 : Project ) if (p2.outputSet -- p.references).nonEmpty =>
318- p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
319- case p @ Project (_, a : Aggregate ) if (a.outputSet -- p.references).nonEmpty =>
320- p.copy(
321- child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains)))
322- case p @ Project (_, w : Window ) if (w.outputSet -- p.references).nonEmpty =>
323- p.copy(child = w.copy(
324- projectList = w.projectList.filter(p.references.contains),
325- windowExpressions = w.windowExpressions.filter(p.references.contains)))
326- case a @ Project (_, e @ Expand (_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty =>
327- val newOutput = e.output.filter(a.references.contains(_))
328- val newProjects = e.projections.map { proj =>
329- proj.zip(e.output).filter { case (e, a) =>
316+ case a @ Aggregate (_, _, e @ Expand (projects, output, child))
317+ if (e.outputSet -- a.references).nonEmpty =>
318+ val newOutput = output.filter(a.references.contains(_))
319+ val newProjects = projects.map { proj =>
320+ proj.zip(output).filter { case (e, a) =>
330321 newOutput.contains(a)
331322 }.unzip._1
332323 }
333- a.copy(child = Expand (newProjects, newOutput, grandChild))
334- // TODO: support some logical plan for Dataset
324+ a.copy(child = Expand (newProjects, newOutput, child))
335325
336- // Prunes the unused columns from child of Aggregate/Window/Expand/Generate
326+ case a @ Aggregate (_, _, e @ Expand (_, _, child))
327+ if (child.outputSet -- e.references -- a.references).nonEmpty =>
328+ a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references)))
329+
330+ // Eliminate attributes that are not needed to calculate the specified aggregates.
337331 case a @ Aggregate (_, _, child) if (child.outputSet -- a.references).nonEmpty =>
338- a.copy(child = prunedChild(child, a.references))
339- case w @ Window (_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty =>
340- w.copy(child = prunedChild(child, w.references))
341- case e @ Expand (_, _, child) if (child.outputSet -- e.references).nonEmpty =>
342- e.copy(child = prunedChild(child, e.references))
332+ a.copy(child = Project (a.references.toSeq, child))
333+
334+ // Eliminate attributes that are not needed to calculate the Generate.
343335 case g : Generate if ! g.join && (g.child.outputSet -- g.references).nonEmpty =>
344- g.copy(child = prunedChild (g.child , g.references ))
336+ g.copy(child = Project (g.references.toSeq , g.child ))
345337
346- // Turn off `join` for Generate if no column from it's child is used
347338 case p @ Project (_, g : Generate ) if g.join && p.references.subsetOf(g.generatedSet) =>
348339 p.copy(child = g.copy(join = false ))
349340
350- // Eliminate unneeded attributes from right side of a LeftSemiJoin.
351- case j @ Join (left, right, LeftSemi , condition) =>
352- j.copy(right = prunedChild(right, j.references))
353-
354- // all the columns will be used to compare, so we can't prune them
355- case p @ Project (_, _ : SetOperation ) => p
356- case p @ Project (_, _ : Distinct ) => p
357- // Eliminate unneeded attributes from children of Union.
358- case p @ Project (_, u : Union ) =>
359- if ((u.outputSet -- p.references).nonEmpty) {
360- val firstChild = u.children.head
361- val newOutput = prunedChild(firstChild, p.references).output
362- // pruning the columns of all children based on the pruned first child.
363- val newChildren = u.children.map { p =>
364- val selected = p.output.zipWithIndex.filter { case (a, i) =>
365- newOutput.contains(firstChild.output(i))
366- }.map(_._1)
367- Project (selected, p)
368- }
369- p.copy(child = u.withNewChildren(newChildren))
370- } else {
341+ case p @ Project (projectList, g : Generate ) if g.join =>
342+ val neededChildOutput = p.references -- g.generatorOutput ++ g.references
343+ if (neededChildOutput == g.child.outputSet) {
371344 p
345+ } else {
346+ Project (projectList, g.copy(child = Project (neededChildOutput.toSeq, g.child)))
372347 }
373348
374- // Can't prune the columns on LeafNode
375- case p @ Project (_, l : LeafNode ) => p
349+ case p @ Project (projectList, a @ Aggregate (groupingExpressions, aggregateExpressions, child))
350+ if (a.outputSet -- p.references).nonEmpty =>
351+ Project (
352+ projectList,
353+ Aggregate (
354+ groupingExpressions,
355+ aggregateExpressions.filter(e => p.references.contains(e)),
356+ child))
376357
377- // Eliminate no-op Projects
378- case p @ Project (projectList, child) if child.output == p.output => child
379-
380- // for all other logical plans that inherits the output from it's children
381- case p @ Project (_, child) =>
382- val required = child.references ++ p.references
383- if ((child.inputSet -- required).nonEmpty) {
384- val newChildren = child.children.map(c => prunedChild(c, required))
385- p.copy(child = child.withNewChildren(newChildren))
358+ // Eliminate unneeded attributes from either side of a Join.
359+ case Project (projectList, Join (left, right, joinType, condition)) =>
360+ // Collect the list of all references required either above or to evaluate the condition.
361+ val allReferences : AttributeSet =
362+ AttributeSet (
363+ projectList.flatMap(_.references.iterator)) ++
364+ condition.map(_.references).getOrElse(AttributeSet (Seq .empty))
365+
366+ /** Applies a projection only when the child is producing unnecessary attributes */
367+ def pruneJoinChild (c : LogicalPlan ): LogicalPlan = prunedChild(c, allReferences)
368+
369+ Project (projectList, Join (pruneJoinChild(left), pruneJoinChild(right), joinType, condition))
370+
371+ // Eliminate unneeded attributes from right side of a LeftSemiJoin.
372+ case Join (left, right, LeftSemi , condition) =>
373+ // Collect the list of all references required to evaluate the condition.
374+ val allReferences : AttributeSet =
375+ condition.map(_.references).getOrElse(AttributeSet (Seq .empty))
376+
377+ Join (left, prunedChild(right, allReferences), LeftSemi , condition)
378+
379+ // Push down project through limit, so that we may have chance to push it further.
380+ case Project (projectList, Limit (exp, child)) =>
381+ Limit (exp, Project (projectList, child))
382+
383+ // Push down project if possible when the child is sort.
384+ case p @ Project (projectList, s @ Sort (_, _, grandChild)) =>
385+ if (s.references.subsetOf(p.outputSet)) {
386+ s.copy(child = Project (projectList, grandChild))
386387 } else {
387- p
388+ val neededReferences = s.references ++ p.references
389+ if (neededReferences == grandChild.outputSet) {
390+ // No column we can prune, return the original plan.
391+ p
392+ } else {
393+ // Do not use neededReferences.toSeq directly, should respect grandChild's output order.
394+ val newProjectList = grandChild.output.filter(neededReferences.contains)
395+ p.copy(child = s.copy(child = Project (newProjectList, grandChild)))
396+ }
388397 }
398+
399+ // Eliminate no-op Projects
400+ case Project (projectList, child) if child.output == projectList => child
389401 }
390402
391403 /** Applies a projection only when the child is producing unnecessary attributes */
392404 private def prunedChild (c : LogicalPlan , allReferences : AttributeSet ) =
393405 if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
394- Project (c.output. filter(allReferences. contains), c)
406+ Project (allReferences. filter(c.outputSet. contains).toSeq , c)
395407 } else {
396408 c
397409 }
0 commit comments