diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index 5ed1e90..0c85658 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -118,8 +118,8 @@ case class CsvRelation protected[spark] ( index = 0 while (index < schemaFields.length) { val field = schemaFields(index) - rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable, - treatEmptyValuesAsNulls, nullValue, simpleDateFormatter) + rowArray(index) = TypeCast.castTo(tokens(index), field.name, field.dataType, + field.nullable, treatEmptyValuesAsNulls, nullValue, simpleDateFormatter) index = index + 1 } Some(Row.fromSeq(rowArray)) @@ -195,6 +195,7 @@ case class CsvRelation protected[spark] ( val field = schemaFields(index) rowArray(subIndex) = TypeCast.castTo( indexSafeTokens(index), + field.name, field.dataType, field.nullable, treatEmptyValuesAsNulls, diff --git a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala index 8c3474b..4c925bb 100644 --- a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala +++ b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala @@ -34,21 +34,43 @@ object TypeCast { * Currently we do not support complex types (ArrayType, MapType, StructType). * * For string types, this is simply the datum. For other types. - * For other nullable types, this is null if the string datum is empty. + * For other nullable types, returns null if it is null or equals to the value specified + * in `nullValue` option. If `treatEmptyValuesAsNulls` is set, it also returns null for + * empty strings. * * @param datum string value - * @param castType SparkSQL type + * @param name field name in schema. + * @param castType data type to cast `datum` into. + * @param nullable nullability for the field. + * @param treatEmptyValuesAsNulls a flag to indicate if empty strings should be treated as null. + * @param nullValue the string value that represents null. + * @param dateFormatter date formatter that uses to parse dates and timestamps. */ private[csv] def castTo( datum: String, + name: String, castType: DataType, nullable: Boolean = true, treatEmptyValuesAsNulls: Boolean = false, nullValue: String = "", dateFormatter: SimpleDateFormat = null): Any = { - if (datum == nullValue && - nullable || - (treatEmptyValuesAsNulls && datum == "")){ + + // If the given column is not nullable, we simply fall back to normal string + // rather than returning null for backwards compatibility. Note that this case is + // different with Spark's internal CSV datasource which throws an exception in this case. + val isNullValueMatched = datum == nullValue && nullable + + // If `treatEmptyValuesAsNulls` is enabled, treat empty strings as nulls. + val shouldTreatEmptyValuesAsNulls = treatEmptyValuesAsNulls && datum == "" + + // `datum` can be null when some tokens were inserted when permissive modes via `PrunedScan`. + // In this case, they are treated as nulls. + val isNullDatum = datum == null + + if (isNullValueMatched || shouldTreatEmptyValuesAsNulls || isNullDatum) { + if (!nullable) { + throw new RuntimeException(s"null value found but field $name is not nullable.") + } null } else { castType match { diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index b6d0681..30b96fc 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -1002,6 +1002,47 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { assert(ages.schema.fields(2).dataType === DoubleType) assert(ages.schema.fields(3).dataType === StringType) } + + test("Should read null properly when schema is lager than parsed tokens") { + val schema = StructType( + StructField("bool", BooleanType, true) :: + StructField("nullcol", IntegerType, true) :: + StructField("nullcol1", IntegerType, true) :: Nil) + + // Selects only bool and nullcol to use `PrunedScan` interface. If we select + // all, it falls back to `TableScan`. + val results = new CsvParser() + .withSchema(schema) + .withUseHeader(true) + .withParserLib(parserLib) + .withParseMode(ParseModes.PERMISSIVE_MODE) + .csvFile(sqlContext, boolFile) + .select("bool", "nullcol") + .collect() + + val expected = Seq(Row(true, null), Row(false, null), Row(false, null)) + assert(results.length == expected.length) + assert(results.toSet == expected.toSet) + + // Negative case + val nonNullableSchema = StructType( + StructField("bool", BooleanType, false) :: + StructField("nullcol", IntegerType, false) :: + StructField("nullcol1", IntegerType, false) :: Nil) + + val exception = intercept[SparkException] { + new CsvParser() + .withSchema(nonNullableSchema) + .withUseHeader(true) + .withParserLib(parserLib) + .withParseMode(ParseModes.PERMISSIVE_MODE) + .csvFile(sqlContext, boolFile) + .select("bool", "nullcol") + .collect() + } + + assert(exception.getMessage.contains("null value found but field nullcol")) + } } class CsvSuite extends AbstractCsvSuite { diff --git a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala index 9f3df23..e875025 100644 --- a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala @@ -32,7 +32,7 @@ class TypeCastSuite extends FunSuite { val decimalType = new DecimalType(None) stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => - assert(TypeCast.castTo(strVal, decimalType) === new BigDecimal(decimalVal.toString)) + assert(TypeCast.castTo(strVal, "_c", decimalType) === new BigDecimal(decimalVal.toString)) } } @@ -61,36 +61,46 @@ class TypeCastSuite extends FunSuite { } test("Nullable types are handled") { - assert(TypeCast.castTo("-", ByteType, nullable = true, nullValue = "-") == null) - assert(TypeCast.castTo("-", ShortType, nullable = true, nullValue = "-") == null) - assert(TypeCast.castTo("-", IntegerType, nullable = true, nullValue = "-") == null) - assert(TypeCast.castTo("-", LongType, nullable = true, nullValue = "-") == null) - assert(TypeCast.castTo("-", FloatType, nullable = true, nullValue = "-") == null) - assert(TypeCast.castTo("-", DoubleType, nullable = true, nullValue = "-") == null) - assert(TypeCast.castTo("-", BooleanType, nullable = true, nullValue = "-") == null) - assert(TypeCast.castTo("-", TimestampType, nullable = true, nullValue = "-") == null) - assert(TypeCast.castTo("-", DateType, nullable = true, nullValue = "-") == null) - assert(TypeCast.castTo("-", StringType, nullable = true, nullValue = "-") == null) + assert(TypeCast.castTo("-", "_c", ByteType, nullable = true, nullValue = "-") == null) + assert(TypeCast.castTo("-", "_c", ShortType, nullable = true, nullValue = "-") == null) + assert(TypeCast.castTo("-", "_c", IntegerType, nullable = true, nullValue = "-") == null) + assert(TypeCast.castTo("-", "_c", LongType, nullable = true, nullValue = "-") == null) + assert(TypeCast.castTo("-", "_c", FloatType, nullable = true, nullValue = "-") == null) + assert(TypeCast.castTo("-", "_c", DoubleType, nullable = true, nullValue = "-") == null) + assert(TypeCast.castTo("-", "_c", BooleanType, nullable = true, nullValue = "-") == null) + assert(TypeCast.castTo("-", "_c", TimestampType, nullable = true, nullValue = "-") == null) + assert(TypeCast.castTo("-", "_c", DateType, nullable = true, nullValue = "-") == null) + assert(TypeCast.castTo("-", "_c", StringType, nullable = false, nullValue = "-") == "-") } test("Throws exception for empty string with non null type") { - val exception = intercept[NumberFormatException]{ - TypeCast.castTo("", IntegerType, nullable = false) + val exception1 = intercept[NumberFormatException] { + TypeCast.castTo("", "_c", IntegerType, nullable = false) } - assert(exception.getMessage.contains("For input string: \"\"")) + assert(exception1.getMessage.contains("For input string: \"\"")) + + val exception2 = intercept[RuntimeException] { + TypeCast.castTo("", "_c", StringType, nullable = false, treatEmptyValuesAsNulls = true) + } + assert(exception2.getMessage.contains("null value found but field _c is not nullable")) + + val exception3 = intercept[RuntimeException] { + TypeCast.castTo(null, "_c", StringType, nullable = false) + } + assert(exception3.getMessage.contains("null value found but field _c is not nullable")) } test("Types are cast correctly") { - assert(TypeCast.castTo("10", ByteType) == 10) - assert(TypeCast.castTo("10", ShortType) == 10) - assert(TypeCast.castTo("10", IntegerType) == 10) - assert(TypeCast.castTo("10", LongType) == 10) - assert(TypeCast.castTo("1.00", FloatType) == 1.0) - assert(TypeCast.castTo("1.00", DoubleType) == 1.0) - assert(TypeCast.castTo("true", BooleanType) == true) + assert(TypeCast.castTo("10", "_c", ByteType) == 10) + assert(TypeCast.castTo("10", "_c", ShortType) == 10) + assert(TypeCast.castTo("10", "_c", IntegerType) == 10) + assert(TypeCast.castTo("10", "_c", LongType) == 10) + assert(TypeCast.castTo("1.00", "_c", FloatType) == 1.0) + assert(TypeCast.castTo("1.00", "_c", DoubleType) == 1.0) + assert(TypeCast.castTo("true", "_c", BooleanType) == true) val timestamp = "2015-01-01 00:00:00" - assert(TypeCast.castTo(timestamp, TimestampType) == Timestamp.valueOf(timestamp)) - assert(TypeCast.castTo("2015-01-01", DateType) == Date.valueOf("2015-01-01")) + assert(TypeCast.castTo(timestamp, "_c", TimestampType) == Timestamp.valueOf(timestamp)) + assert(TypeCast.castTo("2015-01-01", "_c", DateType) == Date.valueOf("2015-01-01")) val dateFormatter = new SimpleDateFormat("dd/MM/yyyy hh:mm") val customTimestamp = "31/01/2015 00:00" @@ -98,25 +108,25 @@ class TypeCastSuite extends FunSuite { // to `java.sql.Date` val expectedDate = new Date(dateFormatter.parse("31/01/2015 00:00").getTime) val expectedTimestamp = new Timestamp(expectedDate.getTime) - assert(TypeCast.castTo(customTimestamp, TimestampType, dateFormatter = dateFormatter) + assert(TypeCast.castTo(customTimestamp, "_c", TimestampType, dateFormatter = dateFormatter) == expectedTimestamp) - assert(TypeCast.castTo(customTimestamp, DateType, dateFormatter = dateFormatter) == + assert(TypeCast.castTo(customTimestamp, "_c", DateType, dateFormatter = dateFormatter) == expectedDate) } test("Float and Double Types are cast correctly with Locale") { val locale : Locale = new Locale("fr", "FR") Locale.setDefault(locale) - assert(TypeCast.castTo("1,00", FloatType) == 1.0) - assert(TypeCast.castTo("1,00", DoubleType) == 1.0) + assert(TypeCast.castTo("1,00", "_c", FloatType) == 1.0) + assert(TypeCast.castTo("1,00", "_c", DoubleType) == 1.0) } test("Can handle mapping user specified nullValues") { - assert(TypeCast.castTo("null", StringType, true, false, "null") == null) - assert(TypeCast.castTo("\\N", ByteType, true, false, "\\N") == null) - assert(TypeCast.castTo("", ShortType, true, false) == null) - assert(TypeCast.castTo("null", StringType, true, true, "null") == null) - assert(TypeCast.castTo("", StringType, true, false, "") == null) - assert(TypeCast.castTo("", StringType, true, true, "") == null) + assert(TypeCast.castTo("null", "_c", StringType, true, false, "null") == null) + assert(TypeCast.castTo("\\N", "_c", ByteType, true, false, "\\N") == null) + assert(TypeCast.castTo("", "_c", ShortType, true, false) == null) + assert(TypeCast.castTo("null", "_c", StringType, true, true, "null") == null) + assert(TypeCast.castTo("", "_c", StringType, true, false, "") == null) + assert(TypeCast.castTo("", "_c", StringType, true, true, "") == null) } }