Skip to content

Commit 1186ef5

Browse files
committed
fix
1 parent 87ffe7a commit 1186ef5

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path}
2626
import org.apache.spark.internal.Logging
2727
import org.apache.spark.sql.{Dataset, SparkSession}
2828
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
29-
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
29+
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ResolvedHint}
3030
import org.apache.spark.sql.execution.columnar.InMemoryRelation
3131
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
3232
import org.apache.spark.storage.StorageLevel
@@ -170,9 +170,13 @@ class CacheManager extends Logging {
170170
def useCachedData(plan: LogicalPlan): LogicalPlan = {
171171
val newPlan = plan transformDown {
172172
case currentFragment =>
173-
lookupCachedData(currentFragment)
174-
.map(_.cachedRepresentation.withOutput(currentFragment.output))
175-
.getOrElse(currentFragment)
173+
lookupCachedData(currentFragment).map { cached =>
174+
val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output)
175+
currentFragment match {
176+
case hint: ResolvedHint => ResolvedHint(cachedPlan, hint.hints)
177+
case _ => cachedPlan
178+
}
179+
}.getOrElse(currentFragment)
176180
}
177181

178182
newPlan transformAllExpressions {

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,19 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
109109
}
110110
}
111111

112+
test("broadcast hint is lost") {
113+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
114+
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
115+
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
116+
df2.cache()
117+
val df3 = df1.join(broadcast(df2), Seq("key"), "inner")
118+
val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
119+
case b: BroadcastHashJoinExec => b
120+
}.size
121+
assert(numBroadCastHashJoin === 1)
122+
}
123+
}
124+
112125
test("broadcast hint isn't propagated after a join") {
113126
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
114127
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")

0 commit comments

Comments
 (0)