Skip to content

Commit b5e9866

Browse files
committed
added test suites
1 parent 5e9f32d commit b5e9866

File tree

1 file changed

+128
-2
lines changed

1 file changed

+128
-2
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,83 @@ package org.apache.spark.sql.execution.columnar
2020
import java.nio.charset.StandardCharsets
2121
import java.sql.{Date, Timestamp}
2222

23-
import org.apache.spark.sql.{QueryTest, Row}
23+
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
2424
import org.apache.spark.sql.internal.SQLConf
2525
import org.apache.spark.sql.test.SharedSQLContext
2626
import org.apache.spark.sql.test.SQLTestData._
2727
import org.apache.spark.sql.types._
28-
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
28+
import org.apache.spark.storage.StorageLevel._
2929

3030
class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
3131
import testImplicits._
3232

3333
setupTestData()
3434

35+
def cachePrimitiveTest(data: DataFrame, dataType: String) {
36+
data.createOrReplaceTempView(s"testData$dataType")
37+
val storageLevel = MEMORY_ONLY
38+
val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan
39+
val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None)
40+
41+
assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel)
42+
inMemoryRelation.cachedColumnBuffers.collect().head match {
43+
case _: CachedBatch => assert(true)
44+
case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}")
45+
}
46+
checkAnswer(inMemoryRelation, data.collect().toSeq)
47+
}
48+
49+
test("all data type w && w/o nullability") {
50+
// all primitives
51+
Seq(true, false).map { nullability =>
52+
val dataTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType,
53+
FloatType, DoubleType, DateType, TimestampType, DecimalType(25, 5), DecimalType(6, 5))
54+
val schema = StructType(dataTypes.zipWithIndex.map { case (dataType, index) =>
55+
StructField(s"col$index", dataType, nullability)
56+
})
57+
val rdd = spark.sparkContext.parallelize((1 to 10).map(i => Row(
58+
if (nullability && i % 3 == 0) null else if (i % 2 == 0) true else false,
59+
if (nullability && i % 3 == 0) null else i.toByte,
60+
if (nullability && i % 3 == 0) null else i.toShort,
61+
if (nullability && i % 3 == 0) null else i.toInt,
62+
if (nullability && i % 3 == 0) null else i.toLong,
63+
if (nullability && i % 3 == 0) null else (i + 0.25).toFloat,
64+
if (nullability && i % 3 == 0) null else (i + 0.75).toDouble,
65+
if (nullability && i % 3 == 0) null else new Date(i),
66+
if (nullability && i % 3 == 0) null else new Timestamp(i * 1000000L),
67+
if (nullability && i % 3 == 0) null else BigDecimal(Long.MaxValue.toString + ".12345"),
68+
if (nullability && i % 3 == 0) null
69+
else new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456")
70+
)))
71+
cachePrimitiveTest(spark.createDataFrame(rdd, schema), "primitivesDateTimeStamp")
72+
}
73+
74+
val schemaNull = StructType(Seq(StructField("col", NullType, true)))
75+
val rddNull = spark.sparkContext.parallelize((1 to 10).map(i => Row(null)))
76+
cachePrimitiveTest(spark.createDataFrame(rddNull, schemaNull), "Null")
77+
78+
Seq(true, false).map { nullability =>
79+
val struct = StructType(StructField("f1", FloatType, false) ::
80+
StructField("f2", ArrayType(BooleanType), true) :: Nil)
81+
val schema = StructType(Seq(
82+
StructField("col0", StringType, nullability),
83+
StructField("col1", ArrayType(IntegerType), nullability),
84+
StructField("col2", ArrayType(ArrayType(IntegerType)), nullability),
85+
StructField("col3", MapType(StringType, IntegerType), nullability),
86+
StructField("col4", struct, nullability)
87+
))
88+
val rdd = spark.sparkContext.parallelize((1 to 10).map(i => Row(
89+
if (nullability && i % 3 == 0) null else s"str${i}: test cache.",
90+
if (nullability && i % 3 == 0) null else (i * 100 to i * 100 + i).toArray,
91+
if (nullability && i % 3 == 0) null
92+
else Array(Array(i, i + 1), Array(i * 100 + 1, i * 100, i * 100 + 2)),
93+
if (nullability && i % 3 == 0) null else (i to i + i).map(j => s"key$j" -> j).toMap,
94+
if (nullability && i % 3 == 0) null else Row((i + 0.25).toFloat, Seq(true, false, null))
95+
)))
96+
cachePrimitiveTest(spark.createDataFrame(rdd, schema), "StringArrayMapStruct")
97+
}
98+
}
99+
35100
test("simple columnar query") {
36101
val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan
37102
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
@@ -58,6 +123,12 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
58123
}.map(Row.fromTuple))
59124
}
60125

126+
test("access only some column of the all of columns") {
127+
val df = spark.range(1, 100).map(i => (i, (i + 1).toFloat)).toDF("i", "f").cache
128+
df.count
129+
assert(df.filter("f <= 10.0").count == 9)
130+
}
131+
61132
test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
62133
val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan
63134
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
@@ -246,4 +317,59 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
246317
assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize)
247318
}
248319

320+
test("access columns in CachedBatch without whole stage codegen") {
321+
// whole stage codegen is not applied to a row with more than WHOLESTAGE_MAX_NUM_FIELDS fields
322+
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "2") {
323+
val data = Seq(null, true, 1.toByte, 3.toShort, 7, 15.toLong,
324+
31.25.toFloat, 63.75, new Date(127), new Timestamp(255000000L), null)
325+
val dataTypes = Seq(NullType, BooleanType, ByteType, ShortType, IntegerType, LongType,
326+
FloatType, DoubleType, DateType, TimestampType, IntegerType)
327+
val schemas = dataTypes.zipWithIndex.map { case (dataType, index) =>
328+
StructField(s"col$index", dataType, true)
329+
}
330+
val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data)))
331+
val df = spark.createDataFrame(rdd, StructType(schemas))
332+
val row = df.persist.take(1).apply(0)
333+
checkAnswer(df, row)
334+
}
335+
336+
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "2") {
337+
val data = Seq(BigDecimal(Long.MaxValue.toString + ".12345"),
338+
new java.math.BigDecimal("1234567890.12345"),
339+
new java.math.BigDecimal("1.23456"),
340+
"test123"
341+
)
342+
val schemas = Seq(
343+
StructField("col0", DecimalType(25, 5), true),
344+
StructField("col1", DecimalType(15, 5), true),
345+
StructField("col2", DecimalType(6, 5), true),
346+
StructField("col3", StringType, true)
347+
)
348+
val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data)))
349+
val df = spark.createDataFrame(rdd, StructType(schemas))
350+
val row = df.persist.take(1).apply(0)
351+
checkAnswer(df, row)
352+
}
353+
354+
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "2") {
355+
val data = Seq((1 to 10).toArray,
356+
Array(Array(10, 11), Array(100, 111, 123)),
357+
Map("key1" -> 111, "key2" -> 222),
358+
Row(1.25.toFloat, Seq(true, false, null))
359+
)
360+
val struct = StructType(StructField("f1", FloatType, false) ::
361+
StructField("f2", ArrayType(BooleanType), true) :: Nil)
362+
val schemas = Seq(
363+
StructField("col0", ArrayType(IntegerType), true),
364+
StructField("col1", ArrayType(ArrayType(IntegerType)), true),
365+
StructField("col2", MapType(StringType, IntegerType), true),
366+
StructField("col3", struct, true)
367+
)
368+
val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data)))
369+
val df = spark.createDataFrame(rdd, StructType(schemas))
370+
val row = df.persist.take(1).apply(0)
371+
checkAnswer(df, row)
372+
}
373+
}
374+
249375
}

0 commit comments

Comments
 (0)