Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,96 @@ package org.apache.spark.sql.execution.columnar
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
import org.apache.spark.storage.StorageLevel._

class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
import testImplicits._

setupTestData()

private def cachePrimitiveTest(data: DataFrame, dataType: String) {
data.createOrReplaceTempView(s"testData$dataType")
val storageLevel = MEMORY_ONLY
val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan
val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None)

assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel)
inMemoryRelation.cachedColumnBuffers.collect().head match {
case _: CachedBatch =>
case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}")
}
checkAnswer(inMemoryRelation, data.collect().toSeq)
}

private def testPrimitiveType(nullability: Boolean): Unit = {
val dataTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DateType, TimestampType, DecimalType(25, 5), DecimalType(6, 5))
val schema = StructType(dataTypes.zipWithIndex.map { case (dataType, index) =>
StructField(s"col$index", dataType, nullability)
})
val rdd = spark.sparkContext.parallelize((1 to 10).map(i => Row(
if (nullability && i % 3 == 0) null else if (i % 2 == 0) true else false,
if (nullability && i % 3 == 0) null else i.toByte,
if (nullability && i % 3 == 0) null else i.toShort,
if (nullability && i % 3 == 0) null else i.toInt,
if (nullability && i % 3 == 0) null else i.toLong,
if (nullability && i % 3 == 0) null else (i + 0.25).toFloat,
if (nullability && i % 3 == 0) null else (i + 0.75).toDouble,
if (nullability && i % 3 == 0) null else new Date(i),
if (nullability && i % 3 == 0) null else new Timestamp(i * 1000000L),
if (nullability && i % 3 == 0) null else BigDecimal(Long.MaxValue.toString + ".12345"),
if (nullability && i % 3 == 0) null
else new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456")
)))
cachePrimitiveTest(spark.createDataFrame(rdd, schema), "primitivesDateTimeStamp")
}

private def tesNonPrimitiveType(nullability: Boolean): Unit = {
val struct = StructType(StructField("f1", FloatType, false) ::
StructField("f2", ArrayType(BooleanType), true) :: Nil)
val schema = StructType(Seq(
StructField("col0", StringType, nullability),
StructField("col1", ArrayType(IntegerType), nullability),
StructField("col2", ArrayType(ArrayType(IntegerType)), nullability),
StructField("col3", MapType(StringType, IntegerType), nullability),
StructField("col4", struct, nullability)
))
val rdd = spark.sparkContext.parallelize((1 to 10).map(i => Row(
if (nullability && i % 3 == 0) null else s"str${i}: test cache.",
if (nullability && i % 3 == 0) null else (i * 100 to i * 100 + i).toArray,
if (nullability && i % 3 == 0) null
else Array(Array(i, i + 1), Array(i * 100 + 1, i * 100, i * 100 + 2)),
if (nullability && i % 3 == 0) null else (i to i + i).map(j => s"key$j" -> j).toMap,
if (nullability && i % 3 == 0) null else Row((i + 0.25).toFloat, Seq(true, false, null))
)))
cachePrimitiveTest(spark.createDataFrame(rdd, schema), "StringArrayMapStruct")
}

test("primitive type with nullability:true") {
testPrimitiveType(true)
}

test("primitive type with nullability:false") {
testPrimitiveType(false)
}

test("non-primitive type with nullability:true") {
val schemaNull = StructType(Seq(StructField("col", NullType, true)))
val rddNull = spark.sparkContext.parallelize((1 to 10).map(i => Row(null)))
cachePrimitiveTest(spark.createDataFrame(rddNull, schemaNull), "Null")

tesNonPrimitiveType(true)
}

test("non-primitive type with nullability:false") {
tesNonPrimitiveType(false)
}

test("simple columnar query") {
val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
Expand All @@ -58,6 +136,13 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
}.map(Row.fromTuple))
}

test("access only some column of the all of columns") {
val df = spark.range(1, 100).map(i => (i, (i + 1).toFloat)).toDF("i", "f")
df.cache
df.count // forced to build cache
assert(df.filter("f <= 10.0").count == 9)
}

test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
Expand Down Expand Up @@ -246,4 +331,63 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize)
}

test("access primitive-type columns in CachedBatch without whole stage codegen") {
// whole stage codegen is not applied to a row with more than WHOLESTAGE_MAX_NUM_FIELDS fields
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "2") {
val data = Seq(null, true, 1.toByte, 3.toShort, 7, 15.toLong,
31.25.toFloat, 63.75, new Date(127), new Timestamp(255000000L), null)
val dataTypes = Seq(NullType, BooleanType, ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DateType, TimestampType, IntegerType)
val schemas = dataTypes.zipWithIndex.map { case (dataType, index) =>
StructField(s"col$index", dataType, true)
}
val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data)))
val df = spark.createDataFrame(rdd, StructType(schemas))
val row = df.persist.take(1).apply(0)
checkAnswer(df, row)
}
}

test("access decimal/string-type columns in CachedBatch without whole stage codegen") {
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "2") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly can you split these into multiple smaller unit tests?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see

val data = Seq(BigDecimal(Long.MaxValue.toString + ".12345"),
new java.math.BigDecimal("1234567890.12345"),
new java.math.BigDecimal("1.23456"),
"test123"
)
val schemas = Seq(
StructField("col0", DecimalType(25, 5), true),
StructField("col1", DecimalType(15, 5), true),
StructField("col2", DecimalType(6, 5), true),
StructField("col3", StringType, true)
)
val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data)))
val df = spark.createDataFrame(rdd, StructType(schemas))
val row = df.persist.take(1).apply(0)
checkAnswer(df, row)
}
}

test("access non-primitive-type columns in CachedBatch without whole stage codegen") {
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "2") {
val data = Seq((1 to 10).toArray,
Array(Array(10, 11), Array(100, 111, 123)),
Map("key1" -> 111, "key2" -> 222),
Row(1.25.toFloat, Seq(true, false, null))
)
val struct = StructType(StructField("f1", FloatType, false) ::
StructField("f2", ArrayType(BooleanType), true) :: Nil)
val schemas = Seq(
StructField("col0", ArrayType(IntegerType), true),
StructField("col1", ArrayType(ArrayType(IntegerType)), true),
StructField("col2", MapType(StringType, IntegerType), true),
StructField("col3", struct, true)
)
val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data)))
val df = spark.createDataFrame(rdd, StructType(schemas))
val row = df.persist.take(1).apply(0)
checkAnswer(df, row)
}
}

}