From a3e0430f20ba467121cc20cc1c598b3dd5326c86 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Fri, 20 May 2022 17:52:23 +0800 Subject: [PATCH 1/3] Pull out complex join keys --- .../optimizer/PullOutComplexJoinKeys.scala | 115 ++++++++++++++++++ .../sql/catalyst/rules/RuleIdCollection.scala | 1 + .../PullOutComplexJoinKeysSuite.scala | 106 ++++++++++++++++ 3 files changed, 222 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeys.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeysSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeys.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeys.scala new file mode 100644 index 000000000000..bf4719f52a6b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeys.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{Alias, And, EqualTo, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.JOIN + +/** + * This rule pulls out the complex join keys expression if can not broadcast. + * Example: + * + * +- Join Inner, ((c1 % 2) = c2)) +- Join Inner, (_complexjoinkey_0 = c2)) + * :- Relation default.t1[c1] parquet => :- Project [(c1 % 2) AS _complexjoinkey_0] + * +- Relation default.t2[c2] parquet : +- Relation default.t1[c1] parquet + * +- Relation default.t2[c2] parquet + * + * For shuffle based join, we may evaluate the join keys for several times: + * - SMJ: always evaluate the join keys during join, and probably evaluate if has shuffle or sort + * - SHJ: always evaluate the join keys during join, and probably evaluate if has shuffle + * So this rule can reduce the cost of repetitive evaluation. + */ +object PullOutComplexJoinKeys extends Rule[LogicalPlan] with JoinSelectionHelper { + + private def isComplexExpression(e: Expression): Boolean = + e.deterministic && !e.foldable && e.children.nonEmpty + + private def hasComplexExpression(joinKeys: Seq[Expression]): Boolean = + joinKeys.exists(isComplexExpression) + + private def extractComplexExpression( + joinKeys: Seq[Expression], + startIndex: Int): mutable.LinkedHashMap[Expression, NamedExpression] = { + val map = new mutable.LinkedHashMap[Expression, NamedExpression]() + var i = startIndex + joinKeys.foreach { + case e: Expression if isComplexExpression(e) => + map.put(e.canonicalized, Alias(e, s"_complexjoinkey_$i")()) + i += 1 + case _ => + } + map + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + plan.transformWithPruning(_.containsPattern(JOIN), ruleId) { + case j @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, other, _, left, right, joinHint) + if hasComplexExpression(leftKeys) || hasComplexExpression(rightKeys) => + val leftComplexExprs = extractComplexExpression(leftKeys, 0) + val (newLeftKeys, newLeft) = + if ((!canBuildBroadcastLeft(joinType) || !canBroadcastBySize(left, conf)) && + leftComplexExprs.nonEmpty) { + ( + leftKeys.map { e => + if (leftComplexExprs.contains(e.canonicalized)) { + leftComplexExprs(e.canonicalized).toAttribute + } else { + e + } + }, + Project(left.output ++ leftComplexExprs.values.toSeq, left) + ) + } else { + (leftKeys, left) + } + + val rightComplexExprs = extractComplexExpression(rightKeys, leftComplexExprs.size) + val (newRightKeys, newRight) = + if ((!canBuildBroadcastRight(joinType) || !canBroadcastBySize(right, conf)) && + rightComplexExprs.nonEmpty) { + ( + rightKeys.map { e => + if (rightComplexExprs.contains(e.canonicalized)) { + rightComplexExprs(e.canonicalized).toAttribute + } else { + e + } + }, + Project(right.output ++ rightComplexExprs.values.toSeq, right) + ) + } else { + (rightKeys, right) + } + + if (left.eq(newLeft) && right.eq(newRight)) { + j + } else { + val newConditions = newLeftKeys.zip(newRightKeys).map { + case (l, r) => EqualTo(l, r) + } ++ other + + Join(newLeft, newRight, joinType, newConditions.reduceOption(And), joinHint) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 1204fa8c604a..e8ac08cc922d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -129,6 +129,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.optimizer.OptimizeUpdateFields":: "org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelation" :: "org.apache.spark.sql.catalyst.optimizer.PruneFilters" :: + "org.apache.spark.sql.catalyst.optimizer.PullOutComplexJoinKeys" :: "org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" :: "org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" :: "org.apache.spark.sql.catalyst.optimizer.PushFoldableIntoBranches" :: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeysSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeysSuite.scala new file mode 100644 index 000000000000..a786dc52c3bd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeysSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf + +class PullOutComplexJoinKeysSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("PullOutComplexJoinKeys", FixedPoint(1), + PullOutComplexJoinKeys, + CollapseProject) :: Nil + } + + val testRelation1 = LocalRelation($"a".int, $"b".int) + val testRelation2 = LocalRelation($"x".int, $"y".int) + + test("pull out complex join keys") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + // join + // a (complex join key) + // b + val plan1 = testRelation1.join(testRelation2, condition = Some($"a" % 2 === $"x")) + val expected1 = testRelation1.select(($"a" % 2) as "_complexjoinkey_0").join( + testRelation2, condition = Some($"_complexjoinkey_0" === $"x")) + comparePlans(Optimize.execute(plan1.analyze), expected1.analyze) + + // join + // project + // a (complex join key) + // b + val plan2 = testRelation1.select($"a").join( + testRelation2, condition = Some($"a" % 2 === $"x")) + val expected2 = testRelation1.select(($"a" % 2) as "_complexjoinkey_0") + .join(testRelation2, condition = Some($"_complexjoinkey_0" === $"x")) + comparePlans(Optimize.execute(plan2.analyze), expected2.analyze) + + // join + // a (two complex join keys) + // b + val plan3 = testRelation1.join(testRelation2, + condition = Some($"a" % 2 === $"x" && $"b" % 3 === $"y")) + val expected3 = testRelation1.select(($"a" % 2) as "_complexjoinkey_0", + ($"b" % 3) as "_complexjoinkey_1").join(testRelation2, + condition = Some($"_complexjoinkey_0" === $"x" && $"_complexjoinkey_1" === $"y")) + comparePlans(Optimize.execute(plan3.analyze), expected3.analyze) + + // join + // a + // b (complex join key) + val plan4 = testRelation1.join(testRelation2, condition = Some($"a" === $"x" % 2)) + val expected4 = testRelation1.join(testRelation2.select(($"x" % 2) as "_complexjoinkey_0"), + condition = Some($"a" === $"_complexjoinkey_0")) + comparePlans(Optimize.execute(plan4.analyze), expected4.analyze) + + // join + // a (complex join key) + // b (complex join key) + val plan5 = testRelation1.join(testRelation2, condition = Some($"a" % 2 === $"x" % 3)) + val expected5 = testRelation1.select(($"a" % 2) as "_complexjoinkey_0").join( + testRelation2.select(($"x" % 3) as "_complexjoinkey_1"), + condition = Some($"_complexjoinkey_0" === $"_complexjoinkey_1")) + comparePlans(Optimize.execute(plan5.analyze), expected5.analyze) + } + } + + test("do not pull out complex join keys") { + // can broadcast + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") { + val p1 = testRelation1.join(testRelation2, condition = Some($"a" % 2 === $"x")).analyze + comparePlans(Optimize.execute(p1), p1) + + val p2 = testRelation1.join(testRelation2, condition = Some($"a" === $"x" % 2)).analyze + comparePlans(Optimize.execute(p2), p2) + } + + // not contains complex expression + val p1 = testRelation1.subquery("t1").join( + testRelation2.subquery("t2"), condition = Some($"a" === $"x")) + comparePlans(Optimize.execute(p1.analyze), p1.analyze) + + // not a equi-join + val p2 = testRelation1.subquery("t1").join(testRelation2.subquery("t2")) + comparePlans(Optimize.execute(p2.analyze), p2.analyze) + } +} From c3f5506e7c7fd54bcd7b146415a67c5c768ed603 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Mon, 6 Jun 2022 20:07:58 +0800 Subject: [PATCH 2/3] nit --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 12e21faca9f2..7a09b4d6cf43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -167,6 +167,8 @@ abstract class Optimizer(catalogManager: CatalogManager) RemoveNoopUnion) :: Batch("OptimizeLimitZero", Once, OptimizeLimitZero) :: + Batch("Pull Out Complex Join Keys", Once, + PullOutComplexJoinKeys) :: // Run this once earlier. This might simplify the plan and reduce cost of optimizer. // For example, a query such as Filter(LocalRelation) would go through all the heavy // optimizer rules that are triggered when there is a filter From f9560d07a845e08f2c9c5bde96754891789a2b46 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Wed, 15 Jun 2022 19:31:07 +0800 Subject: [PATCH 3/3] fix test --- .../sql/catalyst/optimizer/Optimizer.scala | 4 ++-- .../optimizer/PullOutComplexJoinKeys.scala | 11 +++++++---- .../PullOutComplexJoinKeysSuite.scala | 19 ++++++++++++------- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 7a09b4d6cf43..9301abc9f225 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -167,8 +167,6 @@ abstract class Optimizer(catalogManager: CatalogManager) RemoveNoopUnion) :: Batch("OptimizeLimitZero", Once, OptimizeLimitZero) :: - Batch("Pull Out Complex Join Keys", Once, - PullOutComplexJoinKeys) :: // Run this once earlier. This might simplify the plan and reduce cost of optimizer. // For example, a query such as Filter(LocalRelation) would go through all the heavy // optimizer rules that are triggered when there is a filter @@ -212,6 +210,8 @@ abstract class Optimizer(catalogManager: CatalogManager) // idempotence enforcement on this batch. We thus make it FixedPoint(1) instead of Once. Batch("Join Reorder", FixedPoint(1), CostBasedJoinReorder) :+ + Batch("Pull Out Complex Join Keys", Once, + PullOutComplexJoinKeys) :+ Batch("Eliminate Sorts", Once, EliminateSorts) :+ Batch("Decimal Optimizations", fixedPoint, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeys.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeys.scala index bf4719f52a6b..9280f2a71382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeys.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeys.scala @@ -29,9 +29,10 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.JOIN * This rule pulls out the complex join keys expression if can not broadcast. * Example: * - * +- Join Inner, ((c1 % 2) = c2)) +- Join Inner, (_complexjoinkey_0 = c2)) - * :- Relation default.t1[c1] parquet => :- Project [(c1 % 2) AS _complexjoinkey_0] - * +- Relation default.t2[c2] parquet : +- Relation default.t1[c1] parquet + * +- Join Inner, ((c1 % 2) = c2)) - Project [c1, c2] + * :- Relation default.t1[c1] parquet +- Join Inner, (_complexjoinkey_0 = c2)) + * +- Relation default.t2[c2] parquet => :- Project [c1, (c1 % 2) AS _complexjoinkey_0] + * : +- Relation default.t1[c1] parquet * +- Relation default.t2[c2] parquet * * For shuffle based join, we may evaluate the join keys for several times: @@ -108,7 +109,9 @@ object PullOutComplexJoinKeys extends Rule[LogicalPlan] with JoinSelectionHelper case (l, r) => EqualTo(l, r) } ++ other - Join(newLeft, newRight, joinType, newConditions.reduceOption(And), joinHint) + Project( + j.output, + Join(newLeft, newRight, joinType, newConditions.reduceOption(And), joinHint)) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeysSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeysSuite.scala index a786dc52c3bd..ed82c8e4ed67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeysSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeysSuite.scala @@ -41,8 +41,9 @@ class PullOutComplexJoinKeysSuite extends PlanTest { // a (complex join key) // b val plan1 = testRelation1.join(testRelation2, condition = Some($"a" % 2 === $"x")) - val expected1 = testRelation1.select(($"a" % 2) as "_complexjoinkey_0").join( + val expected1 = testRelation1.select($"a", $"b", ($"a" % 2) as "_complexjoinkey_0").join( testRelation2, condition = Some($"_complexjoinkey_0" === $"x")) + .select($"a", $"b", $"x", $"y") comparePlans(Optimize.execute(plan1.analyze), expected1.analyze) // join @@ -51,8 +52,9 @@ class PullOutComplexJoinKeysSuite extends PlanTest { // b val plan2 = testRelation1.select($"a").join( testRelation2, condition = Some($"a" % 2 === $"x")) - val expected2 = testRelation1.select(($"a" % 2) as "_complexjoinkey_0") + val expected2 = testRelation1.select($"a", ($"a" % 2) as "_complexjoinkey_0") .join(testRelation2, condition = Some($"_complexjoinkey_0" === $"x")) + .select($"a", $"x", $"y") comparePlans(Optimize.execute(plan2.analyze), expected2.analyze) // join @@ -60,26 +62,29 @@ class PullOutComplexJoinKeysSuite extends PlanTest { // b val plan3 = testRelation1.join(testRelation2, condition = Some($"a" % 2 === $"x" && $"b" % 3 === $"y")) - val expected3 = testRelation1.select(($"a" % 2) as "_complexjoinkey_0", + val expected3 = testRelation1.select($"a", $"b", ($"a" % 2) as "_complexjoinkey_0", ($"b" % 3) as "_complexjoinkey_1").join(testRelation2, condition = Some($"_complexjoinkey_0" === $"x" && $"_complexjoinkey_1" === $"y")) + .select($"a", $"b", $"x", $"y") comparePlans(Optimize.execute(plan3.analyze), expected3.analyze) // join // a // b (complex join key) val plan4 = testRelation1.join(testRelation2, condition = Some($"a" === $"x" % 2)) - val expected4 = testRelation1.join(testRelation2.select(($"x" % 2) as "_complexjoinkey_0"), - condition = Some($"a" === $"_complexjoinkey_0")) + val expected4 = testRelation1.join(testRelation2.select($"x", $"y", + ($"x" % 2) as "_complexjoinkey_0"), condition = Some($"a" === $"_complexjoinkey_0")) + .select($"a", $"b", $"x", $"y") comparePlans(Optimize.execute(plan4.analyze), expected4.analyze) // join // a (complex join key) // b (complex join key) val plan5 = testRelation1.join(testRelation2, condition = Some($"a" % 2 === $"x" % 3)) - val expected5 = testRelation1.select(($"a" % 2) as "_complexjoinkey_0").join( - testRelation2.select(($"x" % 3) as "_complexjoinkey_1"), + val expected5 = testRelation1.select($"a", $"b", ($"a" % 2) as "_complexjoinkey_0").join( + testRelation2.select($"x", $"y", ($"x" % 3) as "_complexjoinkey_1"), condition = Some($"_complexjoinkey_0" === $"_complexjoinkey_1")) + .select($"a", $"b", $"x", $"y") comparePlans(Optimize.execute(plan5.analyze), expected5.analyze) } }