diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index faf162e..12d1400 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -108,7 +108,8 @@ case class CsvRelation protected[spark] ( try { index = 0 while (index < schemaFields.length) { - rowArray(index) = TypeCast.castTo(tokens(index), schemaFields(index).dataType) + val field = schemaFields(index) + rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable) index = index + 1 } Some(Row.fromSeq(rowArray)) diff --git a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala index 62c7b17..c3f4de2 100644 --- a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala +++ b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala @@ -29,25 +29,32 @@ object TypeCast { * Casts given string datum to specified type. * Currently we do not support complex types (ArrayType, MapType, StructType). * + * For string types, this is simply the datum. For other types. + * For other nullable types, this is null if the string datum is empty. + * * @param datum string value * @param castType SparkSQL type */ - private[csv] def castTo(datum: String, castType: DataType): Any = { - castType match { - case _: ByteType => datum.toByte - case _: ShortType => datum.toShort - case _: IntegerType => datum.toInt - case _: LongType => datum.toLong - case _: FloatType => datum.toFloat - case _: DoubleType => datum.toDouble - case _: BooleanType => datum.toBoolean - case _: DecimalType => new BigDecimal(datum.replaceAll(",", "")) - // TODO(hossein): would be good to support other common timestamp formats - case _: TimestampType => Timestamp.valueOf(datum) - // TODO(hossein): would be good to support other common date formats - case _: DateType => Date.valueOf(datum) - case _: StringType => datum - case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + private[csv] def castTo(datum: String, castType: DataType, nullable: Boolean = true): Any = { + if (datum == "" && nullable && !castType.isInstanceOf[StringType]){ + null + } else { + castType match { + case _: ByteType => datum.toByte + case _: ShortType => datum.toShort + case _: IntegerType => datum.toInt + case _: LongType => datum.toLong + case _: FloatType => datum.toFloat + case _: DoubleType => datum.toDouble + case _: BooleanType => datum.toBoolean + case _: DecimalType => new BigDecimal(datum.replaceAll(",", "")) + // TODO(hossein): would be good to support other common timestamp formats + case _: TimestampType => Timestamp.valueOf(datum) + // TODO(hossein): would be good to support other common date formats + case _: DateType => Date.valueOf(datum) + case _: StringType => datum + case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + } } } diff --git a/src/test/resources/null-numbers.csv b/src/test/resources/null-numbers.csv new file mode 100644 index 0000000..310dcb7 --- /dev/null +++ b/src/test/resources/null-numbers.csv @@ -0,0 +1,4 @@ +name,age +alice,35 +bob, +,24 \ No newline at end of file diff --git a/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala index 4fef4ed..eeea378 100644 --- a/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala @@ -32,6 +32,7 @@ class CsvFastSuite extends FunSuite { val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv" val carsTsvFile = "src/test/resources/cars.tsv" val carsAltFile = "src/test/resources/cars-alternative.csv" + val nullNumbersFile = "src/test/resources/null-numbers.csv" val emptyFile = "src/test/resources/empty.csv" val escapeFile = "src/test/resources/escape.csv" val tempEmptyDir = "target/test/empty2/" @@ -387,4 +388,19 @@ class CsvFastSuite extends FunSuite { assert(results.first().getInt(0) === 1997) } + + test("DSL test nullable fields"){ + + val results = new CsvParser() + .withSchema(StructType(List(StructField("name", StringType, false), StructField("age", IntegerType, true)))) + .withUseHeader(true) + .withParserLib("univocity") + .csvFile(TestSQLContext, nullNumbersFile) + .collect() + + assert(results.head.toSeq == Seq("alice", 35)) + assert(results(1).toSeq == Seq("bob", null)) + assert(results(2).toSeq == Seq("", 24)) + + } } \ No newline at end of file diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index 2986823..0acb654 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -32,6 +32,7 @@ class CsvSuite extends FunSuite { val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv" val carsTsvFile = "src/test/resources/cars.tsv" val carsAltFile = "src/test/resources/cars-alternative.csv" + val nullNumbersFile = "src/test/resources/null-numbers.csv" val emptyFile = "src/test/resources/empty.csv" val escapeFile = "src/test/resources/escape.csv" val tempEmptyDir = "target/test/empty/" @@ -392,4 +393,18 @@ class CsvSuite extends FunSuite { assert(results.first().getInt(0) === 1997) } + + test("DSL test nullable fields"){ + + val results = new CsvParser() + .withSchema(StructType(List(StructField("name", StringType, false), StructField("age", IntegerType, true)))) + .withUseHeader(true) + .csvFile(TestSQLContext, nullNumbersFile) + .collect() + + assert(results.head.toSeq == Seq("alice", 35)) + assert(results(1).toSeq == Seq("bob", null)) + assert(results(2).toSeq == Seq("", 24)) + + } } \ No newline at end of file diff --git a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala index dca2abd..e17bed5 100644 --- a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala @@ -16,10 +16,11 @@ package com.databricks.spark.csv.util import java.math.BigDecimal +import java.sql.{Date, Timestamp} import org.scalatest.FunSuite -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types._ class TypeCastSuite extends FunSuite { @@ -56,4 +57,33 @@ class TypeCastSuite extends FunSuite { } assert(exception.getMessage.contains("Unsupported special character for delimiter")) } + + test("Nullable types are handled"){ + assert(TypeCast.castTo("", IntegerType, nullable = true) == null) + } + + test("String type should always return the same as the input"){ + assert(TypeCast.castTo("", StringType, nullable = true) == "") + assert(TypeCast.castTo("", StringType, nullable = false) == "") + } + + test("Throws exception for empty string with non null type"){ + val exception = intercept[NumberFormatException]{ + TypeCast.castTo("", IntegerType, nullable = false) + } + assert(exception.getMessage.contains("For input string: \"\"")) + } + + test("Types are cast correctly"){ + assert(TypeCast.castTo("10", ByteType) == 10) + assert(TypeCast.castTo("10", ShortType) == 10) + assert(TypeCast.castTo("10", IntegerType) == 10) + assert(TypeCast.castTo("10", LongType) == 10) + assert(TypeCast.castTo("1.00", FloatType) == 1.0) + assert(TypeCast.castTo("1.00", DoubleType) == 1.0) + assert(TypeCast.castTo("true", BooleanType) == true) + val timestamp = "2015-01-01 00:00:00" + assert(TypeCast.castTo(timestamp, TimestampType) == Timestamp.valueOf(timestamp)) + assert(TypeCast.castTo("2015-01-01", DateType) == Date.valueOf("2015-01-01")) + } }