Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -221,18 +221,27 @@ private[csv] object CSVTypeCast {
* Currently we do not support complex types (ArrayType, MapType, StructType).
*
* For string types, this is simply the datum. For other types.
* For other nullable types, this is null if the string datum is empty.
* For other nullable types, returns null if it is null or equals to the value specified
* in `nullValue` option.
*
* @param datum string value
* @param castType SparkSQL type
* @param name field name in schema.
* @param castType data type to cast `datum` into.
* @param nullable nullability for the field.
* @param options CSV options.
*/
def castTo(
datum: String,
name: String,
castType: DataType,
nullable: Boolean = true,
options: CSVOptions = CSVOptions()): Any = {

if (nullable && datum == options.nullValue) {
// datum can be null if the number of fields found is less than the length of the schema
if (datum == options.nullValue || datum == null) {
if (!nullable) {
throw new RuntimeException(s"null value found but field $name is not nullable.")
}
null
} else {
castType match {
Expand Down Expand Up @@ -281,7 +290,7 @@ private[csv] object CSVTypeCast {
DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime)
}
case _: StringType => UTF8String.fromString(datum)
case udt: UserDefinedType[_] => castTo(datum, udt.sqlType, nullable, options)
case udt: UserDefinedType[_] => castTo(datum, name, udt.sqlType, nullable, options)
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ object CSVRelation extends Logging {
// value is not stored in the row.
val value = CSVTypeCast.castTo(
indexSafeTokens(index),
field.name,
field.dataType,
field.nullable,
params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -890,4 +890,19 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
}
}

test("load null when the schema is larger than parsed tokens ") {
withTempPath { path =>
Seq("1").toDF().write.text(path.getAbsolutePath)
val schema = StructType(
StructField("a", IntegerType, true) ::
StructField("b", IntegerType, true) :: Nil)
val df = spark.read
.schema(schema)
.option("header", "false")
.csv(path.getAbsolutePath)

checkAnswer(df, Row(1, null))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class CSVTypeCastSuite extends SparkFunSuite {

stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) =>
val decimalValue = new BigDecimal(decimalVal.toString)
assert(CSVTypeCast.castTo(strVal, decimalType) ===
assert(CSVTypeCast.castTo(strVal, "_1", decimalType) ===
Decimal(decimalValue, decimalType.precision, decimalType.scale))
}
}
Expand Down Expand Up @@ -67,97 +67,108 @@ class CSVTypeCastSuite extends SparkFunSuite {

test("Nullable types are handled") {
assertNull(
CSVTypeCast.castTo("-", ByteType, nullable = true, CSVOptions("nullValue", "-")))
CSVTypeCast.castTo("-", "_1", ByteType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", ShortType, nullable = true, CSVOptions("nullValue", "-")))
CSVTypeCast.castTo("-", "_1", ShortType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", IntegerType, nullable = true, CSVOptions("nullValue", "-")))
CSVTypeCast.castTo("-", "_1", IntegerType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", LongType, nullable = true, CSVOptions("nullValue", "-")))
CSVTypeCast.castTo("-", "_1", LongType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", FloatType, nullable = true, CSVOptions("nullValue", "-")))
CSVTypeCast.castTo("-", "_1", FloatType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", DoubleType, nullable = true, CSVOptions("nullValue", "-")))
CSVTypeCast.castTo("-", "_1", DoubleType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", BooleanType, nullable = true, CSVOptions("nullValue", "-")))
CSVTypeCast.castTo("-", "_1", BooleanType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", DecimalType.DoubleDecimal, true, CSVOptions("nullValue", "-")))
CSVTypeCast.castTo("-", "_1", DecimalType.DoubleDecimal, true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", TimestampType, nullable = true, CSVOptions("nullValue", "-")))
CSVTypeCast.castTo("-", "_1", TimestampType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", DateType, nullable = true, CSVOptions("nullValue", "-")))
CSVTypeCast.castTo("-", "_1", DateType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo("-", StringType, nullable = true, CSVOptions("nullValue", "-")))
CSVTypeCast.castTo("-", "_1", StringType, nullable = true, CSVOptions("nullValue", "-")))
assertNull(
CSVTypeCast.castTo(null, "_1", IntegerType, nullable = true, CSVOptions("nullValue", "-")))

// casting a null to not nullable field should throw an exception.
var message = intercept[RuntimeException] {
CSVTypeCast.castTo(null, "_1", IntegerType, nullable = false, CSVOptions("nullValue", "-"))
}.getMessage
assert(message.contains("null value found but field _1 is not nullable."))

message = intercept[RuntimeException] {
CSVTypeCast.castTo("-", "_1", StringType, nullable = false, CSVOptions("nullValue", "-"))
}.getMessage
assert(message.contains("null value found but field _1 is not nullable."))
}

test("String type should also respect `nullValue`") {
assertNull(
CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions()))
assert(
CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions()) ==
UTF8String.fromString(""))
CSVTypeCast.castTo("", "_1", StringType, nullable = true, CSVOptions()))

assert(
CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions("nullValue", "null")) ==
CSVTypeCast.castTo("", "_1", StringType, nullable = true, CSVOptions("nullValue", "null")) ==
UTF8String.fromString(""))
assert(
CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions("nullValue", "null")) ==
CSVTypeCast.castTo("", "_1", StringType, nullable = false, CSVOptions("nullValue", "null")) ==
UTF8String.fromString(""))

assertNull(
CSVTypeCast.castTo(null, StringType, nullable = true, CSVOptions("nullValue", "null")))
CSVTypeCast.castTo(null, "_1", StringType, nullable = true, CSVOptions("nullValue", "null")))
}

test("Throws exception for empty string with non null type") {
val exception = intercept[NumberFormatException]{
CSVTypeCast.castTo("", IntegerType, nullable = false, CSVOptions())
val exception = intercept[RuntimeException]{
CSVTypeCast.castTo("", "_1", IntegerType, nullable = false, CSVOptions())
}
assert(exception.getMessage.contains("For input string: \"\""))
assert(exception.getMessage.contains("null value found but field _1 is not nullable."))
}

test("Types are cast correctly") {
assert(CSVTypeCast.castTo("10", ByteType) == 10)
assert(CSVTypeCast.castTo("10", ShortType) == 10)
assert(CSVTypeCast.castTo("10", IntegerType) == 10)
assert(CSVTypeCast.castTo("10", LongType) == 10)
assert(CSVTypeCast.castTo("1.00", FloatType) == 1.0)
assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0)
assert(CSVTypeCast.castTo("true", BooleanType) == true)
assert(CSVTypeCast.castTo("10", "_1", ByteType) == 10)
assert(CSVTypeCast.castTo("10", "_1", ShortType) == 10)
assert(CSVTypeCast.castTo("10", "_1", IntegerType) == 10)
assert(CSVTypeCast.castTo("10", "_1", LongType) == 10)
assert(CSVTypeCast.castTo("1.00", "_1", FloatType) == 1.0)
assert(CSVTypeCast.castTo("1.00", "_1", DoubleType) == 1.0)
assert(CSVTypeCast.castTo("true", "_1", BooleanType) == true)

val timestampsOptions = CSVOptions("timestampFormat", "dd/MM/yyyy hh:mm")
val customTimestamp = "31/01/2015 00:00"
val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime
val castedTimestamp =
CSVTypeCast.castTo(customTimestamp, TimestampType, nullable = true, timestampsOptions)
CSVTypeCast.castTo(customTimestamp, "_1", TimestampType, nullable = true, timestampsOptions)
assert(castedTimestamp == expectedTime * 1000L)

val customDate = "31/01/2015"
val dateOptions = CSVOptions("dateFormat", "dd/MM/yyyy")
val expectedDate = dateOptions.dateFormat.parse(customDate).getTime
val castedDate = CSVTypeCast.castTo(customTimestamp, DateType, nullable = true, dateOptions)
val castedDate =
CSVTypeCast.castTo(customTimestamp, "_1", DateType, nullable = true, dateOptions)
assert(castedDate == DateTimeUtils.millisToDays(expectedDate))

val timestamp = "2015-01-01 00:00:00"
assert(CSVTypeCast.castTo(timestamp, TimestampType) ==
assert(CSVTypeCast.castTo(timestamp, "_1", TimestampType) ==
DateTimeUtils.stringToTime(timestamp).getTime * 1000L)
assert(CSVTypeCast.castTo("2015-01-01", DateType) ==
assert(CSVTypeCast.castTo("2015-01-01", "_1", DateType) ==
DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime))
}

test("Float and Double Types are cast without respect to platform default Locale") {
val originalLocale = Locale.getDefault
try {
Locale.setDefault(new Locale("fr", "FR"))
assert(CSVTypeCast.castTo("1,00", FloatType) == 100.0) // Would parse as 1.0 in fr-FR
assert(CSVTypeCast.castTo("1,00", DoubleType) == 100.0)
assert(CSVTypeCast.castTo("1,00", "_1", FloatType) == 100.0) // Would parse as 1.0 in fr-FR
assert(CSVTypeCast.castTo("1,00", "_1", DoubleType) == 100.0)
} finally {
Locale.setDefault(originalLocale)
}
}

test("Float NaN values are parsed correctly") {
val floatVal: Float = CSVTypeCast.castTo(
"nn", FloatType, nullable = true, CSVOptions("nanValue", "nn")).asInstanceOf[Float]
"nn", "_1", FloatType, nullable = true, CSVOptions("nanValue", "nn")).asInstanceOf[Float]

// Java implements the IEEE-754 floating point standard which guarantees that any comparison
// against NaN will return false (except != which returns true)
Expand All @@ -166,32 +177,32 @@ class CSVTypeCastSuite extends SparkFunSuite {

test("Double NaN values are parsed correctly") {
val doubleVal: Double = CSVTypeCast.castTo(
"-", DoubleType, nullable = true, CSVOptions("nanValue", "-")).asInstanceOf[Double]
"-", "_1", DoubleType, nullable = true, CSVOptions("nanValue", "-")).asInstanceOf[Double]

assert(doubleVal.isNaN)
}

test("Float infinite values can be parsed") {
val floatVal1 = CSVTypeCast.castTo(
"max", FloatType, nullable = true, CSVOptions("negativeInf", "max")).asInstanceOf[Float]
"max", "_1", FloatType, nullable = true, CSVOptions("negativeInf", "max")).asInstanceOf[Float]

assert(floatVal1 == Float.NegativeInfinity)

val floatVal2 = CSVTypeCast.castTo(
"max", FloatType, nullable = true, CSVOptions("positiveInf", "max")).asInstanceOf[Float]
"max", "_1", FloatType, nullable = true, CSVOptions("positiveInf", "max")).asInstanceOf[Float]

assert(floatVal2 == Float.PositiveInfinity)
}

test("Double infinite values can be parsed") {
val doubleVal1 = CSVTypeCast.castTo(
"max", DoubleType, nullable = true, CSVOptions("negativeInf", "max")
"max", "_1", DoubleType, nullable = true, CSVOptions("negativeInf", "max")
).asInstanceOf[Double]

assert(doubleVal1 == Double.NegativeInfinity)

val doubleVal2 = CSVTypeCast.castTo(
"max", DoubleType, nullable = true, CSVOptions("positiveInf", "max")
"max", "_1", DoubleType, nullable = true, CSVOptions("positiveInf", "max")
).asInstanceOf[Double]

assert(doubleVal2 == Double.PositiveInfinity)
Expand Down