@@ -143,9 +143,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
143143 val schema = new StructType ().add(" k" , IntegerType ).add(" v" , StringType )
144144 val smallDF = spark.createDataFrame(rdd, schema)
145145 val df = spark.range(10 ).join(broadcast(smallDF), col(" k" ) === col(" id" ))
146- assert(df.queryExecution.executedPlan.find(p =>
147- p.isInstanceOf [WholeStageCodegenExec ] &&
148- p.asInstanceOf [WholeStageCodegenExec ].child.isInstanceOf [BroadcastHashJoinExec ]).isDefined)
146+ val broadcastHashJoin = df.queryExecution.executedPlan.find {
147+ case WholeStageCodegenExec (ProjectExec (_, _ : BroadcastHashJoinExec )) => true
148+ }
149+ assert(broadcastHashJoin.isDefined)
149150 assert(df.collect() === Array (Row (1 , 1 , " 1" ), Row (1 , 1 , " 1" ), Row (2 , 2 , " 2" )))
150151 }
151152
@@ -187,7 +188,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
187188 // test one join with non-unique key from build side
188189 val joinNonUniqueDF = df1.join(df2.hint(" SHUFFLE_HASH" ), $" k1" === $" k2" % 3 , " full_outer" )
189190 assert(joinNonUniqueDF.queryExecution.executedPlan.collect {
190- case WholeStageCodegenExec (_ : ShuffledHashJoinExec ) => true
191+ case WholeStageCodegenExec (ProjectExec (_, _ : ShuffledHashJoinExec ) ) => true
191192 }.size === 1 )
192193 checkAnswer(joinNonUniqueDF, Seq (Row (0 , 0 ), Row (0 , 3 ), Row (0 , 6 ), Row (0 , 9 ), Row (1 , 1 ),
193194 Row (1 , 4 ), Row (1 , 7 ), Row (2 , 2 ), Row (2 , 5 ), Row (2 , 8 ), Row (3 , null ), Row (4 , null )))
@@ -196,7 +197,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
196197 val joinWithNonEquiDF = df1.join(df2.hint(" SHUFFLE_HASH" ),
197198 $" k1" === $" k2" % 3 && $" k1" + 3 =!= $" k2" , " full_outer" )
198199 assert(joinWithNonEquiDF.queryExecution.executedPlan.collect {
199- case WholeStageCodegenExec (_ : ShuffledHashJoinExec ) => true
200+ case WholeStageCodegenExec (ProjectExec (_, _ : ShuffledHashJoinExec ) ) => true
200201 }.size === 1 )
201202 checkAnswer(joinWithNonEquiDF, Seq (Row (0 , 0 ), Row (0 , 6 ), Row (0 , 9 ), Row (1 , 1 ),
202203 Row (1 , 7 ), Row (2 , 2 ), Row (2 , 8 ), Row (3 , null ), Row (4 , null ), Row (null , 3 ), Row (null , 4 ),
0 commit comments