Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/main/scala/com/databricks/spark/csv/CsvRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ case class CsvRelation protected[spark] (
index = 0
while (index < schemaFields.length) {
val field = schemaFields(index)
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable,
treatEmptyValuesAsNulls, nullValue, simpleDateFormatter)
rowArray(index) = TypeCast.castTo(tokens(index), field.name, field.dataType,
field.nullable, treatEmptyValuesAsNulls, nullValue, simpleDateFormatter)
index = index + 1
}
Some(Row.fromSeq(rowArray))
Expand Down Expand Up @@ -195,6 +195,7 @@ case class CsvRelation protected[spark] (
val field = schemaFields(index)
rowArray(subIndex) = TypeCast.castTo(
indexSafeTokens(index),
field.name,
field.dataType,
field.nullable,
treatEmptyValuesAsNulls,
Expand Down
32 changes: 27 additions & 5 deletions src/main/scala/com/databricks/spark/csv/util/TypeCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,43 @@ object TypeCast {
* Currently we do not support complex types (ArrayType, MapType, StructType).
*
* For string types, this is simply the datum. For other types.
* For other nullable types, this is null if the string datum is empty.
* For other nullable types, returns null if it is null or equals to the value specified
* in `nullValue` option. If `treatEmptyValuesAsNulls` is set, it also returns null for
* empty strings.
*
* @param datum string value
* @param castType SparkSQL type
* @param name field name in schema.
* @param castType data type to cast `datum` into.
* @param nullable nullability for the field.
* @param treatEmptyValuesAsNulls a flag to indicate if empty strings should be treated as null.
* @param nullValue the string value that represents null.
* @param dateFormatter date formatter that uses to parse dates and timestamps.
*/
private[csv] def castTo(
datum: String,
name: String,
castType: DataType,
nullable: Boolean = true,
treatEmptyValuesAsNulls: Boolean = false,
nullValue: String = "",
dateFormatter: SimpleDateFormat = null): Any = {
if (datum == nullValue &&
nullable ||
(treatEmptyValuesAsNulls && datum == "")){

// If the given column is not nullable, we simply fall back to normal string
// rather than returning null for backwards compatibility. Note that this case is
// different with Spark's internal CSV datasource which throws an exception in this case.
val isNullValueMatched = datum == nullValue && nullable

// If `treatEmptyValuesAsNulls` is enabled, treat empty strings as nulls.
val shouldTreatEmptyValuesAsNulls = treatEmptyValuesAsNulls && datum == ""

// `datum` can be null when some tokens were inserted when permissive modes via `PrunedScan`.
// In this case, they are treated as nulls.
val isNullDatum = datum == null
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, the actual cases added are,

  • datum == null && !nullable : returns null. (it was NEP or java.lang.NumberFormatException: null before).
  • datum == null && !nullable : throws an exception. (it was NEP or java.lang.NumberFormatException: null before).
  • treatEmptyValuesAsNulls && datum == && !nullable : throws an exception. (it was trying to set null to non-nullable field.

These should be non-behaviour changes.


if (isNullValueMatched || shouldTreatEmptyValuesAsNulls || isNullDatum) {
if (!nullable) {
throw new RuntimeException(s"null value found but field $name is not nullable.")
}
null
} else {
castType match {
Expand Down
41 changes: 41 additions & 0 deletions src/test/scala/com/databricks/spark/csv/CsvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,47 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
assert(ages.schema.fields(2).dataType === DoubleType)
assert(ages.schema.fields(3).dataType === StringType)
}

test("Should read null properly when schema is lager than parsed tokens") {
val schema = StructType(
StructField("bool", BooleanType, true) ::
StructField("nullcol", IntegerType, true) ::
StructField("nullcol1", IntegerType, true) :: Nil)

// Selects only bool and nullcol to use `PrunedScan` interface. If we select
// all, it falls back to `TableScan`.
val results = new CsvParser()
.withSchema(schema)
.withUseHeader(true)
.withParserLib(parserLib)
.withParseMode(ParseModes.PERMISSIVE_MODE)
.csvFile(sqlContext, boolFile)
.select("bool", "nullcol")
.collect()

val expected = Seq(Row(true, null), Row(false, null), Row(false, null))
assert(results.length == expected.length)
assert(results.toSet == expected.toSet)

// Negative case
val nonNullableSchema = StructType(
StructField("bool", BooleanType, false) ::
StructField("nullcol", IntegerType, false) ::
StructField("nullcol1", IntegerType, false) :: Nil)

val exception = intercept[SparkException] {
new CsvParser()
.withSchema(nonNullableSchema)
.withUseHeader(true)
.withParserLib(parserLib)
.withParseMode(ParseModes.PERMISSIVE_MODE)
.csvFile(sqlContext, boolFile)
.select("bool", "nullcol")
.collect()
}

assert(exception.getMessage.contains("null value found but field nullcol"))
}
}

class CsvSuite extends AbstractCsvSuite {
Expand Down
76 changes: 43 additions & 33 deletions src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TypeCastSuite extends FunSuite {
val decimalType = new DecimalType(None)

stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) =>
assert(TypeCast.castTo(strVal, decimalType) === new BigDecimal(decimalVal.toString))
assert(TypeCast.castTo(strVal, "_c", decimalType) === new BigDecimal(decimalVal.toString))
}
}

Expand Down Expand Up @@ -61,62 +61,72 @@ class TypeCastSuite extends FunSuite {
}

test("Nullable types are handled") {
assert(TypeCast.castTo("-", ByteType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", ShortType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", IntegerType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", LongType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", FloatType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", DoubleType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", BooleanType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", TimestampType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", DateType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", StringType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", "_c", ByteType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", "_c", ShortType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", "_c", IntegerType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", "_c", LongType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", "_c", FloatType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", "_c", DoubleType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", "_c", BooleanType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", "_c", TimestampType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", "_c", DateType, nullable = true, nullValue = "-") == null)
assert(TypeCast.castTo("-", "_c", StringType, nullable = false, nullValue = "-") == "-")
}

test("Throws exception for empty string with non null type") {
val exception = intercept[NumberFormatException]{
TypeCast.castTo("", IntegerType, nullable = false)
val exception1 = intercept[NumberFormatException] {
TypeCast.castTo("", "_c", IntegerType, nullable = false)
}
assert(exception.getMessage.contains("For input string: \"\""))
assert(exception1.getMessage.contains("For input string: \"\""))

val exception2 = intercept[RuntimeException] {
TypeCast.castTo("", "_c", StringType, nullable = false, treatEmptyValuesAsNulls = true)
}
assert(exception2.getMessage.contains("null value found but field _c is not nullable"))

val exception3 = intercept[RuntimeException] {
TypeCast.castTo(null, "_c", StringType, nullable = false)
}
assert(exception3.getMessage.contains("null value found but field _c is not nullable"))
}

test("Types are cast correctly") {
assert(TypeCast.castTo("10", ByteType) == 10)
assert(TypeCast.castTo("10", ShortType) == 10)
assert(TypeCast.castTo("10", IntegerType) == 10)
assert(TypeCast.castTo("10", LongType) == 10)
assert(TypeCast.castTo("1.00", FloatType) == 1.0)
assert(TypeCast.castTo("1.00", DoubleType) == 1.0)
assert(TypeCast.castTo("true", BooleanType) == true)
assert(TypeCast.castTo("10", "_c", ByteType) == 10)
assert(TypeCast.castTo("10", "_c", ShortType) == 10)
assert(TypeCast.castTo("10", "_c", IntegerType) == 10)
assert(TypeCast.castTo("10", "_c", LongType) == 10)
assert(TypeCast.castTo("1.00", "_c", FloatType) == 1.0)
assert(TypeCast.castTo("1.00", "_c", DoubleType) == 1.0)
assert(TypeCast.castTo("true", "_c", BooleanType) == true)
val timestamp = "2015-01-01 00:00:00"
assert(TypeCast.castTo(timestamp, TimestampType) == Timestamp.valueOf(timestamp))
assert(TypeCast.castTo("2015-01-01", DateType) == Date.valueOf("2015-01-01"))
assert(TypeCast.castTo(timestamp, "_c", TimestampType) == Timestamp.valueOf(timestamp))
assert(TypeCast.castTo("2015-01-01", "_c", DateType) == Date.valueOf("2015-01-01"))

val dateFormatter = new SimpleDateFormat("dd/MM/yyyy hh:mm")
val customTimestamp = "31/01/2015 00:00"
// `SimpleDateFormat.parse` returns `java.util.Date`. This needs to be converted
// to `java.sql.Date`
val expectedDate = new Date(dateFormatter.parse("31/01/2015 00:00").getTime)
val expectedTimestamp = new Timestamp(expectedDate.getTime)
assert(TypeCast.castTo(customTimestamp, TimestampType, dateFormatter = dateFormatter)
assert(TypeCast.castTo(customTimestamp, "_c", TimestampType, dateFormatter = dateFormatter)
== expectedTimestamp)
assert(TypeCast.castTo(customTimestamp, DateType, dateFormatter = dateFormatter) ==
assert(TypeCast.castTo(customTimestamp, "_c", DateType, dateFormatter = dateFormatter) ==
expectedDate)
}

test("Float and Double Types are cast correctly with Locale") {
val locale : Locale = new Locale("fr", "FR")
Locale.setDefault(locale)
assert(TypeCast.castTo("1,00", FloatType) == 1.0)
assert(TypeCast.castTo("1,00", DoubleType) == 1.0)
assert(TypeCast.castTo("1,00", "_c", FloatType) == 1.0)
assert(TypeCast.castTo("1,00", "_c", DoubleType) == 1.0)
}

test("Can handle mapping user specified nullValues") {
assert(TypeCast.castTo("null", StringType, true, false, "null") == null)
assert(TypeCast.castTo("\\N", ByteType, true, false, "\\N") == null)
assert(TypeCast.castTo("", ShortType, true, false) == null)
assert(TypeCast.castTo("null", StringType, true, true, "null") == null)
assert(TypeCast.castTo("", StringType, true, false, "") == null)
assert(TypeCast.castTo("", StringType, true, true, "") == null)
assert(TypeCast.castTo("null", "_c", StringType, true, false, "null") == null)
assert(TypeCast.castTo("\\N", "_c", ByteType, true, false, "\\N") == null)
assert(TypeCast.castTo("", "_c", ShortType, true, false) == null)
assert(TypeCast.castTo("null", "_c", StringType, true, true, "null") == null)
assert(TypeCast.castTo("", "_c", StringType, true, false, "") == null)
assert(TypeCast.castTo("", "_c", StringType, true, true, "") == null)
}
}