diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 38a051c15476..bf74c8a51d27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -95,6 +95,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) EliminateOuterJoin, PushPredicateThroughJoin, PushDownPredicate, + PushDownLeftSemiAntiJoin, LimitPushDown, ColumnPruning, InferFiltersFromConstraints, @@ -1016,24 +1017,13 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // This also applies to Aggregate. case Filter(condition, project @ Project(fields, grandChild)) if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) => - - // Create a map of Aliases to their values from the child projection. - // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). - val aliasMap = AttributeMap(fields.collect { - case a: Alias => (a.toAttribute, a.child) - }) - + val aliasMap = getAliasMap(project) project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) case filter @ Filter(condition, aggregate: Aggregate) if aggregate.aggregateExpressions.forall(_.deterministic) && aggregate.groupingExpressions.nonEmpty => - // Find all the aliased expressions in the aggregate list that don't include any actual - // AggregateExpression, and create a map from the alias to the expression - val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { - case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => - (a.toAttribute, a.child) - }) + val aliasMap = getAliasMap(aggregate) // For each filter, expand the alias and check if the filter can be evaluated using // attributes produced by the aggregate operator's child operator. @@ -1131,7 +1121,23 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } } - private def canPushThrough(p: UnaryNode): Boolean = p match { + def getAliasMap(plan: Project): AttributeMap[Expression] = { + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). + AttributeMap(plan.projectList.collect { case a: Alias => (a.toAttribute, a.child) }) + } + + def getAliasMap(plan: Aggregate): AttributeMap[Expression] = { + // Find all the aliased expressions in the aggregate list that don't include any actual + // AggregateExpression, and create a map from the alias to the expression + val aliasMap = plan.aggregateExpressions.collect { + case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => + (a.toAttribute, a.child) + } + AttributeMap(aliasMap) + } + + def canPushThrough(p: UnaryNode): Boolean = p match { // Note that some operators (e.g. project, aggregate, union) are being handled separately // (earlier in this rule). case _: AppendColumns => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala new file mode 100644 index 000000000000..bc868df3dbb0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * This rule is a variant of [[PushDownPredicate]] which can handle + * pushing down Left semi and Left Anti joins below the following operators. + * 1) Project + * 2) Window + * 3) Union + * 4) Aggregate + * 5) Other permissible unary operators. please see [[PushDownPredicate.canPushThrough]]. + */ +object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // LeftSemi/LeftAnti over Project + case Join(p @ Project(pList, gChild), rightOp, LeftSemiOrAnti(joinType), joinCond, hint) + if pList.forall(_.deterministic) && + !pList.exists(ScalarSubquery.hasCorrelatedScalarSubquery) && + canPushThroughCondition(Seq(gChild), joinCond, rightOp) => + if (joinCond.isEmpty) { + // No join condition, just push down the Join below Project + p.copy(child = Join(gChild, rightOp, joinType, joinCond, hint)) + } else { + val aliasMap = PushDownPredicate.getAliasMap(p) + val newJoinCond = if (aliasMap.nonEmpty) { + Option(replaceAlias(joinCond.get, aliasMap)) + } else { + joinCond + } + p.copy(child = Join(gChild, rightOp, joinType, newJoinCond, hint)) + } + + // LeftSemi/LeftAnti over Aggregate + case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) + if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty && + !agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) => + if (joinCond.isEmpty) { + // No join condition, just push down Join below Aggregate + agg.copy(child = Join(agg.child, rightOp, joinType, joinCond, hint)) + } else { + val aliasMap = PushDownPredicate.getAliasMap(agg) + + // For each join condition, expand the alias and check if the condition can be evaluated + // using attributes produced by the aggregate operator's child operator. + val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond => + val replaced = replaceAlias(cond, aliasMap) + cond.references.nonEmpty && + replaced.references.subsetOf(agg.child.outputSet ++ rightOp.outputSet) + } + + // Check if the remaining predicates do not contain columns from the right + // hand side of the join. Since the remaining predicates will be kept + // as a filter over aggregate, this check is necessary after the left semi + // or left anti join is moved below aggregate. The reason is, for this kind + // of join, we only output from the left leg of the join. + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) + + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val replaced = replaceAlias(pushDownPredicate, aliasMap) + val newAgg = agg.copy(child = Join(agg.child, rightOp, joinType, Option(replaced), hint)) + // If there is no more filter to stay up, just return the Aggregate over Join. + // Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)". + if (stayUp.isEmpty) newAgg else Filter(stayUp.reduce(And), newAgg) + } else { + // The join condition is not a subset of the Aggregate's GROUP BY columns, + // no push down. + join + } + } + + // LeftSemi/LeftAnti over Window + case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) + if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => + if (joinCond.isEmpty) { + // No join condition, just push down Join below Window + w.copy(child = Join(w.child, rightOp, joinType, joinCond, hint)) + } else { + val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++ + rightOp.outputSet + + val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond => + cond.references.subsetOf(partitionAttrs) + } + + // Check if the remaining predicates do not contain columns from the right + // hand side of the join. Since the remaining predicates will be kept + // as a filter over window, this check is necessary after the left semi + // or left anti join is moved below window. The reason is, for this kind + // of join, we only output from the left leg of the join. + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) + + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { + val predicate = pushDown.reduce(And) + val newPlan = w.copy(child = Join(w.child, rightOp, joinType, Option(predicate), hint)) + if (stayUp.isEmpty) newPlan else Filter(stayUp.reduce(And), newPlan) + } else { + // The join condition is not a subset of the Window's PARTITION BY clause, + // no push down. + join + } + } + + // LeftSemi/LeftAnti over Union + case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) + if canPushThroughCondition(union.children, joinCond, rightOp) => + if (joinCond.isEmpty) { + // Push down the Join below Union + val newGrandChildren = union.children.map { Join(_, rightOp, joinType, joinCond, hint) } + union.withNewChildren(newGrandChildren) + } else { + val output = union.output + val newGrandChildren = union.children.map { grandchild => + val newCond = joinCond.get transform { + case e if output.exists(_.semanticEquals(e)) => + grandchild.output(output.indexWhere(_.semanticEquals(e))) + } + assert(newCond.references.subsetOf(grandchild.outputSet ++ rightOp.outputSet)) + Join(grandchild, rightOp, joinType, Option(newCond), hint) + } + union.withNewChildren(newGrandChildren) + } + + // LeftSemi/LeftAnti over UnaryNode + case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) + if PushDownPredicate.canPushThrough(u) && u.expressions.forall(_.deterministic) => + pushDownJoin(join, u.child) { joinCond => + u.withNewChildren(Seq(Join(u.child, rightOp, joinType, joinCond, hint))) + } + } + + /** + * Check if we can safely push a join through a project or union by making sure that attributes + * referred in join condition do not contain the same attributes as the plan they are moved + * into. This can happen when both sides of join refers to the same source (self join). This + * function makes sure that the join condition refers to attributes that are not ambiguous (i.e + * present in both the legs of the join) or else the resultant plan will be invalid. + */ + private def canPushThroughCondition( + plans: Seq[LogicalPlan], + condition: Option[Expression], + rightOp: LogicalPlan): Boolean = { + val attributes = AttributeSet(plans.flatMap(_.output)) + if (condition.isDefined) { + val matched = condition.get.references.intersect(rightOp.outputSet).intersect(attributes) + matched.isEmpty + } else { + true + } + } + + + private def pushDownJoin( + join: Join, + grandchild: LogicalPlan)(insertJoin: Option[Expression] => LogicalPlan): LogicalPlan = { + if (join.condition.isEmpty) { + insertJoin(None) + } else { + val (pushDown, stayUp) = splitConjunctivePredicates(join.condition.get) + .partition {_.references.subsetOf(grandchild.outputSet ++ join.right.outputSet)} + + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(join.right.outputSet) + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { + val newChild = insertJoin(Option(pushDown.reduceLeft(And))) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + join + } + } + } +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index c77849035a97..86cdc261329c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -114,3 +114,10 @@ object LeftExistence { case _ => None } } + +object LeftSemiOrAnti { + def unapply(joinType: JoinType): Option[JoinType] = joinType match { + case LeftSemi | LeftAnti => Some(joinType) + case _ => None + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala new file mode 100644 index 000000000000..1a0231ed2d99 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.unsafe.types.CalendarInterval + +class LeftSemiPushdownSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueryAliases) :: + Batch("Filter Pushdown", FixedPoint(10), + CombineFilters, + PushDownPredicate, + PushDownLeftSemiAntiJoin, + BooleanSimplification, + CollapseProject) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + val testRelation1 = LocalRelation('d.int) + + test("Project: LeftSemiAnti join pushdown") { + val originalQuery = testRelation + .select(star()) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .select('a, 'b, 'c) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("Project: LeftSemiAnti join no pushdown because of non-deterministic proj exprs") { + val originalQuery = testRelation + .select(Rand('a), 'b, 'c) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } + + test("Project: LeftSemiAnti join non correlated scalar subq") { + val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze) + val originalQuery = testRelation + .select(subq.as("sum")) + .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some(subq === 'd)) + .select(subq.as("sum")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Project: LeftSemiAnti join no pushdown - correlated scalar subq in projection list") { + val testRelation2 = LocalRelation('e.int, 'f.int) + val subqPlan = testRelation2.groupBy('e)(sum('f).as("sum")).where('e === 'a) + val subqExpr = ScalarSubquery(subqPlan) + val originalQuery = testRelation + .select(subqExpr.as("sum")) + .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } + + test("Aggregate: LeftSemiAnti join pushdown") { + val originalQuery = testRelation + .groupBy('b)('b, sum('c)) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .groupBy('b)('b, sum('c)) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Aggregate: LeftSemiAnti join no pushdown due to non-deterministic aggr expressions") { + val originalQuery = testRelation + .groupBy('b)('b, Rand(10).as('c)) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } + + test("Aggregate: LeftSemiAnti join partial pushdown") { + val originalQuery = testRelation + .groupBy('b)('b, sum('c).as('sum)) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 10)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .groupBy('b)('b, sum('c).as('sum)) + .where('sum === 10) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("LeftSemiAnti join over aggregate - no pushdown") { + val originalQuery = testRelation + .groupBy('b)('b, sum('c).as('sum)) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } + + test("Aggregate: LeftSemiAnti join non-correlated scalar subq aggr exprs") { + val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze) + val originalQuery = testRelation + .groupBy('a) ('a, subq.as("sum")) + .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd && 'a === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some(subq === 'd && 'a === 'd)) + .groupBy('a) ('a, subq.as("sum")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("LeftSemiAnti join over Window") { + val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) + + val originalQuery = testRelation + .select('a, 'b, 'c, winExpr.as('window)) + .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)) + .select('a, 'b, 'c) + .window(winExpr.as('window) :: Nil, 'a :: Nil, 'b.asc :: Nil) + .select('a, 'b, 'c, 'window).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Window: LeftSemiAnti partial pushdown") { + // Attributes from join condition which does not refer to the window partition spec + // are kept up in the plan as a Filter operator above Window. + val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) + + val originalQuery = testRelation + .select('a, 'b, 'c, winExpr.as('window)) + .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd && 'b > 5)) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)) + .select('a, 'b, 'c) + .window(winExpr.as('window) :: Nil, 'a :: Nil, 'b.asc :: Nil) + .where('b > 5) + .select('a, 'b, 'c, 'window).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Union: LeftSemiAnti join pushdown") { + val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int) + + val originalQuery = Union(Seq(testRelation, testRelation2)) + .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Union(Seq( + testRelation.join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)), + testRelation2.join(testRelation1, joinType = LeftSemi, condition = Some('x === 'd)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Union: LeftSemiAnti join no pushdown in self join scenario") { + val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int) + + val originalQuery = Union(Seq(testRelation, testRelation2)) + .join(testRelation2, joinType = LeftSemi, condition = Some('a === 'x)) + + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } + + test("Unary: LeftSemiAnti join pushdown") { + val originalQuery = testRelation + .select(star()) + .repartition(1) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .select('a, 'b, 'c) + .repartition(1) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("Unary: LeftSemiAnti join pushdown - empty join condition") { + val originalQuery = testRelation + .select(star()) + .repartition(1) + .join(testRelation1, joinType = LeftSemi, condition = None) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = None) + .select('a, 'b, 'c) + .repartition(1) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("Unary: LeftSemiAnti join pushdown - partial pushdown") { + val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) + val originalQuery = testRelationWithArrayType + .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'b === 'out_col)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelationWithArrayType + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) + .where('b === 'out_col) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Unary: LeftSemiAnti join pushdown - no pushdown") { + val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) + val originalQuery = testRelationWithArrayType + .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'd === 'out_col)) + + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 98a8ad5eeb2b..b77048aba779 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -345,10 +345,10 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") // Assume the execution plan is - // ... -> BroadcastHashJoin(nodeId = 0) + // ... -> BroadcastHashJoin(nodeId = 1) val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") testSparkPlanMetrics(df, 2, Map( - 0L -> (("BroadcastHashJoin", Map( + 1L -> (("BroadcastHashJoin", Map( "number of output rows" -> 2L)))) ) }