From 5edaa2ac11c4f7ad4d3a6358bccc7eae0e817660 Mon Sep 17 00:00:00 2001 From: Rahul Tanwani Date: Sun, 14 Feb 2016 00:26:16 +0530 Subject: [PATCH 1/2] [SPARK-13309][SQL] Fix type inference issue with CSV data --- .../datasources/csv/CSVInferSchema.scala | 16 ++++++++-------- sql/core/src/test/resources/simple_sparse.csv | 5 +++++ .../datasources/csv/CSVInferSchemaSuite.scala | 6 ++++++ .../sql/execution/datasources/csv/CSVSuite.scala | 14 +++++++++++++- 4 files changed, 32 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/resources/simple_sparse.csv diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index ace8cd7ad864e..fe80085555eb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.types._ - private[csv] object CSVInferSchema { /** @@ -48,7 +47,11 @@ private[csv] object CSVInferSchema { tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes) val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => - StructField(thisHeader, rootType, nullable = true) + val dType = rootType match { + case z: NullType => StringType + case other => other + } + StructField(thisHeader, dType, nullable = true) } StructType(structFields) @@ -66,11 +69,7 @@ private[csv] object CSVInferSchema { def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = { first.zipAll(second, NullType, NullType).map { case ((a, b)) => - val tpe = findTightestCommonType(a, b).getOrElse(StringType) - tpe match { - case _: NullType => StringType - case other => other - } + findTightestCommonType(a, b).getOrElse(NullType) } } @@ -140,6 +139,8 @@ private[csv] object CSVInferSchema { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) + case (StringType, t2) => Some(StringType) + case (t1, StringType) => Some(StringType) // Promote numeric types to the highest of the two and all numeric types to unlimited decimal case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => @@ -150,7 +151,6 @@ private[csv] object CSVInferSchema { } } - private[csv] object CSVTypeCast { /** diff --git a/sql/core/src/test/resources/simple_sparse.csv b/sql/core/src/test/resources/simple_sparse.csv new file mode 100644 index 0000000000000..02d29cabf95f2 --- /dev/null +++ b/sql/core/src/test/resources/simple_sparse.csv @@ -0,0 +1,5 @@ +A,B,C,D +1,,, +,1,, +,,1, +,,,1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index a1796f1326007..f30a7e8290011 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -68,4 +68,10 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType) assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType) } + + test("Merging Nulltypes should yeild Nulltype.") { + assert( + CSVInferSchema.mergeRowTypes(Array(NullType), + Array(NullType)).deep == Array(NullType).deep) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 9d1f4569ad5e9..7d76e917607fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -37,6 +37,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val emptyFile = "empty.csv" private val commentsFile = "comments.csv" private val disableCommentsFile = "disable_comments.csv" + private val simpleSparseFile = "simple_sparse.csv" private def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).toString @@ -224,7 +225,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(result.schema.fieldNames.size === 1) } - test("DDL test with empty file") { sqlContext.sql(s""" |CREATE TEMPORARY TABLE carsTable @@ -387,4 +387,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(carsCopy, withHeader = true) } } + + test("Schema inference correctly identifies the datatype when data is sparse.") { + val df = sqlContext.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(simpleSparseFile)) + + assert( + df.schema.fields.map{field => field.dataType}.deep == + Array(IntegerType, IntegerType, IntegerType, IntegerType).deep) + } } From 60cd75cc0990a11b2c24f4896afd188028fd8257 Mon Sep 17 00:00:00 2001 From: Rahul Tanwani Date: Sun, 14 Feb 2016 14:25:12 +0000 Subject: [PATCH 2/2] Fix review comments --- .../spark/sql/execution/datasources/csv/CSVInferSchema.scala | 4 ++-- .../sql/execution/datasources/csv/CSVInferSchemaSuite.scala | 5 ++--- .../spark/sql/execution/datasources/csv/CSVSuite.scala | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index fe80085555eb5..7f1ed28046b1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -48,7 +48,7 @@ private[csv] object CSVInferSchema { val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => val dType = rootType match { - case z: NullType => StringType + case _: NullType => StringType case other => other } StructField(thisHeader, dType, nullable = true) @@ -68,7 +68,7 @@ private[csv] object CSVInferSchema { } def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = { - first.zipAll(second, NullType, NullType).map { case ((a, b)) => + first.zipAll(second, NullType, NullType).map { case (a, b) => findTightestCommonType(a, b).getOrElse(NullType) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index f30a7e8290011..412f1b89beee7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -70,8 +70,7 @@ class InferSchemaSuite extends SparkFunSuite { } test("Merging Nulltypes should yeild Nulltype.") { - assert( - CSVInferSchema.mergeRowTypes(Array(NullType), - Array(NullType)).deep == Array(NullType).deep) + val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType)) + assert(mergedNullTypes.deep == Array(NullType).deep) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 7d76e917607fc..516375b50081e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -396,7 +396,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .load(testFile(simpleSparseFile)) assert( - df.schema.fields.map{field => field.dataType}.deep == + df.schema.fields.map(field => field.dataType).deep == Array(IntegerType, IntegerType, IntegerType, IntegerType).deep) } }