Skip to content

Commit 451cceb

Browse files
committed
Merge pull request #102 from dtpeacock/nullable_type_support
Support for nullable schema types
2 parents 63ebf1a + 2d24a52 commit 451cceb

File tree

6 files changed

+91
-18
lines changed

6 files changed

+91
-18
lines changed

src/main/scala/com/databricks/spark/csv/CsvRelation.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ case class CsvRelation protected[spark] (
108108
try {
109109
index = 0
110110
while (index < schemaFields.length) {
111-
rowArray(index) = TypeCast.castTo(tokens(index), schemaFields(index).dataType)
111+
val field = schemaFields(index)
112+
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable)
112113
index = index + 1
113114
}
114115
Some(Row.fromSeq(rowArray))

src/main/scala/com/databricks/spark/csv/util/TypeCast.scala

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,32 @@ object TypeCast {
2929
* Casts given string datum to specified type.
3030
* Currently we do not support complex types (ArrayType, MapType, StructType).
3131
*
32+
* For string types, this is simply the datum. For other types.
33+
* For other nullable types, this is null if the string datum is empty.
34+
*
3235
* @param datum string value
3336
* @param castType SparkSQL type
3437
*/
35-
private[csv] def castTo(datum: String, castType: DataType): Any = {
36-
castType match {
37-
case _: ByteType => datum.toByte
38-
case _: ShortType => datum.toShort
39-
case _: IntegerType => datum.toInt
40-
case _: LongType => datum.toLong
41-
case _: FloatType => datum.toFloat
42-
case _: DoubleType => datum.toDouble
43-
case _: BooleanType => datum.toBoolean
44-
case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
45-
// TODO(hossein): would be good to support other common timestamp formats
46-
case _: TimestampType => Timestamp.valueOf(datum)
47-
// TODO(hossein): would be good to support other common date formats
48-
case _: DateType => Date.valueOf(datum)
49-
case _: StringType => datum
50-
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
38+
private[csv] def castTo(datum: String, castType: DataType, nullable: Boolean = true): Any = {
39+
if (datum == "" && nullable && !castType.isInstanceOf[StringType]){
40+
null
41+
} else {
42+
castType match {
43+
case _: ByteType => datum.toByte
44+
case _: ShortType => datum.toShort
45+
case _: IntegerType => datum.toInt
46+
case _: LongType => datum.toLong
47+
case _: FloatType => datum.toFloat
48+
case _: DoubleType => datum.toDouble
49+
case _: BooleanType => datum.toBoolean
50+
case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
51+
// TODO(hossein): would be good to support other common timestamp formats
52+
case _: TimestampType => Timestamp.valueOf(datum)
53+
// TODO(hossein): would be good to support other common date formats
54+
case _: DateType => Date.valueOf(datum)
55+
case _: StringType => datum
56+
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
57+
}
5158
}
5259
}
5360

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
name,age
2+
alice,35
3+
bob,
4+
,24

src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class CsvFastSuite extends FunSuite {
3232
val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv"
3333
val carsTsvFile = "src/test/resources/cars.tsv"
3434
val carsAltFile = "src/test/resources/cars-alternative.csv"
35+
val nullNumbersFile = "src/test/resources/null-numbers.csv"
3536
val emptyFile = "src/test/resources/empty.csv"
3637
val escapeFile = "src/test/resources/escape.csv"
3738
val tempEmptyDir = "target/test/empty2/"
@@ -387,4 +388,19 @@ class CsvFastSuite extends FunSuite {
387388
assert(results.first().getInt(0) === 1997)
388389

389390
}
391+
392+
test("DSL test nullable fields"){
393+
394+
val results = new CsvParser()
395+
.withSchema(StructType(List(StructField("name", StringType, false), StructField("age", IntegerType, true))))
396+
.withUseHeader(true)
397+
.withParserLib("univocity")
398+
.csvFile(TestSQLContext, nullNumbersFile)
399+
.collect()
400+
401+
assert(results.head.toSeq == Seq("alice", 35))
402+
assert(results(1).toSeq == Seq("bob", null))
403+
assert(results(2).toSeq == Seq("", 24))
404+
405+
}
390406
}

src/test/scala/com/databricks/spark/csv/CsvSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class CsvSuite extends FunSuite {
3232
val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv"
3333
val carsTsvFile = "src/test/resources/cars.tsv"
3434
val carsAltFile = "src/test/resources/cars-alternative.csv"
35+
val nullNumbersFile = "src/test/resources/null-numbers.csv"
3536
val emptyFile = "src/test/resources/empty.csv"
3637
val escapeFile = "src/test/resources/escape.csv"
3738
val tempEmptyDir = "target/test/empty/"
@@ -392,4 +393,18 @@ class CsvSuite extends FunSuite {
392393
assert(results.first().getInt(0) === 1997)
393394

394395
}
396+
397+
test("DSL test nullable fields"){
398+
399+
val results = new CsvParser()
400+
.withSchema(StructType(List(StructField("name", StringType, false), StructField("age", IntegerType, true))))
401+
.withUseHeader(true)
402+
.csvFile(TestSQLContext, nullNumbersFile)
403+
.collect()
404+
405+
assert(results.head.toSeq == Seq("alice", 35))
406+
assert(results(1).toSeq == Seq("bob", null))
407+
assert(results(2).toSeq == Seq("", 24))
408+
409+
}
395410
}

src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
package com.databricks.spark.csv.util
1717

1818
import java.math.BigDecimal
19+
import java.sql.{Date, Timestamp}
1920

2021
import org.scalatest.FunSuite
2122

22-
import org.apache.spark.sql.types.DecimalType
23+
import org.apache.spark.sql.types._
2324

2425
class TypeCastSuite extends FunSuite {
2526

@@ -56,4 +57,33 @@ class TypeCastSuite extends FunSuite {
5657
}
5758
assert(exception.getMessage.contains("Unsupported special character for delimiter"))
5859
}
60+
61+
test("Nullable types are handled"){
62+
assert(TypeCast.castTo("", IntegerType, nullable = true) == null)
63+
}
64+
65+
test("String type should always return the same as the input"){
66+
assert(TypeCast.castTo("", StringType, nullable = true) == "")
67+
assert(TypeCast.castTo("", StringType, nullable = false) == "")
68+
}
69+
70+
test("Throws exception for empty string with non null type"){
71+
val exception = intercept[NumberFormatException]{
72+
TypeCast.castTo("", IntegerType, nullable = false)
73+
}
74+
assert(exception.getMessage.contains("For input string: \"\""))
75+
}
76+
77+
test("Types are cast correctly"){
78+
assert(TypeCast.castTo("10", ByteType) == 10)
79+
assert(TypeCast.castTo("10", ShortType) == 10)
80+
assert(TypeCast.castTo("10", IntegerType) == 10)
81+
assert(TypeCast.castTo("10", LongType) == 10)
82+
assert(TypeCast.castTo("1.00", FloatType) == 1.0)
83+
assert(TypeCast.castTo("1.00", DoubleType) == 1.0)
84+
assert(TypeCast.castTo("true", BooleanType) == true)
85+
val timestamp = "2015-01-01 00:00:00"
86+
assert(TypeCast.castTo(timestamp, TimestampType) == Timestamp.valueOf(timestamp))
87+
assert(TypeCast.castTo("2015-01-01", DateType) == Date.valueOf("2015-01-01"))
88+
}
5989
}

0 commit comments

Comments
 (0)