@@ -22,7 +22,8 @@ import scala.reflect.ClassTag
2222import org .apache .spark .AccumulatorSuite
2323import org .apache .spark .sql .{Dataset , QueryTest , Row , SparkSession }
2424import org .apache .spark .sql .catalyst .expressions .{BitwiseAnd , BitwiseOr , Cast , Literal , ShiftLeft }
25- import org .apache .spark .sql .execution .{BinaryExecNode , SparkPlan , WholeStageCodegenExec }
25+ import org .apache .spark .sql .execution .{SparkPlan , WholeStageCodegenExec }
26+ import org .apache .spark .sql .execution .columnar .InMemoryTableScanExec
2627import org .apache .spark .sql .execution .exchange .EnsureRequirements
2728import org .apache .spark .sql .functions ._
2829import org .apache .spark .sql .internal .SQLConf
@@ -70,8 +71,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
7071 private def testBroadcastJoin [T : ClassTag ](
7172 joinType : String ,
7273 forceBroadcast : Boolean = false ): SparkPlan = {
73- val df1 = spark.createDataFrame( Seq ((1 , " 4" ), (2 , " 2" ) )).toDF(" key" , " value" )
74- val df2 = spark.createDataFrame( Seq ((1 , " 1" ), (2 , " 2" ) )).toDF(" key" , " value" )
74+ val df1 = Seq ((1 , " 4" ), (2 , " 2" )).toDF(" key" , " value" )
75+ val df2 = Seq ((1 , " 1" ), (2 , " 2" )).toDF(" key" , " value" )
7576
7677 // Comparison at the end is for broadcast left semi join
7778 val joinExpression = df1(" key" ) === df2(" key" ) && df1(" value" ) > df2(" value" )
@@ -109,61 +110,89 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
109110 }
110111 }
111112
112- test(" broadcast hint is retained after using the cached data" ) {
113+ test(" SPARK-23192: broadcast hint should be retained after using the cached data" ) {
113114 withSQLConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
114- val df1 = spark.createDataFrame(Seq ((1 , " 4" ), (2 , " 2" ))).toDF(" key" , " value" )
115- val df2 = spark.createDataFrame(Seq ((1 , " 1" ), (2 , " 2" ))).toDF(" key" , " value" )
116- df2.cache()
117- val df3 = df1.join(broadcast(df2), Seq (" key" ), " inner" )
118- val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
119- case b : BroadcastHashJoinExec => b
120- }.size
121- assert(numBroadCastHashJoin === 1 )
115+ try {
116+ val df1 = Seq ((1 , " 4" ), (2 , " 2" )).toDF(" key" , " value" )
117+ val df2 = Seq ((1 , " 1" ), (2 , " 2" )).toDF(" key" , " value" )
118+ df2.cache()
119+ val df3 = df1.join(broadcast(df2), Seq (" key" ), " inner" )
120+ val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
121+ case b : BroadcastHashJoinExec => b
122+ }.size
123+ assert(numBroadCastHashJoin === 1 )
124+ } finally {
125+ spark.catalog.clearCache()
126+ }
127+ }
128+ }
129+
130+ test(" SPARK-23214: cached data should not carry extra hint info" ) {
131+ withSQLConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
132+ try {
133+ val df1 = Seq ((1 , " 4" ), (2 , " 2" )).toDF(" key" , " value" )
134+ val df2 = Seq ((1 , " 1" ), (2 , " 2" )).toDF(" key" , " value" )
135+ broadcast(df2).cache()
136+
137+ val df3 = df1.join(df2, Seq (" key" ), " inner" )
138+ val numCachedPlan = df3.queryExecution.executedPlan.collect {
139+ case i : InMemoryTableScanExec => i
140+ }.size
141+ // df2 should be cached.
142+ assert(numCachedPlan === 1 )
143+
144+ val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
145+ case b : BroadcastHashJoinExec => b
146+ }.size
147+ // df2 should not be broadcasted.
148+ assert(numBroadCastHashJoin === 0 )
149+ } finally {
150+ spark.catalog.clearCache()
151+ }
122152 }
123153 }
124154
125155 test(" broadcast hint isn't propagated after a join" ) {
126156 withSQLConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
127- val df1 = spark.createDataFrame( Seq ((1 , " 4" ), (2 , " 2" ) )).toDF(" key" , " value" )
128- val df2 = spark.createDataFrame( Seq ((1 , " 1" ), (2 , " 2" ) )).toDF(" key" , " value" )
157+ val df1 = Seq ((1 , " 4" ), (2 , " 2" )).toDF(" key" , " value" )
158+ val df2 = Seq ((1 , " 1" ), (2 , " 2" )).toDF(" key" , " value" )
129159 val df3 = df1.join(broadcast(df2), Seq (" key" ), " inner" ).drop(df2(" key" ))
130160
131- val df4 = spark.createDataFrame( Seq ((1 , " 5" ), (2 , " 5" ) )).toDF(" key" , " value" )
161+ val df4 = Seq ((1 , " 5" ), (2 , " 5" )).toDF(" key" , " value" )
132162 val df5 = df4.join(df3, Seq (" key" ), " inner" )
133163
134- val plan =
135- EnsureRequirements (spark.sessionState.conf).apply(df5.queryExecution.sparkPlan)
164+ val plan = EnsureRequirements (spark.sessionState.conf).apply(df5.queryExecution.sparkPlan)
136165
137166 assert(plan.collect { case p : BroadcastHashJoinExec => p }.size === 1 )
138167 assert(plan.collect { case p : SortMergeJoinExec => p }.size === 1 )
139168 }
140169 }
141170
142171 private def assertBroadcastJoin (df : Dataset [Row ]) : Unit = {
143- val df1 = spark.createDataFrame( Seq ((1 , " 4" ), (2 , " 2" ) )).toDF(" key" , " value" )
172+ val df1 = Seq ((1 , " 4" ), (2 , " 2" )).toDF(" key" , " value" )
144173 val joined = df1.join(df, Seq (" key" ), " inner" )
145174
146- val plan =
147- EnsureRequirements (spark.sessionState.conf).apply(joined.queryExecution.sparkPlan)
175+ val plan = EnsureRequirements (spark.sessionState.conf).apply(joined.queryExecution.sparkPlan)
148176
149177 assert(plan.collect { case p : BroadcastHashJoinExec => p }.size === 1 )
150178 }
151179
152180 test(" broadcast hint programming API" ) {
153181 withSQLConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
154- val df2 = spark.createDataFrame( Seq ((1 , " 1" ), (2 , " 2" ), (3 , " 2" ) )).toDF(" key" , " value" )
182+ val df2 = Seq ((1 , " 1" ), (2 , " 2" ), (3 , " 2" )).toDF(" key" , " value" )
155183 val broadcasted = broadcast(df2)
156- val df3 = spark.createDataFrame(Seq ((2 , " 2" ), (3 , " 3" ))).toDF(" key" , " value" )
157-
158- val cases = Seq (broadcasted.limit(2 ),
159- broadcasted.filter(" value < 10" ),
160- broadcasted.sample(true , 0.5 ),
161- broadcasted.distinct(),
162- broadcasted.groupBy(" value" ).agg(min($" key" ).as(" key" )),
163- // except and intersect are semi/anti-joins which won't return more data then
164- // their left argument, so the broadcast hint should be propagated here
165- broadcasted.except(df3),
166- broadcasted.intersect(df3))
184+ val df3 = Seq ((2 , " 2" ), (3 , " 3" )).toDF(" key" , " value" )
185+
186+ val cases = Seq (
187+ broadcasted.limit(2 ),
188+ broadcasted.filter(" value < 10" ),
189+ broadcasted.sample(true , 0.5 ),
190+ broadcasted.distinct(),
191+ broadcasted.groupBy(" value" ).agg(min($" key" ).as(" key" )),
192+ // except and intersect are semi/anti-joins which won't return more data then
193+ // their left argument, so the broadcast hint should be propagated here
194+ broadcasted.except(df3),
195+ broadcasted.intersect(df3))
167196
168197 cases.foreach(assertBroadcastJoin)
169198 }
@@ -240,9 +269,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
240269 test(" Shouldn't change broadcast join buildSide if user clearly specified" ) {
241270
242271 withTempView(" t1" , " t2" ) {
243- spark.createDataFrame(Seq ((1 , " 4" ), (2 , " 2" ))).toDF(" key" , " value" ).createTempView(" t1" )
244- spark.createDataFrame(Seq ((1 , " 1" ), (2 , " 12.3" ), (2 , " 123" ))).toDF(" key" , " value" )
245- .createTempView(" t2" )
272+ Seq ((1 , " 4" ), (2 , " 2" )).toDF(" key" , " value" ).createTempView(" t1" )
273+ Seq ((1 , " 1" ), (2 , " 12.3" ), (2 , " 123" )).toDF(" key" , " value" ).createTempView(" t2" )
246274
247275 val t1Size = spark.table(" t1" ).queryExecution.analyzed.children.head.stats.sizeInBytes
248276 val t2Size = spark.table(" t2" ).queryExecution.analyzed.children.head.stats.sizeInBytes
@@ -292,9 +320,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
292320 test(" Shouldn't bias towards build right if user didn't specify" ) {
293321
294322 withTempView(" t1" , " t2" ) {
295- spark.createDataFrame(Seq ((1 , " 4" ), (2 , " 2" ))).toDF(" key" , " value" ).createTempView(" t1" )
296- spark.createDataFrame(Seq ((1 , " 1" ), (2 , " 12.3" ), (2 , " 123" ))).toDF(" key" , " value" )
297- .createTempView(" t2" )
323+ Seq ((1 , " 4" ), (2 , " 2" )).toDF(" key" , " value" ).createTempView(" t1" )
324+ Seq ((1 , " 1" ), (2 , " 12.3" ), (2 , " 123" )).toDF(" key" , " value" ).createTempView(" t2" )
298325
299326 val t1Size = spark.table(" t1" ).queryExecution.analyzed.children.head.stats.sizeInBytes
300327 val t2Size = spark.table(" t2" ).queryExecution.analyzed.children.head.stats.sizeInBytes
0 commit comments