Skip to content

Commit 46d68db

Browse files
committed
test for remove metadata block
test for remove metadata block fix the test fix the test fix the test
1 parent d4f12b1 commit 46d68db

File tree

2 files changed

+60
-19
lines changed

2 files changed

+60
-19
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/CachedColumnarRDD.scala

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,7 @@ class CachedColumnarRDD(
3030
extends RDD[CachedBatch](_sc, Seq(new OneToOneDependency(dataRDD))) {
3131

3232
override def compute(split: Partition, context: TaskContext): Iterator[CachedBatch] = {
33-
if (containsPartitionMetadata) {
34-
val parentIterator = dataRDD.iterator(split, context)
35-
if (!parentIterator.hasNext) {
36-
Iterator()
37-
} else {
38-
val cachedBatch = parentIterator.next()
39-
SparkEnv.get.blockManager.putSingle(RDDPartitionMetadataBlockId(id, split.index),
40-
cachedBatch.stats, expectedStorageLevel)
41-
Iterator(cachedBatch)
42-
}
43-
} else {
44-
firstParent.iterator(split, context)
45-
}
33+
firstParent.iterator(split, context)
4634
}
4735

4836
override protected def getPartitions: Array[Partition] = dataRDD.partitions
@@ -54,7 +42,17 @@ class CachedColumnarRDD(
5442
SparkEnv.get.blockManager.getSingle[InternalRow](metadataBlockId).map(metadataBlock =>
5543
new InterruptibleIterator[CachedBatch](context,
5644
new CachedColumnarIterator(metadataBlock, split, context, superGetOrCompute))
57-
).getOrElse(superGetOrCompute(split, context))
45+
).getOrElse {
46+
val batchIter = superGetOrCompute(split, context)
47+
if (containsPartitionMetadata && getStorageLevel != StorageLevel.NONE && batchIter.hasNext) {
48+
val cachedBatch = batchIter.next()
49+
SparkEnv.get.blockManager.putSingle(metadataBlockId, cachedBatch.stats,
50+
expectedStorageLevel)
51+
new InterruptibleIterator[CachedBatch](context, Iterator(cachedBatch))
52+
} else {
53+
batchIter
54+
}
55+
}
5856
}
5957
}
6058

sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.columnar
2020
import java.nio.charset.StandardCharsets
2121
import java.sql.{Date, Timestamp}
2222

23+
import org.apache.spark.SparkEnv
2324
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
2425
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
2526
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In}
@@ -30,6 +31,7 @@ import org.apache.spark.sql.internal.SQLConf
3031
import org.apache.spark.sql.test.SharedSQLContext
3132
import org.apache.spark.sql.test.SQLTestData._
3233
import org.apache.spark.sql.types._
34+
import org.apache.spark.storage.RDDPartitionMetadataBlockId
3335
import org.apache.spark.storage.StorageLevel._
3436

3537
class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
@@ -482,7 +484,6 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
482484
}
483485

484486
test("table cache can prune unnecessary partitions correctly") {
485-
// scalastyle:off
486487
var bytesReadWithoutPruning = 0L
487488
var bytesReadWithPruning = 0L
488489
var inMemoryPartitionMetadata = false
@@ -499,14 +500,56 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
499500
Seq("true", "false").foreach { enabled =>
500501
withSQLConf(SQLConf.IN_MEMORY_PARTITION_METADATA.key -> enabled) {
501502
inMemoryPartitionMetadata = conf.inMemoryPartitionMetadata
502-
val df1 = (0 until 1000000).toDF("value").repartition(4).cache()
503-
df1.where("value >= 999999").collect()
504-
val resultArr = df1.where("value >= 999999").collect()
503+
val df1 = (0 until 100000).toDF("value").repartition(4).cache()
504+
df1.where("value >= 99999").collect()
505+
val resultArr = df1.where("value >= 99999").collect()
505506
assert(resultArr.length == 1)
506-
assert(resultArr.head.getInt(0) == 999999)
507+
assert(resultArr.head.getInt(0) == 99999)
507508
df1.unpersist(true)
508509
}
509510
}
511+
assert(bytesReadWithoutPruning > 0)
512+
assert(bytesReadWithPruning > 0)
510513
assert(bytesReadWithoutPruning > bytesReadWithPruning * 3)
511514
}
515+
516+
test("generate correct results when metadata block is removed") {
517+
var bytesReadWithMetadata = 0L
518+
var bytesReadWithoutMetadata = 0L
519+
@volatile var removePartitionMetadata = false
520+
sparkContext.addSparkListener(new SparkListener() {
521+
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
522+
val metrics = taskEnd.taskMetrics
523+
if (removePartitionMetadata) {
524+
bytesReadWithoutMetadata += metrics.inputMetrics.bytesRead
525+
} else {
526+
bytesReadWithMetadata += metrics.inputMetrics.bytesRead
527+
}
528+
}
529+
})
530+
Seq("true").foreach { enabled =>
531+
withSQLConf(SQLConf.IN_MEMORY_PARTITION_METADATA.key -> enabled) {
532+
removePartitionMetadata = true
533+
val df1 = (0 until 100000).toDF("value").repartition(4).cache()
534+
val inMemoryRelation = df1.queryExecution.optimizedPlan.collect {
535+
case m: InMemoryRelation => m
536+
}
537+
df1.where("value >= 99999").collect()
538+
(0 until 4).foreach(partitionId => SparkEnv.get.blockManager.removeBlock(
539+
RDDPartitionMetadataBlockId(inMemoryRelation.head.cachedColumnBuffers.id, partitionId)))
540+
var resultArr = df1.where("value >= 99999").collect()
541+
assert(resultArr.length === 1)
542+
assert(resultArr.head.getInt(0) === 99999)
543+
// scalastyle:off
544+
removePartitionMetadata = false
545+
resultArr = df1.where("value >= 99999").collect()
546+
assert(resultArr.length === 1)
547+
assert(resultArr.head.getInt(0) === 99999)
548+
df1.unpersist(blocking = true)
549+
assert(bytesReadWithMetadata > 0)
550+
assert(bytesReadWithoutMetadata > 0)
551+
assert(bytesReadWithoutMetadata > bytesReadWithMetadata * 3)
552+
}
553+
}
554+
}
512555
}

0 commit comments

Comments
 (0)