@@ -190,19 +190,33 @@ object PartitionPruning extends Rule[LogicalPlan]
190190 }
191191 }
192192
193+ // Make sure injected filters could push through Shuffle, see PushPredicateThroughNonJoin
194+ private def probablyPushThroughShuffle (exp : Expression , plan : LogicalPlan ): Boolean = {
195+ plan match {
196+ case j : Join if ! canPlanAsBroadcastHashJoin(j, conf) => true
197+ case a @ Aggregate (groupingExps, aggExps, child)
198+ if aggExps.forall(_.deterministic) && groupingExps.nonEmpty &&
199+ replaceAlias(exp, getAliasMap(a)).references.subsetOf(child.outputSet) => true
200+ case w : Window
201+ if w.partitionSpec.forall(_.isInstanceOf [AttributeReference ]) &&
202+ exp.references.subsetOf(AttributeSet (w.partitionSpec.flatMap(_.references))) => true
203+ case p : Project =>
204+ probablyPushThroughShuffle(replaceAlias(exp, getAliasMap(p)), p.child)
205+ case other =>
206+ other.children.exists { p =>
207+ if (exp.references.subsetOf(p.outputSet)) probablyPushThroughShuffle(exp, p) else false
208+ }
209+ }
210+ }
211+
193212 private def dataPruningHasBenefit (
194213 prunRelation : LogicalRelation ,
214+ exp : Expression ,
195215 prunPlan : LogicalPlan ,
196216 otherPlan : LogicalPlan ,
197217 canBuildBroadcast : Boolean ): Boolean = {
198218 if (canBuildBroadcast) {
199- val shuffleStages = prunPlan.collect {
200- case j @ Join (left, right, _, _, hint)
201- if ! canBroadcastBySize(left, SQLConf .get) && ! canBroadcastBySize(right, SQLConf .get)
202- && ! hintToBroadcastLeft(hint) && ! hintToBroadcastRight(hint) => j
203- case a : Aggregate => a
204- }
205- shuffleStages.exists(_.collectLeaves().exists(_.equals(prunRelation))) &&
219+ probablyPushThroughShuffle(exp, prunPlan) &&
206220 prunRelation.stats.sizeInBytes >= SQLConf .get.dynamicDataPruningSideThreshold
207221 } else {
208222 val estimatePruningSideSize =
@@ -251,7 +265,7 @@ object PartitionPruning extends Rule[LogicalPlan]
251265 canPruneLeft(joinType) &&
252266 supportDynamicPruning(right) &&
253267 (canBroadcastBySize(right, conf) || hintToBroadcastRight(hint)) &&
254- dataPruningHasBenefit(scan.logicalRelation, left, right,
268+ dataPruningHasBenefit(scan.logicalRelation, l, left, right,
255269 canBuildBroadcastRight(joinType)) =>
256270 newLeft = insertDataPredicate(l, newLeft, r, right, rightKeys)
257271 case _ =>
@@ -269,7 +283,7 @@ object PartitionPruning extends Rule[LogicalPlan]
269283 canPruneRight(joinType) &&
270284 supportDynamicPruning(left) &&
271285 (canBroadcastBySize(left, conf) || hintToBroadcastLeft(hint)) &&
272- dataPruningHasBenefit(scan.logicalRelation, right, left,
286+ dataPruningHasBenefit(scan.logicalRelation, r, right, left,
273287 canBuildBroadcastLeft(joinType)) =>
274288 newRight = insertDataPredicate(r, newRight, l, left, leftKeys)
275289 case _ =>
0 commit comments