diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2de92d06ec83..9783a953ea69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -899,23 +899,25 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe case a: Attribute if aliasMap.contains(a) => aliasMap(a) }.forall(_.deterministic)) - // If there is no nondeterministic conditions, push down the whole condition. - if (nondeterministic.isEmpty) { - project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) + // Only pushed the conditions that do not exist in the project's constraints. + val pushedConditions = + deterministic.map(replaceAlias(_, aliasMap)).filterNot { + c => project.child.constraints.contains(c) || c.references.isEmpty + } + + // If no condition can be pushed down, leave it un-changed. + if (pushedConditions.isEmpty) { + filter } else { - // If they are all nondeterministic conditions, leave it un-changed. - if (deterministic.isEmpty) { - filter + val newConditions = pushedConditions.reduce(And) + if (nondeterministic.isEmpty) { + project.copy(child = Filter(newConditions, grandChild)) } else { - // Push down the small conditions without nondeterministic expressions. - val pushedCondition = - deterministic.map(replaceAlias(_, aliasMap)).reduce(And) Filter(nondeterministic.reduce(And), - project.copy(child = Filter(pushedCondition, grandChild))) + project.copy(child = Filter(newConditions, grandChild))) } } } - } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferredFiltersPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferredFiltersPushDownSuite.scala new file mode 100644 index 000000000000..49c4e449ba1d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferredFiltersPushDownSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class InferredFiltersPushDownSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("NullFiltering", FixedPoint(15), + SetOperationPushDown, + PushPredicateThroughJoin, + PushPredicateThroughProject, + PushPredicateThroughGenerate, + PushPredicateThroughAggregate, + NullFiltering, + CollapseProject, + PruneFilters, + CombineFilters) :: Nil + } + + val testRelation1 = LocalRelation('a.string, 'b.int, 'c.int) + val testRelation2 = LocalRelation('a.string, 'b.int, 'c.int) + + test("filter: do not push predicates") { + val x2 = testRelation1.select("tst1".as("key"), 'b) + val x1 = testRelation1.groupBy('a)('a.as("x2a"), 'b + 1) + val y = testRelation1 + val union1 = x1.unionAll(x2) + + val originalQuery = union1.join(y, condition = Some('x2a === 'a)).analyze + + val x1Optimized = + testRelation1.select("tst1".as("key"), 'b).where(IsNotNull('key)) + val x2Optimized = testRelation1.where(IsNotNull('a)).groupBy('a)('a.as("x2a"), 'b + 1) + val unionOptimized = x2Optimized.unionAll(x1Optimized) + val correctAnswer = + unionOptimized.join(y.where(IsNotNull('a)), condition = Some('x2a === 'a)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index 0ee7cf92097e..8b1362c4ccc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -130,7 +130,6 @@ class PruneFiltersSuite extends PlanTest { test("Nondeterministic predicate is not pruned") { val originalQuery = testRelation.where(Rand(10) > 5).select('a).where(Rand(10) > 5).analyze val optimized = Optimize.execute(originalQuery) - val correctAnswer = testRelation.where(Rand(10) > 5).where(Rand(10) > 5).select('a).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized, originalQuery) } }