@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.columnar
2020import java .nio .charset .StandardCharsets
2121import java .sql .{Date , Timestamp }
2222
23+ import org .apache .spark .SparkEnv
2324import org .apache .spark .scheduler .{SparkListener , SparkListenerTaskEnd }
2425import org .apache .spark .sql .{DataFrame , QueryTest , Row }
2526import org .apache .spark .sql .catalyst .expressions .{AttributeReference , AttributeSet , In }
@@ -30,6 +31,7 @@ import org.apache.spark.sql.internal.SQLConf
3031import org .apache .spark .sql .test .SharedSQLContext
3132import org .apache .spark .sql .test .SQLTestData ._
3233import org .apache .spark .sql .types ._
34+ import org .apache .spark .storage .RDDPartitionMetadataBlockId
3335import org .apache .spark .storage .StorageLevel ._
3436
3537class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
@@ -482,7 +484,6 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
482484 }
483485
484486 test(" table cache can prune unnecessary partitions correctly" ) {
485- // scalastyle:off
486487 var bytesReadWithoutPruning = 0L
487488 var bytesReadWithPruning = 0L
488489 var inMemoryPartitionMetadata = false
@@ -499,14 +500,56 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
499500 Seq (" true" , " false" ).foreach { enabled =>
500501 withSQLConf(SQLConf .IN_MEMORY_PARTITION_METADATA .key -> enabled) {
501502 inMemoryPartitionMetadata = conf.inMemoryPartitionMetadata
502- val df1 = (0 until 1000000 ).toDF(" value" ).repartition(4 ).cache()
503- df1.where(" value >= 999999 " ).collect()
504- val resultArr = df1.where(" value >= 999999 " ).collect()
503+ val df1 = (0 until 100000 ).toDF(" value" ).repartition(4 ).cache()
504+ df1.where(" value >= 99999 " ).collect()
505+ val resultArr = df1.where(" value >= 99999 " ).collect()
505506 assert(resultArr.length == 1 )
506- assert(resultArr.head.getInt(0 ) == 999999 )
507+ assert(resultArr.head.getInt(0 ) == 99999 )
507508 df1.unpersist(true )
508509 }
509510 }
511+ assert(bytesReadWithoutPruning > 0 )
512+ assert(bytesReadWithPruning > 0 )
510513 assert(bytesReadWithoutPruning > bytesReadWithPruning * 3 )
511514 }
515+
516+ test(" generate correct results when metadata block is removed" ) {
517+ var bytesReadWithMetadata = 0L
518+ var bytesReadWithoutMetadata = 0L
519+ @ volatile var removePartitionMetadata = false
520+ sparkContext.addSparkListener(new SparkListener () {
521+ override def onTaskEnd (taskEnd : SparkListenerTaskEnd ) {
522+ val metrics = taskEnd.taskMetrics
523+ if (removePartitionMetadata) {
524+ bytesReadWithoutMetadata += metrics.inputMetrics.bytesRead
525+ } else {
526+ bytesReadWithMetadata += metrics.inputMetrics.bytesRead
527+ }
528+ }
529+ })
530+ Seq (" true" ).foreach { enabled =>
531+ withSQLConf(SQLConf .IN_MEMORY_PARTITION_METADATA .key -> enabled) {
532+ removePartitionMetadata = true
533+ val df1 = (0 until 100000 ).toDF(" value" ).repartition(4 ).cache()
534+ val inMemoryRelation = df1.queryExecution.optimizedPlan.collect {
535+ case m : InMemoryRelation => m
536+ }
537+ df1.where(" value >= 99999" ).collect()
538+ (0 until 4 ).foreach(partitionId => SparkEnv .get.blockManager.removeBlock(
539+ RDDPartitionMetadataBlockId (inMemoryRelation.head.cachedColumnBuffers.id, partitionId)))
540+ var resultArr = df1.where(" value >= 99999" ).collect()
541+ assert(resultArr.length === 1 )
542+ assert(resultArr.head.getInt(0 ) === 99999 )
543+ // scalastyle:off
544+ removePartitionMetadata = false
545+ resultArr = df1.where(" value >= 99999" ).collect()
546+ assert(resultArr.length === 1 )
547+ assert(resultArr.head.getInt(0 ) === 99999 )
548+ df1.unpersist(blocking = true )
549+ assert(bytesReadWithMetadata > 0 )
550+ assert(bytesReadWithoutMetadata > 0 )
551+ assert(bytesReadWithoutMetadata > bytesReadWithMetadata * 3 )
552+ }
553+ }
554+ }
512555}
0 commit comments