From 330d6154ef6199ae7ec757e15b5e8f092754c978 Mon Sep 17 00:00:00 2001 From: David Peacock Date: Thu, 16 Jul 2015 21:38:15 +0100 Subject: [PATCH 1/4] Support for nullable schema types --- .../com/databricks/spark/csv/CsvRelation.scala | 9 ++++++++- src/test/resources/null-numbers.csv | 4 ++++ .../com/databricks/spark/csv/CsvFastSuite.scala | 16 ++++++++++++++++ .../com/databricks/spark/csv/CsvSuite.scala | 15 +++++++++++++++ 4 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 src/test/resources/null-numbers.csv diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index faf162e..dcefd4d 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -108,7 +108,14 @@ case class CsvRelation protected[spark] ( try { index = 0 while (index < schemaFields.length) { - rowArray(index) = TypeCast.castTo(tokens(index), schemaFields(index).dataType) + rowArray(index) = if (schemaFields(index).nullable && tokens(index) == ""){ + schemaFields(index).dataType match { + case StringType => "" + case _ => null + } + } else { + TypeCast.castTo(tokens(index), schemaFields(index).dataType) + } index = index + 1 } Some(Row.fromSeq(rowArray)) diff --git a/src/test/resources/null-numbers.csv b/src/test/resources/null-numbers.csv new file mode 100644 index 0000000..310dcb7 --- /dev/null +++ b/src/test/resources/null-numbers.csv @@ -0,0 +1,4 @@ +name,age +alice,35 +bob, +,24 \ No newline at end of file diff --git a/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala index 4fef4ed..eeea378 100644 --- a/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala @@ -32,6 +32,7 @@ class CsvFastSuite extends FunSuite { val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv" val carsTsvFile = "src/test/resources/cars.tsv" val carsAltFile = "src/test/resources/cars-alternative.csv" + val nullNumbersFile = "src/test/resources/null-numbers.csv" val emptyFile = "src/test/resources/empty.csv" val escapeFile = "src/test/resources/escape.csv" val tempEmptyDir = "target/test/empty2/" @@ -387,4 +388,19 @@ class CsvFastSuite extends FunSuite { assert(results.first().getInt(0) === 1997) } + + test("DSL test nullable fields"){ + + val results = new CsvParser() + .withSchema(StructType(List(StructField("name", StringType, false), StructField("age", IntegerType, true)))) + .withUseHeader(true) + .withParserLib("univocity") + .csvFile(TestSQLContext, nullNumbersFile) + .collect() + + assert(results.head.toSeq == Seq("alice", 35)) + assert(results(1).toSeq == Seq("bob", null)) + assert(results(2).toSeq == Seq("", 24)) + + } } \ No newline at end of file diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index 2986823..0acb654 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -32,6 +32,7 @@ class CsvSuite extends FunSuite { val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv" val carsTsvFile = "src/test/resources/cars.tsv" val carsAltFile = "src/test/resources/cars-alternative.csv" + val nullNumbersFile = "src/test/resources/null-numbers.csv" val emptyFile = "src/test/resources/empty.csv" val escapeFile = "src/test/resources/escape.csv" val tempEmptyDir = "target/test/empty/" @@ -392,4 +393,18 @@ class CsvSuite extends FunSuite { assert(results.first().getInt(0) === 1997) } + + test("DSL test nullable fields"){ + + val results = new CsvParser() + .withSchema(StructType(List(StructField("name", StringType, false), StructField("age", IntegerType, true)))) + .withUseHeader(true) + .csvFile(TestSQLContext, nullNumbersFile) + .collect() + + assert(results.head.toSeq == Seq("alice", 35)) + assert(results(1).toSeq == Seq("bob", null)) + assert(results(2).toSeq == Seq("", 24)) + + } } \ No newline at end of file From f292cc5250f7669f835b0e465b021416c261e6df Mon Sep 17 00:00:00 2001 From: David Peacock Date: Thu, 16 Jul 2015 23:02:31 +0100 Subject: [PATCH 2/4] Pull null handling into TypeCast, add tests --- .../databricks/spark/csv/CsvRelation.scala | 10 +---- .../databricks/spark/csv/util/TypeCast.scala | 40 +++++++++++-------- .../spark/csv/util/TypeCastSuite.scala | 18 ++++++++- 3 files changed, 43 insertions(+), 25 deletions(-) diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index dcefd4d..12d1400 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -108,14 +108,8 @@ case class CsvRelation protected[spark] ( try { index = 0 while (index < schemaFields.length) { - rowArray(index) = if (schemaFields(index).nullable && tokens(index) == ""){ - schemaFields(index).dataType match { - case StringType => "" - case _ => null - } - } else { - TypeCast.castTo(tokens(index), schemaFields(index).dataType) - } + val field = schemaFields(index) + rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable) index = index + 1 } Some(Row.fromSeq(rowArray)) diff --git a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala index 62c7b17..d9f0b66 100644 --- a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala +++ b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala @@ -29,25 +29,33 @@ object TypeCast { * Casts given string datum to specified type. * Currently we do not support complex types (ArrayType, MapType, StructType). * + * For string types, this is simply the datum. For other types. + * For other nullable types, this is null if the string datum is empty. + * * @param datum string value * @param castType SparkSQL type */ - private[csv] def castTo(datum: String, castType: DataType): Any = { - castType match { - case _: ByteType => datum.toByte - case _: ShortType => datum.toShort - case _: IntegerType => datum.toInt - case _: LongType => datum.toLong - case _: FloatType => datum.toFloat - case _: DoubleType => datum.toDouble - case _: BooleanType => datum.toBoolean - case _: DecimalType => new BigDecimal(datum.replaceAll(",", "")) - // TODO(hossein): would be good to support other common timestamp formats - case _: TimestampType => Timestamp.valueOf(datum) - // TODO(hossein): would be good to support other common date formats - case _: DateType => Date.valueOf(datum) - case _: StringType => datum - case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + private[csv] def castTo(datum: String, castType: DataType, nullable: Boolean = true): Any = { + if (castType.isInstanceOf[StringType]){ + datum + } else if (nullable && datum == ""){ + null + } else { + castType match { + case _: ByteType => datum.toByte + case _: ShortType => datum.toShort + case _: IntegerType => datum.toInt + case _: LongType => datum.toLong + case _: FloatType => datum.toFloat + case _: DoubleType => datum.toDouble + case _: BooleanType => datum.toBoolean + case _: DecimalType => new BigDecimal(datum.replaceAll(",", "")) + // TODO(hossein): would be good to support other common timestamp formats + case _: TimestampType => Timestamp.valueOf(datum) + // TODO(hossein): would be good to support other common date formats + case _: DateType => Date.valueOf(datum) + case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + } } } diff --git a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala index dca2abd..f464a2c 100644 --- a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala @@ -19,7 +19,7 @@ import java.math.BigDecimal import org.scalatest.FunSuite -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.{StringType, IntegerType, DecimalType} class TypeCastSuite extends FunSuite { @@ -56,4 +56,20 @@ class TypeCastSuite extends FunSuite { } assert(exception.getMessage.contains("Unsupported special character for delimiter")) } + + test("Nullable types are handled"){ + assert(TypeCast.castTo("", IntegerType, nullable = true) == null) + } + + test("String type should always return the same as the input"){ + assert(TypeCast.castTo("", StringType, nullable = true) == "") + assert(TypeCast.castTo("", StringType, nullable = false) == "") + } + + test("Throws exception for empty string with non null type"){ + val exception = intercept[NumberFormatException]{ + TypeCast.castTo("", IntegerType, nullable = false) + } + assert(exception.getMessage.contains("For input string: \"\"")) + } } From 8465ad4f64aca43a211aee2f1fa6ad65369ffa7a Mon Sep 17 00:00:00 2001 From: David Peacock Date: Thu, 16 Jul 2015 23:27:11 +0100 Subject: [PATCH 3/4] Add tests for casts --- .../spark/csv/util/TypeCastSuite.scala | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala index f464a2c..e17bed5 100644 --- a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala @@ -16,10 +16,11 @@ package com.databricks.spark.csv.util import java.math.BigDecimal +import java.sql.{Date, Timestamp} import org.scalatest.FunSuite -import org.apache.spark.sql.types.{StringType, IntegerType, DecimalType} +import org.apache.spark.sql.types._ class TypeCastSuite extends FunSuite { @@ -72,4 +73,17 @@ class TypeCastSuite extends FunSuite { } assert(exception.getMessage.contains("For input string: \"\"")) } + + test("Types are cast correctly"){ + assert(TypeCast.castTo("10", ByteType) == 10) + assert(TypeCast.castTo("10", ShortType) == 10) + assert(TypeCast.castTo("10", IntegerType) == 10) + assert(TypeCast.castTo("10", LongType) == 10) + assert(TypeCast.castTo("1.00", FloatType) == 1.0) + assert(TypeCast.castTo("1.00", DoubleType) == 1.0) + assert(TypeCast.castTo("true", BooleanType) == true) + val timestamp = "2015-01-01 00:00:00" + assert(TypeCast.castTo(timestamp, TimestampType) == Timestamp.valueOf(timestamp)) + assert(TypeCast.castTo("2015-01-01", DateType) == Date.valueOf("2015-01-01")) + } } From 2d24a5260fefaab8c38cc5d350368f27496c5058 Mon Sep 17 00:00:00 2001 From: David Peacock Date: Fri, 17 Jul 2015 18:38:53 +0100 Subject: [PATCH 4/4] Simplify nullable logic --- src/main/scala/com/databricks/spark/csv/util/TypeCast.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala index d9f0b66..c3f4de2 100644 --- a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala +++ b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala @@ -36,9 +36,7 @@ object TypeCast { * @param castType SparkSQL type */ private[csv] def castTo(datum: String, castType: DataType, nullable: Boolean = true): Any = { - if (castType.isInstanceOf[StringType]){ - datum - } else if (nullable && datum == ""){ + if (datum == "" && nullable && !castType.isInstanceOf[StringType]){ null } else { castType match { @@ -54,6 +52,7 @@ object TypeCast { case _: TimestampType => Timestamp.valueOf(datum) // TODO(hossein): would be good to support other common date formats case _: DateType => Date.valueOf(datum) + case _: StringType => datum case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") } }