@@ -20,18 +20,83 @@ package org.apache.spark.sql.execution.columnar
2020import java .nio .charset .StandardCharsets
2121import java .sql .{Date , Timestamp }
2222
23- import org .apache .spark .sql .{QueryTest , Row }
23+ import org .apache .spark .sql .{DataFrame , QueryTest , Row }
2424import org .apache .spark .sql .internal .SQLConf
2525import org .apache .spark .sql .test .SharedSQLContext
2626import org .apache .spark .sql .test .SQLTestData ._
2727import org .apache .spark .sql .types ._
28- import org .apache .spark .storage .StorageLevel .MEMORY_ONLY
28+ import org .apache .spark .storage .StorageLevel ._
2929
3030class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
3131 import testImplicits ._
3232
3333 setupTestData()
3434
35+ def cachePrimitiveTest (data : DataFrame , dataType : String ) {
36+ data.createOrReplaceTempView(s " testData $dataType" )
37+ val storageLevel = MEMORY_ONLY
38+ val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan
39+ val inMemoryRelation = InMemoryRelation (useCompression = true , 5 , storageLevel, plan, None )
40+
41+ assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel)
42+ inMemoryRelation.cachedColumnBuffers.collect().head match {
43+ case _ : CachedBatch => assert(true )
44+ case other => fail(s " Unexpected cached batch type: ${other.getClass.getName}" )
45+ }
46+ checkAnswer(inMemoryRelation, data.collect().toSeq)
47+ }
48+
49+ test(" all data type w && w/o nullability" ) {
50+ // all primitives
51+ Seq (true , false ).map { nullability =>
52+ val dataTypes = Seq (BooleanType , ByteType , ShortType , IntegerType , LongType ,
53+ FloatType , DoubleType , DateType , TimestampType , DecimalType (25 , 5 ), DecimalType (6 , 5 ))
54+ val schema = StructType (dataTypes.zipWithIndex.map { case (dataType, index) =>
55+ StructField (s " col $index" , dataType, nullability)
56+ })
57+ val rdd = spark.sparkContext.parallelize((1 to 10 ).map(i => Row (
58+ if (nullability && i % 3 == 0 ) null else if (i % 2 == 0 ) true else false ,
59+ if (nullability && i % 3 == 0 ) null else i.toByte,
60+ if (nullability && i % 3 == 0 ) null else i.toShort,
61+ if (nullability && i % 3 == 0 ) null else i.toInt,
62+ if (nullability && i % 3 == 0 ) null else i.toLong,
63+ if (nullability && i % 3 == 0 ) null else (i + 0.25 ).toFloat,
64+ if (nullability && i % 3 == 0 ) null else (i + 0.75 ).toDouble,
65+ if (nullability && i % 3 == 0 ) null else new Date (i),
66+ if (nullability && i % 3 == 0 ) null else new Timestamp (i * 1000000L ),
67+ if (nullability && i % 3 == 0 ) null else BigDecimal (Long .MaxValue .toString + " .12345" ),
68+ if (nullability && i % 3 == 0 ) null
69+ else new java.math.BigDecimal (s " ${i % 9 + 1 }" + " .23456" )
70+ )))
71+ cachePrimitiveTest(spark.createDataFrame(rdd, schema), " primitivesDateTimeStamp" )
72+ }
73+
74+ val schemaNull = StructType (Seq (StructField (" col" , NullType , true )))
75+ val rddNull = spark.sparkContext.parallelize((1 to 10 ).map(i => Row (null )))
76+ cachePrimitiveTest(spark.createDataFrame(rddNull, schemaNull), " Null" )
77+
78+ Seq (true , false ).map { nullability =>
79+ val struct = StructType (StructField (" f1" , FloatType , false ) ::
80+ StructField (" f2" , ArrayType (BooleanType ), true ) :: Nil )
81+ val schema = StructType (Seq (
82+ StructField (" col0" , StringType , nullability),
83+ StructField (" col1" , ArrayType (IntegerType ), nullability),
84+ StructField (" col2" , ArrayType (ArrayType (IntegerType )), nullability),
85+ StructField (" col3" , MapType (StringType , IntegerType ), nullability),
86+ StructField (" col4" , struct, nullability)
87+ ))
88+ val rdd = spark.sparkContext.parallelize((1 to 10 ).map(i => Row (
89+ if (nullability && i % 3 == 0 ) null else s " str ${i}: test cache. " ,
90+ if (nullability && i % 3 == 0 ) null else (i * 100 to i * 100 + i).toArray,
91+ if (nullability && i % 3 == 0 ) null
92+ else Array (Array (i, i + 1 ), Array (i * 100 + 1 , i * 100 , i * 100 + 2 )),
93+ if (nullability && i % 3 == 0 ) null else (i to i + i).map(j => s " key $j" -> j).toMap,
94+ if (nullability && i % 3 == 0 ) null else Row ((i + 0.25 ).toFloat, Seq (true , false , null ))
95+ )))
96+ cachePrimitiveTest(spark.createDataFrame(rdd, schema), " StringArrayMapStruct" )
97+ }
98+ }
99+
35100 test(" simple columnar query" ) {
36101 val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan
37102 val scan = InMemoryRelation (useCompression = true , 5 , MEMORY_ONLY , plan, None )
@@ -58,6 +123,12 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
58123 }.map(Row .fromTuple))
59124 }
60125
126+ test(" access only some column of the all of columns" ) {
127+ val df = spark.range(1 , 100 ).map(i => (i, (i + 1 ).toFloat)).toDF(" i" , " f" ).cache
128+ df.count
129+ assert(df.filter(" f <= 10.0" ).count == 9 )
130+ }
131+
61132 test(" SPARK-1436 regression: in-memory columns must be able to be accessed multiple times" ) {
62133 val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan
63134 val scan = InMemoryRelation (useCompression = true , 5 , MEMORY_ONLY , plan, None )
@@ -246,4 +317,59 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
246317 assert(cached.batchStats.value === expectedAnswer.size * INT .defaultSize)
247318 }
248319
320+ test(" access columns in CachedBatch without whole stage codegen" ) {
321+ // whole stage codegen is not applied to a row with more than WHOLESTAGE_MAX_NUM_FIELDS fields
322+ withSQLConf(SQLConf .WHOLESTAGE_MAX_NUM_FIELDS .key -> " 2" ) {
323+ val data = Seq (null , true , 1 .toByte, 3 .toShort, 7 , 15 .toLong,
324+ 31.25 .toFloat, 63.75 , new Date (127 ), new Timestamp (255000000L ), null )
325+ val dataTypes = Seq (NullType , BooleanType , ByteType , ShortType , IntegerType , LongType ,
326+ FloatType , DoubleType , DateType , TimestampType , IntegerType )
327+ val schemas = dataTypes.zipWithIndex.map { case (dataType, index) =>
328+ StructField (s " col $index" , dataType, true )
329+ }
330+ val rdd = sparkContext.makeRDD(Seq (Row .fromSeq(data)))
331+ val df = spark.createDataFrame(rdd, StructType (schemas))
332+ val row = df.persist.take(1 ).apply(0 )
333+ checkAnswer(df, row)
334+ }
335+
336+ withSQLConf(SQLConf .WHOLESTAGE_MAX_NUM_FIELDS .key -> " 2" ) {
337+ val data = Seq (BigDecimal (Long .MaxValue .toString + " .12345" ),
338+ new java.math.BigDecimal (" 1234567890.12345" ),
339+ new java.math.BigDecimal (" 1.23456" ),
340+ " test123"
341+ )
342+ val schemas = Seq (
343+ StructField (" col0" , DecimalType (25 , 5 ), true ),
344+ StructField (" col1" , DecimalType (15 , 5 ), true ),
345+ StructField (" col2" , DecimalType (6 , 5 ), true ),
346+ StructField (" col3" , StringType , true )
347+ )
348+ val rdd = sparkContext.makeRDD(Seq (Row .fromSeq(data)))
349+ val df = spark.createDataFrame(rdd, StructType (schemas))
350+ val row = df.persist.take(1 ).apply(0 )
351+ checkAnswer(df, row)
352+ }
353+
354+ withSQLConf(SQLConf .WHOLESTAGE_MAX_NUM_FIELDS .key -> " 2" ) {
355+ val data = Seq ((1 to 10 ).toArray,
356+ Array (Array (10 , 11 ), Array (100 , 111 , 123 )),
357+ Map (" key1" -> 111 , " key2" -> 222 ),
358+ Row (1.25 .toFloat, Seq (true , false , null ))
359+ )
360+ val struct = StructType (StructField (" f1" , FloatType , false ) ::
361+ StructField (" f2" , ArrayType (BooleanType ), true ) :: Nil )
362+ val schemas = Seq (
363+ StructField (" col0" , ArrayType (IntegerType ), true ),
364+ StructField (" col1" , ArrayType (ArrayType (IntegerType )), true ),
365+ StructField (" col2" , MapType (StringType , IntegerType ), true ),
366+ StructField (" col3" , struct, true )
367+ )
368+ val rdd = sparkContext.makeRDD(Seq (Row .fromSeq(data)))
369+ val df = spark.createDataFrame(rdd, StructType (schemas))
370+ val row = df.persist.take(1 ).apply(0 )
371+ checkAnswer(df, row)
372+ }
373+ }
374+
249375}
0 commit comments