From 8e6a0fddf384439756434df8bbf97249efa07fda Mon Sep 17 00:00:00 2001 From: jiaan Geng Date: Tue, 4 Apr 2023 19:57:51 +0800 Subject: [PATCH] [SPARK-43025][SQL] Eliminate Union if filters have the same child plan --- .../optimizer/CombineUnionedSubquery.scala | 249 ++++++++++++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../sql/catalyst/rules/RuleIdCollection.scala | 1 + .../apache/spark/sql/internal/SQLConf.scala | 11 + .../CombineUnionedSubquerySuite.scala | 160 +++++++++++ .../sql/DataFrameSetOperationsSuite.scala | 109 ++++++++ 6 files changed, 533 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineUnionedSubquery.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineUnionedSubquerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineUnionedSubquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineUnionedSubquery.scala new file mode 100644 index 0000000000000..4aacbd30aefb0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineUnionedSubquery.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, Attribute, AttributeMap, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, NamedExpression, Not, Or} +import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Distinct, Filter, LogicalPlan, Project, SerializeFromObject, Union} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.trees.TreePattern.{FILTER, UNION} +import org.apache.spark.sql.types.{DataType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType} + +/** + * This rule optimizes Union operators by: + * 1. Combines the children of [[Union]]. + * 2. Eliminate [[Union]] operators if all the children can be merged into one. + */ +object CombineUnionedSubquery extends Rule[LogicalPlan] with AliasHelper { + + /** + * A tag to identify if the [[Union]] is the child of [[Distinct]] or [[Deduplicate]]. + */ + val UNION_ELIMINATE_DISABLED = TreeNodeTag[Unit]("union_eliminate_disabled") + + def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.combineUnionedAggregatesEnabled) return plan + + plan.transformDownWithPruning( + _.containsAnyPattern(FILTER, UNION), ruleId) { + case d @ Distinct(u: Union) => + u.setTagValue(UNION_ELIMINATE_DISABLED, ()) + d + case d @ Deduplicate(_, u: Union) => + u.setTagValue(UNION_ELIMINATE_DISABLED, ()) + d + case u: Union if u.getTagValue(UNION_ELIMINATE_DISABLED).isEmpty => eliminateUnion(u) + } + } + + private def eliminateUnion(union: Union): LogicalPlan = { + val cache = mutable.ArrayBuffer.empty[LogicalPlan] + union.children.foreach(subPlan => mergeSubPlan(subPlan, cache)) + + assert(cache.size > 0) + + if (cache.size == union.children.size) { + union + } else if (cache.size > 1) { + union.copy(children = cache.toSeq) + } else { + cache.head + } + } + + private def mergeSubPlan(subPlan: LogicalPlan, cache: mutable.ArrayBuffer[LogicalPlan]): Unit = { + cache.zipWithIndex.collectFirst(Function.unlift { + case (cachedPlan, subqueryIndex) if subPlan.canonicalized != cachedPlan.canonicalized => + tryMergePlans(subPlan, cachedPlan).map { + case (mergedPlan, _) => + cache(subqueryIndex) = mergedPlan + } + case _ => None + }).getOrElse { + cache += subPlan + } + } + + // Recursively traverse down and try merging 2 plans. If merge is possible then return the merged + // plan with the attribute mapping from the new to the merged version. + private def tryMergePlans( + newPlan: LogicalPlan, + cachedPlan: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute])] = { + checkIdenticalPlans(newPlan, cachedPlan).map(cachedPlan -> _).orElse( + (newPlan, cachedPlan) match { + case (np: Project, cp: Project) => + tryMergePlans(np.child, cp.child).flatMap { case (mergedChild, outputMap) => + val projectTuples = np.projectList.map { ne => + val mapped = mapAttributes(ne, outputMap) + ne.canonicalized -> mapped + } + + if (checkIdenticalProjectList(projectTuples.map(_._2), cp.projectList)) { + val projectMap = projectTuples.toMap + val mergedProjectList = mutable.ArrayBuffer[NamedExpression](cp.projectList: _*) + val newOutputMap = AttributeMap(np.projectList.map { ne => + val mapped = projectMap(ne.canonicalized) + val withoutAlias = mapped match { + case Alias(child, _) => child + case e => e + } + ne.toAttribute -> mergedProjectList.find { + case Alias(child, _) => child semanticEquals withoutAlias + case e => e semanticEquals withoutAlias + }.getOrElse { + mergedProjectList += mapped + mapped + }.toAttribute + }) + val mergedPlan = Project(mergedProjectList.toSeq, mergedChild) + Some(mergedPlan -> newOutputMap) + } else { + None + } + } + + case (nf: Filter, cf: Filter) => + tryMergePlans(nf.child, cf.child).flatMap { case (mergedChild, outputMap) => + val mappedNewCondition = mapAttributes(nf.condition, outputMap) + if (checkCondition(mappedNewCondition, cf.condition)) { + val combinedCondition = Or(cf.condition, mappedNewCondition) + val mergedPlan = cf.copy(condition = combinedCondition, child = mergedChild) + Some(mergedPlan -> outputMap) + } else { + None + } + } + + case (ns: SerializeFromObject, cs: SerializeFromObject) => + checkIdenticalPlans(ns, cs).map { outputMap => + (cs, outputMap) + } + + case _ => None + } + ) + } + + private def checkIdenticalProjectList( + nes: Seq[NamedExpression], ces: Seq[NamedExpression]): Boolean = { + val npAliases = getAliasMap(nes) + val cpAliases = getAliasMap(ces) + nes.zip(ces).forall { + case (ne1, ne2) => + replaceAlias(ne1, npAliases).semanticEquals(replaceAlias(ne2, cpAliases)) + } + } + + private def checkCondition(leftCondition: Expression, rightCondition: Expression): Boolean = { + val normalizedLeft = normalizeExpression(leftCondition) + val normalizedRight = normalizeExpression(rightCondition) + if (normalizedLeft.isDefined && normalizedRight.isDefined) { + (normalizedLeft.get, normalizedRight.get) match { + case (a GreaterThan b, c LessThan d) if a.semanticEquals(c) => + isGreaterOrEqualTo(b, d, a.dataType) + case (a LessThan b, c GreaterThan d) if a.semanticEquals(c) => + isGreaterOrEqualTo(d, b, a.dataType) + case (a GreaterThanOrEqual b, c LessThan d) if a.semanticEquals(c) => + isGreaterOrEqualTo(b, d, a.dataType) + case (a LessThan b, c GreaterThanOrEqual d) if a.semanticEquals(c) => + isGreaterOrEqualTo(d, b, a.dataType) + case (a GreaterThan b, c LessThanOrEqual d) if a.semanticEquals(c) => + isGreaterOrEqualTo(b, d, a.dataType) + case (a LessThanOrEqual b, c GreaterThan d) if a.semanticEquals(c) => + isGreaterOrEqualTo(d, b, a.dataType) + case (a EqualTo b, Not(c EqualTo d)) if a.semanticEquals(c) => + isEqualTo(b, d, a.dataType) + case _ => false + } + } else { + false + } + } + + private def normalizeExpression(expr: Expression): Option[Expression] = { + expr match { + case gt @ GreaterThan(_, r) if r.foldable => + Some(gt) + case l GreaterThan r if l.foldable => + Some(LessThanOrEqual(r, l)) + case lt @ LessThan(_, r) if r.foldable => + Some(lt) + case l LessThan r if l.foldable => + Some(GreaterThanOrEqual(r, l)) + case gte @ GreaterThanOrEqual(_, r) if r.foldable => + Some(gte) + case l GreaterThanOrEqual r if l.foldable => + Some(LessThan(r, l)) + case lte @ LessThanOrEqual(_, r) if r.foldable => + Some(lte) + case l LessThanOrEqual r if l.foldable => + Some(GreaterThan(r, l)) + case eq @ EqualTo(_, r) if r.foldable => + Some(eq) + case l EqualTo r if l.foldable => + Some(EqualTo(r, l)) + case not @ Not(EqualTo(l, r)) if r.foldable => + Some(not) + case Not(l EqualTo r) if l.foldable => + Some(Not(EqualTo(r, l))) + case _ => None + } + } + + private def isGreaterOrEqualTo( + left: Expression, right: Expression, dataType: DataType): Boolean = dataType match { + case ShortType => left.eval().asInstanceOf[Short] >= right.eval().asInstanceOf[Short] + case IntegerType => left.eval().asInstanceOf[Int] >= right.eval().asInstanceOf[Int] + case LongType => left.eval().asInstanceOf[Long] >= right.eval().asInstanceOf[Long] + case FloatType => left.eval().asInstanceOf[Float] >= right.eval().asInstanceOf[Float] + case DoubleType => left.eval().asInstanceOf[Double] >= right.eval().asInstanceOf[Double] + case DecimalType.Fixed(_, _) => + left.eval().asInstanceOf[Decimal] >= right.eval().asInstanceOf[Decimal] + case _ => false + } + + private def isEqualTo( + left: Expression, right: Expression, dataType: DataType): Boolean = dataType match { + case ShortType => left.eval().asInstanceOf[Short] == right.eval().asInstanceOf[Short] + case IntegerType => left.eval().asInstanceOf[Int] == right.eval().asInstanceOf[Int] + case LongType => left.eval().asInstanceOf[Long] == right.eval().asInstanceOf[Long] + case FloatType => left.eval().asInstanceOf[Float] == right.eval().asInstanceOf[Float] + case DoubleType => left.eval().asInstanceOf[Double] == right.eval().asInstanceOf[Double] + case DecimalType.Fixed(_, _) => + left.eval().asInstanceOf[Decimal] == right.eval().asInstanceOf[Decimal] + case _ => false + } + + private def checkIdenticalPlans( + newPlan: LogicalPlan, + cachedPlan: LogicalPlan): Option[AttributeMap[Attribute]] = { + if (newPlan.canonicalized == cachedPlan.canonicalized) { + Some(AttributeMap(newPlan.output.zip(cachedPlan.output))) + } else { + None + } + } + + private def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]) = { + expr.transform { + case a: Attribute => outputMap.getOrElse(a, a) + }.asInstanceOf[T] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e4548ee69dc4a..847168f95e81e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -97,6 +97,7 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateOffsets, EliminateLimits, CombineUnions, + CombineUnionedSubquery, // Constant folding and strength reduction OptimizeRepartition, EliminateWindowPartitions, @@ -169,7 +170,8 @@ abstract class Optimizer(catalogManager: CatalogManager) Batch("Union", fixedPoint, RemoveNoopOperators, CombineUnions, - RemoveNoopUnion) :: + RemoveNoopUnion, + CombineUnionedSubquery) :: // Run this once earlier. This might simplify the plan and reduce cost of optimizer. // For example, a query such as Filter(LocalRelation) would go through all the heavy // optimizer rules that are triggered when there is a filter diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 11df764ebb03b..1a59951c336fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -114,6 +114,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.optimizer.CombineFilters" :: "org.apache.spark.sql.catalyst.optimizer.CombineTypedFilters" :: "org.apache.spark.sql.catalyst.optimizer.CombineUnions" :: + "org.apache.spark.sql.catalyst.optimizer.CombineUnionedSubquery" :: "org.apache.spark.sql.catalyst.optimizer.ConstantFolding" :: "org.apache.spark.sql.catalyst.optimizer.ConstantPropagation" :: "org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 428ad052eba81..d281ffdb4e769 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3769,6 +3769,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val COMBINE_UNIONED_SUBQUERYS_ENABLED = + buildConf("spark.sql.optimizer.combineUnionedSubquery.enabled") + .doc("When true, we attempt to eliminate union by combine subquery " + + "to reduce the scan times and avoid shuffle.") + .version("3.5.0") + .booleanConf + .createWithDefault(false) + val LEGACY_SIZE_OF_NULL = buildConf("spark.sql.legacy.sizeOfNull") .internal() .doc(s"If it is set to false, or ${ANSI_ENABLED.key} is true, then size of null returns " + @@ -5613,6 +5621,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) + def combineUnionedAggregatesEnabled: Boolean = + getConf(SQLConf.COMBINE_UNIONED_SUBQUERYS_ENABLED) + def legacySizeOfNull: Boolean = { // size(null) should return null under ansi mode. getConf(SQLConf.LEGACY_SIZE_OF_NULL) && !getConf(ANSI_ENABLED) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineUnionedSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineUnionedSubquerySuite.scala new file mode 100644 index 0000000000000..cebbab8e63875 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineUnionedSubquerySuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Distinct, LocalRelation, LogicalPlan, Union} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf + +class CombineUnionedSubquerySuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Eliminate Unions by combine subquery", Once, CombineUnionedSubquery) :: Nil + } + + override def test(testName: String, testTags: Tag*)(testFun: => Any) + (implicit pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf(SQLConf.COMBINE_UNIONED_SUBQUERYS_ENABLED.key -> "true") { + testFun + } + } + } + + private val testRelation = LocalRelation.fromExternalRows( + Seq("a".attr.int, "b".attr.int, "c".attr.int), + 1.to(6).map(i => Row(i, 2 * i, 3 * i))) + private val a = testRelation.output(0) + private val b = testRelation.output(1) + private val c = testRelation.output(2) + + test("union two side are Plans without Filter") { + val originalQuery1 = testRelation.union(testRelation) + comparePlans(Optimize.execute(originalQuery1.analyze), originalQuery1) + + val originalQuery2 = testRelation.select(b, c).union(testRelation.select(b, c)) + comparePlans(Optimize.execute(originalQuery2.analyze), originalQuery2) + } + + test("Union two subqueries and only one side with Filter") { + val originalQuery1 = testRelation.union(testRelation.where(a < 2)) + comparePlans(Optimize.execute(originalQuery1), originalQuery1) + + val originalQuery2 = testRelation.where(a > 4).union(testRelation) + comparePlans(Optimize.execute(originalQuery2), originalQuery2) + } + + test("Union two subqueries with filters intersect") { + val originalQuery1 = testRelation.where(a > 1).union(testRelation.where(a < 4)) + comparePlans(Optimize.execute(originalQuery1), originalQuery1) + + val originalQuery2 = + testRelation.where(a > 3).union(testRelation.where(a < 4)).union(testRelation.where(a < 2)) + comparePlans(Optimize.execute(originalQuery2), originalQuery2) + } + + test("Union two subqueries with filters do not intersect") { + val originalQuery1 = testRelation.where(a > 4).union(testRelation.where(a < 2)) + val correctAnswer1 = testRelation.where(a > 4 || a < 2) + comparePlans(Optimize.execute(originalQuery1), correctAnswer1) + + val originalQuery2 = + testRelation.where(a > 4).select(b, c).union(testRelation.where(a < 2).select(b, c)) + val correctAnswer2 = testRelation.where(a > 4 || a < 2).select(b, c) + comparePlans(Optimize.execute(originalQuery2), correctAnswer2) + + val originalQuery3 = + testRelation.where(a > 4).select((b + 1).as("b1"), c).union( + testRelation.where(a < 2).select((b + 1).as("b1"), c)) + val correctAnswer3 = testRelation.where(a > 4 || a < 2).select((b + 1).as("b1"), c) + comparePlans(Optimize.execute(originalQuery3), correctAnswer3) + + val originalQuery4 = + testRelation.select(a, b).where(a > 4).union(testRelation.select(a, b).where(a < 2)) + val correctAnswer4 = testRelation.select(a, b).where(a > 4 || a < 2) + comparePlans(Optimize.execute(originalQuery4), correctAnswer4) + + val originalQuery5 = + testRelation.select(a, (b + 1).as("b1")).where(a > 4).union( + testRelation.select(a, (b + 1).as("b1")).where(a < 2)) + val correctAnswer5 = testRelation.select(a, (b + 1).as("b1")).where(a > 4 || a < 2) + comparePlans(Optimize.execute(originalQuery5), correctAnswer5) + + val originalQuery6 = + testRelation.where(a > 4).select((b + 1).as("b1"), c).union( + testRelation.where(a < 2).select((b + 2).as("b1"), c)) + comparePlans(Optimize.execute(originalQuery6), originalQuery6) + + val originalQuery7 = + testRelation.select(a, (b + 1).as("b1")).where(a > 4).union( + testRelation.select(a, (b + 2).as("b1")).where(a < 2)) + comparePlans(Optimize.execute(originalQuery7), originalQuery7) + } + + test("upstream union could be optimized") { + val originalQuery1 = + testRelation.where(a > 4).union(testRelation.where(a < 2)).union(testRelation.where(a < 3)) + val correctAnswer1 = testRelation.where(a > 4 || a < 2).union(testRelation.where(a < 3)) + comparePlans(Optimize.execute(originalQuery1), correctAnswer1) + + val originalQuery2 = + testRelation.where(a > 4).union(testRelation.where(a < 2)).union(testRelation.select(a, b, c)) + val correctAnswer2 = testRelation.where(a > 4 || a < 2).union(testRelation.select(a, b, c)) + comparePlans(Optimize.execute(originalQuery2), correctAnswer2) + + val originalQuery3 = + Union(Seq(testRelation.where(a > 4), testRelation.where(a < 2), testRelation.where(a < 3))) + val correctAnswer3 = testRelation.where(a > 4 || a < 2).union(testRelation.where(a < 3)) + comparePlans(Optimize.execute(originalQuery3), correctAnswer3) + + val originalQuery4 = + Union(Seq(testRelation.where(a < 3), testRelation.where(a < 2), testRelation.where(a > 4))) + val correctAnswer4 = testRelation.where(a < 3 || a > 4).union(testRelation.where(a < 2)) + comparePlans(Optimize.execute(originalQuery4), correctAnswer4) + + val originalQuery5 = + Union(Seq(testRelation.where(a > 1), testRelation.where(a < 2), testRelation.where(a > 4))) + val correctAnswer5 = testRelation.where(a > 1).union(testRelation.where(a < 2 || a > 4)) + comparePlans(Optimize.execute(originalQuery5), correctAnswer5) + } + + test("nested Union with filters intersect") { + val originalQuery1 = + testRelation.where(a > 4).union(testRelation.where(a < 2).union(testRelation.where(a < 3))) + comparePlans(Optimize.execute(originalQuery1), originalQuery1) + } + + test("Union the distinct between two filters") { + val originalQuery1 = Distinct(testRelation.where(a > 4).union(testRelation.where(a < 2))) + comparePlans(Optimize.execute(originalQuery1), originalQuery1) + } + + test("Deduplicate the Union between two filters") { + val originalQuery1 = + Deduplicate(testRelation.output, testRelation.where(a > 4).union(testRelation.where(a < 2))) + comparePlans(Optimize.execute(originalQuery1), originalQuery1) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index cbc39557ce4cc..36b0f234e68f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -985,6 +985,115 @@ class DataFrameSetOperationsSuite extends QueryTest } } + test("SPARK-43025: Eliminate Union if filters have the same child plan") { + Seq(false, true).foreach { enabled => + withSQLConf(SQLConf.COMBINE_UNIONED_SUBQUERYS_ENABLED.key -> enabled.toString) { + val union1 = + testData2.where($"a" > 2).union( + testData2.where($"a" < 2)) + checkAnswer(union1, Seq(Row(1, 1), Row(1, 2), Row(3, 1), Row(3, 2))) + + val union2 = + testData2.where($"a" > 2).select($"b").union( + testData2.where($"a" < 2).select($"b")) + checkAnswer(union2, Seq(Row(1), Row(1), Row(2), Row(2))) + + val union3 = + testData2.where($"a" > 2).select($"b" + 1).union( + testData2.where($"a" < 2).select($"b" + 1)) + checkAnswer(union3, Seq(Row(2), Row(2), Row(3), Row(3))) + + val union4 = + testData2.where($"a" > 2).select(($"b" + 1).as("b1")).union( + testData2.where($"a" < 2).select(($"b" + 1).as("b1"))) + checkAnswer(union4, Seq(Row(2), Row(2), Row(3), Row(3))) + + val union5 = + testData2.select($"a").where($"a" > 2).union( + testData2.select($"a").where($"a" < 2)) + checkAnswer(union5, Seq(Row(1), Row(1), Row(3), Row(3))) + + val union6 = + testData2.select($"a", ($"b" + 1).as("b1")).where($"a" > 2).union( + testData2.select($"a", ($"b" + 1).as("b1")).where($"a" < 2)) + checkAnswer(union6, Seq(Row(1, 2), Row(1, 3), Row(3, 2), Row(3, 3))) + + val union7 = + testData2.where($"a" > 2).union( + testData2.where($"a" < 2)).union(testData2.where($"a" < 3)) + checkAnswer(union7, + Seq( + Row(1, 1), + Row(1, 1), + Row(1, 2), + Row(1, 2), + Row(2, 1), + Row(2, 2), + Row(3, 1), + Row(3, 2))) + + val union8 = + testData2.where($"a" > 2).union( + testData2.where($"a" < 2)).union(testData2.select($"a", $"b")) + checkAnswer(union8, + Seq( + Row(1, 1), + Row(1, 1), + Row(1, 2), + Row(1, 2), + Row(2, 1), + Row(2, 2), + Row(3, 1), + Row(3, 1), + Row(3, 2), + Row(3, 2))) + + val union9 = Union(Seq( + testData2.where($"a" > 2).logicalPlan, + testData2.where($"a" < 2).logicalPlan, + testData2.where($"a" < 3).logicalPlan)) + checkAnswer(union9, + Seq( + Row(1, 1), + Row(1, 1), + Row(1, 2), + Row(1, 2), + Row(2, 1), + Row(2, 2), + Row(3, 1), + Row(3, 2))) + + val union10 = Union(Seq( + testData2.where($"a" < 3).logicalPlan, + testData2.where($"a" < 2).logicalPlan, + testData2.where($"a" > 4).logicalPlan)) + checkAnswer(union10, + Seq( + Row(1, 1), + Row(1, 1), + Row(1, 2), + Row(1, 2), + Row(2, 1), + Row(2, 2))) + + val union11 = Union(Seq( + testData2.where($"a" < 3).logicalPlan, + testData2.where($"a" < 2).logicalPlan, + testData2.where($"a" > 2).logicalPlan)) + checkAnswer(union11, + Seq( + Row(1, 1), + Row(1, 1), + Row(1, 2), + Row(1, 2), + Row(2, 1), + Row(2, 2), + Row(3, 1), + Row(3, 2))) + } + } + } + test("SPARK-34548: Remove unnecessary children from Union") { Seq(RemoveNoopUnion.ruleName, "").map { ruleName => withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> ruleName) {