Skip to content

Commit f292cc5

Browse files
committed
Pull null handling into TypeCast, add tests
1 parent 330d615 commit f292cc5

File tree

3 files changed

+43
-25
lines changed

3 files changed

+43
-25
lines changed

src/main/scala/com/databricks/spark/csv/CsvRelation.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,8 @@ case class CsvRelation protected[spark] (
108108
try {
109109
index = 0
110110
while (index < schemaFields.length) {
111-
rowArray(index) = if (schemaFields(index).nullable && tokens(index) == ""){
112-
schemaFields(index).dataType match {
113-
case StringType => ""
114-
case _ => null
115-
}
116-
} else {
117-
TypeCast.castTo(tokens(index), schemaFields(index).dataType)
118-
}
111+
val field = schemaFields(index)
112+
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable)
119113
index = index + 1
120114
}
121115
Some(Row.fromSeq(rowArray))

src/main/scala/com/databricks/spark/csv/util/TypeCast.scala

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,33 @@ object TypeCast {
2929
* Casts given string datum to specified type.
3030
* Currently we do not support complex types (ArrayType, MapType, StructType).
3131
*
32+
* For string types, this is simply the datum. For other types.
33+
* For other nullable types, this is null if the string datum is empty.
34+
*
3235
* @param datum string value
3336
* @param castType SparkSQL type
3437
*/
35-
private[csv] def castTo(datum: String, castType: DataType): Any = {
36-
castType match {
37-
case _: ByteType => datum.toByte
38-
case _: ShortType => datum.toShort
39-
case _: IntegerType => datum.toInt
40-
case _: LongType => datum.toLong
41-
case _: FloatType => datum.toFloat
42-
case _: DoubleType => datum.toDouble
43-
case _: BooleanType => datum.toBoolean
44-
case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
45-
// TODO(hossein): would be good to support other common timestamp formats
46-
case _: TimestampType => Timestamp.valueOf(datum)
47-
// TODO(hossein): would be good to support other common date formats
48-
case _: DateType => Date.valueOf(datum)
49-
case _: StringType => datum
50-
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
38+
private[csv] def castTo(datum: String, castType: DataType, nullable: Boolean = true): Any = {
39+
if (castType.isInstanceOf[StringType]){
40+
datum
41+
} else if (nullable && datum == ""){
42+
null
43+
} else {
44+
castType match {
45+
case _: ByteType => datum.toByte
46+
case _: ShortType => datum.toShort
47+
case _: IntegerType => datum.toInt
48+
case _: LongType => datum.toLong
49+
case _: FloatType => datum.toFloat
50+
case _: DoubleType => datum.toDouble
51+
case _: BooleanType => datum.toBoolean
52+
case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
53+
// TODO(hossein): would be good to support other common timestamp formats
54+
case _: TimestampType => Timestamp.valueOf(datum)
55+
// TODO(hossein): would be good to support other common date formats
56+
case _: DateType => Date.valueOf(datum)
57+
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
58+
}
5159
}
5260
}
5361

src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import java.math.BigDecimal
1919

2020
import org.scalatest.FunSuite
2121

22-
import org.apache.spark.sql.types.DecimalType
22+
import org.apache.spark.sql.types.{StringType, IntegerType, DecimalType}
2323

2424
class TypeCastSuite extends FunSuite {
2525

@@ -56,4 +56,20 @@ class TypeCastSuite extends FunSuite {
5656
}
5757
assert(exception.getMessage.contains("Unsupported special character for delimiter"))
5858
}
59+
60+
test("Nullable types are handled"){
61+
assert(TypeCast.castTo("", IntegerType, nullable = true) == null)
62+
}
63+
64+
test("String type should always return the same as the input"){
65+
assert(TypeCast.castTo("", StringType, nullable = true) == "")
66+
assert(TypeCast.castTo("", StringType, nullable = false) == "")
67+
}
68+
69+
test("Throws exception for empty string with non null type"){
70+
val exception = intercept[NumberFormatException]{
71+
TypeCast.castTo("", IntegerType, nullable = false)
72+
}
73+
assert(exception.getMessage.contains("For input string: \"\""))
74+
}
5975
}

0 commit comments

Comments
 (0)