Skip to content

Commit d4f12b1

Browse files
committed
add first test case
fix compilation of tests fix tests revise the test fix test revise the test add missing file revise the test revise the test revise the test revise the test revise the test revise the test revise the test revise the test
1 parent 963ca0a commit d4f12b1

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ case class InMemoryRelation(
152152
private def buildBuffers(): Unit = {
153153
val output = child.output
154154

155-
// TODO: need better abstraction for two iterators here
156155
val batchedRDD = child.execute().mapPartitionsInternal { rowIterator =>
157156
new CachedBatchIterator(rowIterator, output, batchSize, useCompression, batchStats,
158157
usePartitionLevelMetadata)

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,10 @@ case class InMemoryTableScanExec(
223223
if !partitionFilter.eval(cachedIter.partitionStats) =>
224224
// scalastyle:off
225225
println(s"skipped partition $index")
226-
// scalastyle:on
227226
Iterator()
228227
case _ =>
229228
doFilterCachedBatches(cachedBatchIterator, schema, partitionFilter)
229+
// scalastyle:on
230230
}
231231
}
232232
}

sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.columnar
2020
import java.nio.charset.StandardCharsets
2121
import java.sql.{Date, Timestamp}
2222

23+
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
2324
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
2425
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In}
2526
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
@@ -479,4 +480,33 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
479480
}
480481
}
481482
}
483+
484+
test("table cache can prune unnecessary partitions correctly") {
485+
// scalastyle:off
486+
var bytesReadWithoutPruning = 0L
487+
var bytesReadWithPruning = 0L
488+
var inMemoryPartitionMetadata = false
489+
sparkContext.addSparkListener(new SparkListener() {
490+
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
491+
val metrics = taskEnd.taskMetrics
492+
if (inMemoryPartitionMetadata) {
493+
bytesReadWithPruning += metrics.inputMetrics.bytesRead
494+
} else {
495+
bytesReadWithoutPruning += metrics.inputMetrics.bytesRead
496+
}
497+
}
498+
})
499+
Seq("true", "false").foreach { enabled =>
500+
withSQLConf(SQLConf.IN_MEMORY_PARTITION_METADATA.key -> enabled) {
501+
inMemoryPartitionMetadata = conf.inMemoryPartitionMetadata
502+
val df1 = (0 until 1000000).toDF("value").repartition(4).cache()
503+
df1.where("value >= 999999").collect()
504+
val resultArr = df1.where("value >= 999999").collect()
505+
assert(resultArr.length == 1)
506+
assert(resultArr.head.getInt(0) == 999999)
507+
df1.unpersist(true)
508+
}
509+
}
510+
assert(bytesReadWithoutPruning > bytesReadWithPruning * 3)
511+
}
482512
}

0 commit comments

Comments
 (0)