Skip to content

Commit 20c0efe

Browse files
cloud-fangatorsmile
authored andcommitted
[SPARK-23214][SQL] cached data should not carry extra hint info
## What changes were proposed in this pull request? This is a regression introduced by #19864 When we lookup cache, we should not carry the hint info, as this cache entry might be added by a plan having hint info, while the input plan for this lookup may not have hint info, or have different hint info. ## How was this patch tested? a new test. Author: Wenchen Fan <[email protected]> Closes #20394 from cloud-fan/cache. (cherry picked from commit 5b5447c) Signed-off-by: gatorsmile <[email protected]>
1 parent 7aaf23c commit 20c0efe

File tree

5 files changed

+94
-59
lines changed

5 files changed

+94
-59
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,17 @@ class CacheManager extends Logging {
169169
/** Replaces segments of the given logical plan with cached versions where possible. */
170170
def useCachedData(plan: LogicalPlan): LogicalPlan = {
171171
val newPlan = plan transformDown {
172+
// Do not lookup the cache by hint node. Hint node is special, we should ignore it when
173+
// canonicalizing plans, so that plans which are same except hint can hit the same cache.
174+
// However, we also want to keep the hint info after cache lookup. Here we skip the hint
175+
// node, so that the returned caching plan won't replace the hint node and drop the hint info
176+
// from the original plan.
177+
case hint: ResolvedHint => hint
178+
172179
case currentFragment =>
173-
lookupCachedData(currentFragment).map { cached =>
174-
val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output)
175-
currentFragment match {
176-
case hint: ResolvedHint => ResolvedHint(cachedPlan, hint.hints)
177-
case _ => cachedPlan
178-
}
179-
}.getOrElse(currentFragment)
180+
lookupCachedData(currentFragment)
181+
.map(_.cachedRepresentation.withOutput(currentFragment.output))
182+
.getOrElse(currentFragment)
180183
}
181184

182185
newPlan transformAllExpressions {

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
2626
import org.apache.spark.sql.catalyst.expressions._
2727
import org.apache.spark.sql.catalyst.plans.logical
28-
import org.apache.spark.sql.catalyst.plans.logical.Statistics
28+
import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics}
2929
import org.apache.spark.sql.execution.SparkPlan
3030
import org.apache.spark.storage.StorageLevel
3131
import org.apache.spark.util.LongAccumulator
@@ -62,8 +62,8 @@ case class InMemoryRelation(
6262
@transient child: SparkPlan,
6363
tableName: Option[String])(
6464
@transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
65-
val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
66-
statsOfPlanToCache: Statistics = null)
65+
val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
66+
statsOfPlanToCache: Statistics)
6767
extends logical.LeafNode with MultiInstanceRelation {
6868

6969
override protected def innerChildren: Seq[SparkPlan] = Seq(child)
@@ -73,11 +73,16 @@ case class InMemoryRelation(
7373
@transient val partitionStatistics = new PartitionStatistics(output)
7474

7575
override def computeStats(): Statistics = {
76-
if (batchStats.value == 0L) {
77-
// Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache
78-
statsOfPlanToCache
76+
if (sizeInBytesStats.value == 0L) {
77+
// Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache.
78+
// Note that we should drop the hint info here. We may cache a plan whose root node is a hint
79+
// node. When we lookup the cache with a semantically same plan without hint info, the plan
80+
// returned by cache lookup should not have hint info. If we lookup the cache with a
81+
// semantically same plan with a different hint info, `CacheManager.useCachedData` will take
82+
// care of it and retain the hint info in the lookup input plan.
83+
statsOfPlanToCache.copy(hints = HintInfo())
7984
} else {
80-
Statistics(sizeInBytes = batchStats.value.longValue)
85+
Statistics(sizeInBytes = sizeInBytesStats.value.longValue)
8186
}
8287
}
8388

@@ -122,7 +127,7 @@ case class InMemoryRelation(
122127
rowCount += 1
123128
}
124129

125-
batchStats.add(totalSize)
130+
sizeInBytesStats.add(totalSize)
126131

127132
val stats = InternalRow.fromSeq(
128133
columnBuilders.flatMap(_.columnStats.collectedStatistics))
@@ -144,7 +149,7 @@ case class InMemoryRelation(
144149
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
145150
InMemoryRelation(
146151
newOutput, useCompression, batchSize, storageLevel, child, tableName)(
147-
_cachedColumnBuffers, batchStats, statsOfPlanToCache)
152+
_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
148153
}
149154

150155
override def newInstance(): this.type = {
@@ -156,12 +161,12 @@ case class InMemoryRelation(
156161
child,
157162
tableName)(
158163
_cachedColumnBuffers,
159-
batchStats,
164+
sizeInBytesStats,
160165
statsOfPlanToCache).asInstanceOf[this.type]
161166
}
162167

163168
def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers
164169

165170
override protected def otherCopyArgs: Seq[AnyRef] =
166-
Seq(_cachedColumnBuffers, batchStats, statsOfPlanToCache)
171+
Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
167172
}

sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
368368
val toBeCleanedAccIds = new HashSet[Long]
369369

370370
val accId1 = spark.table("t1").queryExecution.withCachedData.collect {
371-
case i: InMemoryRelation => i.batchStats.id
371+
case i: InMemoryRelation => i.sizeInBytesStats.id
372372
}.head
373373
toBeCleanedAccIds += accId1
374374

375375
val accId2 = spark.table("t1").queryExecution.withCachedData.collect {
376-
case i: InMemoryRelation => i.batchStats.id
376+
case i: InMemoryRelation => i.sizeInBytesStats.id
377377
}.head
378378
toBeCleanedAccIds += accId2
379379

sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
336336
checkAnswer(cached, expectedAnswer)
337337

338338
// Check that the right size was calculated.
339-
assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize)
339+
assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize)
340340
}
341341

342342
test("access primitive-type columns in CachedBatch without whole stage codegen") {

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala

Lines changed: 65 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ import scala.reflect.ClassTag
2222
import org.apache.spark.AccumulatorSuite
2323
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
2424
import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft}
25-
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, WholeStageCodegenExec}
25+
import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec}
26+
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
2627
import org.apache.spark.sql.execution.exchange.EnsureRequirements
2728
import org.apache.spark.sql.functions._
2829
import org.apache.spark.sql.internal.SQLConf
@@ -70,8 +71,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
7071
private def testBroadcastJoin[T: ClassTag](
7172
joinType: String,
7273
forceBroadcast: Boolean = false): SparkPlan = {
73-
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
74-
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
74+
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
75+
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
7576

7677
// Comparison at the end is for broadcast left semi join
7778
val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
@@ -109,61 +110,89 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
109110
}
110111
}
111112

112-
test("broadcast hint is retained after using the cached data") {
113+
test("SPARK-23192: broadcast hint should be retained after using the cached data") {
113114
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
114-
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
115-
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
116-
df2.cache()
117-
val df3 = df1.join(broadcast(df2), Seq("key"), "inner")
118-
val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
119-
case b: BroadcastHashJoinExec => b
120-
}.size
121-
assert(numBroadCastHashJoin === 1)
115+
try {
116+
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
117+
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
118+
df2.cache()
119+
val df3 = df1.join(broadcast(df2), Seq("key"), "inner")
120+
val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
121+
case b: BroadcastHashJoinExec => b
122+
}.size
123+
assert(numBroadCastHashJoin === 1)
124+
} finally {
125+
spark.catalog.clearCache()
126+
}
127+
}
128+
}
129+
130+
test("SPARK-23214: cached data should not carry extra hint info") {
131+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
132+
try {
133+
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
134+
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
135+
broadcast(df2).cache()
136+
137+
val df3 = df1.join(df2, Seq("key"), "inner")
138+
val numCachedPlan = df3.queryExecution.executedPlan.collect {
139+
case i: InMemoryTableScanExec => i
140+
}.size
141+
// df2 should be cached.
142+
assert(numCachedPlan === 1)
143+
144+
val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
145+
case b: BroadcastHashJoinExec => b
146+
}.size
147+
// df2 should not be broadcasted.
148+
assert(numBroadCastHashJoin === 0)
149+
} finally {
150+
spark.catalog.clearCache()
151+
}
122152
}
123153
}
124154

125155
test("broadcast hint isn't propagated after a join") {
126156
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
127-
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
128-
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
157+
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
158+
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
129159
val df3 = df1.join(broadcast(df2), Seq("key"), "inner").drop(df2("key"))
130160

131-
val df4 = spark.createDataFrame(Seq((1, "5"), (2, "5"))).toDF("key", "value")
161+
val df4 = Seq((1, "5"), (2, "5")).toDF("key", "value")
132162
val df5 = df4.join(df3, Seq("key"), "inner")
133163

134-
val plan =
135-
EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan)
164+
val plan = EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan)
136165

137166
assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1)
138167
assert(plan.collect { case p: SortMergeJoinExec => p }.size === 1)
139168
}
140169
}
141170

142171
private def assertBroadcastJoin(df : Dataset[Row]) : Unit = {
143-
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
172+
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
144173
val joined = df1.join(df, Seq("key"), "inner")
145174

146-
val plan =
147-
EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan)
175+
val plan = EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan)
148176

149177
assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1)
150178
}
151179

152180
test("broadcast hint programming API") {
153181
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
154-
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value")
182+
val df2 = Seq((1, "1"), (2, "2"), (3, "2")).toDF("key", "value")
155183
val broadcasted = broadcast(df2)
156-
val df3 = spark.createDataFrame(Seq((2, "2"), (3, "3"))).toDF("key", "value")
157-
158-
val cases = Seq(broadcasted.limit(2),
159-
broadcasted.filter("value < 10"),
160-
broadcasted.sample(true, 0.5),
161-
broadcasted.distinct(),
162-
broadcasted.groupBy("value").agg(min($"key").as("key")),
163-
// except and intersect are semi/anti-joins which won't return more data then
164-
// their left argument, so the broadcast hint should be propagated here
165-
broadcasted.except(df3),
166-
broadcasted.intersect(df3))
184+
val df3 = Seq((2, "2"), (3, "3")).toDF("key", "value")
185+
186+
val cases = Seq(
187+
broadcasted.limit(2),
188+
broadcasted.filter("value < 10"),
189+
broadcasted.sample(true, 0.5),
190+
broadcasted.distinct(),
191+
broadcasted.groupBy("value").agg(min($"key").as("key")),
192+
// except and intersect are semi/anti-joins which won't return more data then
193+
// their left argument, so the broadcast hint should be propagated here
194+
broadcasted.except(df3),
195+
broadcasted.intersect(df3))
167196

168197
cases.foreach(assertBroadcastJoin)
169198
}
@@ -240,9 +269,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
240269
test("Shouldn't change broadcast join buildSide if user clearly specified") {
241270

242271
withTempView("t1", "t2") {
243-
spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1")
244-
spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value")
245-
.createTempView("t2")
272+
Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
273+
Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2")
246274

247275
val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
248276
val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
@@ -292,9 +320,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
292320
test("Shouldn't bias towards build right if user didn't specify") {
293321

294322
withTempView("t1", "t2") {
295-
spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1")
296-
spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value")
297-
.createTempView("t2")
323+
Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
324+
Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2")
298325

299326
val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
300327
val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes

0 commit comments

Comments
 (0)