Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/main/scala/com/databricks/spark/csv/CsvRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ case class CsvRelation protected[spark] (
try {
index = 0
while (index < schemaFields.length) {
rowArray(index) = TypeCast.castTo(tokens(index), schemaFields(index).dataType)
val field = schemaFields(index)
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable)
index = index + 1
}
Some(Row.fromSeq(rowArray))
Expand Down
39 changes: 23 additions & 16 deletions src/main/scala/com/databricks/spark/csv/util/TypeCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,32 @@ object TypeCast {
* Casts given string datum to specified type.
* Currently we do not support complex types (ArrayType, MapType, StructType).
*
* For string types, this is simply the datum. For other types.
* For other nullable types, this is null if the string datum is empty.
*
* @param datum string value
* @param castType SparkSQL type
*/
private[csv] def castTo(datum: String, castType: DataType): Any = {
castType match {
case _: ByteType => datum.toByte
case _: ShortType => datum.toShort
case _: IntegerType => datum.toInt
case _: LongType => datum.toLong
case _: FloatType => datum.toFloat
case _: DoubleType => datum.toDouble
case _: BooleanType => datum.toBoolean
case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
// TODO(hossein): would be good to support other common timestamp formats
case _: TimestampType => Timestamp.valueOf(datum)
// TODO(hossein): would be good to support other common date formats
case _: DateType => Date.valueOf(datum)
case _: StringType => datum
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
private[csv] def castTo(datum: String, castType: DataType, nullable: Boolean = true): Any = {
if (datum == "" && nullable && !castType.isInstanceOf[StringType]){
null
} else {
castType match {
case _: ByteType => datum.toByte
case _: ShortType => datum.toShort
case _: IntegerType => datum.toInt
case _: LongType => datum.toLong
case _: FloatType => datum.toFloat
case _: DoubleType => datum.toDouble
case _: BooleanType => datum.toBoolean
case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
// TODO(hossein): would be good to support other common timestamp formats
case _: TimestampType => Timestamp.valueOf(datum)
// TODO(hossein): would be good to support other common date formats
case _: DateType => Date.valueOf(datum)
case _: StringType => datum
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/test/resources/null-numbers.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name,age
alice,35
bob,
,24
16 changes: 16 additions & 0 deletions src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class CsvFastSuite extends FunSuite {
val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv"
val carsTsvFile = "src/test/resources/cars.tsv"
val carsAltFile = "src/test/resources/cars-alternative.csv"
val nullNumbersFile = "src/test/resources/null-numbers.csv"
val emptyFile = "src/test/resources/empty.csv"
val escapeFile = "src/test/resources/escape.csv"
val tempEmptyDir = "target/test/empty2/"
Expand Down Expand Up @@ -387,4 +388,19 @@ class CsvFastSuite extends FunSuite {
assert(results.first().getInt(0) === 1997)

}

test("DSL test nullable fields"){

val results = new CsvParser()
.withSchema(StructType(List(StructField("name", StringType, false), StructField("age", IntegerType, true))))
.withUseHeader(true)
.withParserLib("univocity")
.csvFile(TestSQLContext, nullNumbersFile)
.collect()

assert(results.head.toSeq == Seq("alice", 35))
assert(results(1).toSeq == Seq("bob", null))
assert(results(2).toSeq == Seq("", 24))

}
}
15 changes: 15 additions & 0 deletions src/test/scala/com/databricks/spark/csv/CsvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class CsvSuite extends FunSuite {
val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv"
val carsTsvFile = "src/test/resources/cars.tsv"
val carsAltFile = "src/test/resources/cars-alternative.csv"
val nullNumbersFile = "src/test/resources/null-numbers.csv"
val emptyFile = "src/test/resources/empty.csv"
val escapeFile = "src/test/resources/escape.csv"
val tempEmptyDir = "target/test/empty/"
Expand Down Expand Up @@ -392,4 +393,18 @@ class CsvSuite extends FunSuite {
assert(results.first().getInt(0) === 1997)

}

test("DSL test nullable fields"){

val results = new CsvParser()
.withSchema(StructType(List(StructField("name", StringType, false), StructField("age", IntegerType, true))))
.withUseHeader(true)
.csvFile(TestSQLContext, nullNumbersFile)
.collect()

assert(results.head.toSeq == Seq("alice", 35))
assert(results(1).toSeq == Seq("bob", null))
assert(results(2).toSeq == Seq("", 24))

}
}
32 changes: 31 additions & 1 deletion src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
package com.databricks.spark.csv.util

import java.math.BigDecimal
import java.sql.{Date, Timestamp}

import org.scalatest.FunSuite

import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.types._

class TypeCastSuite extends FunSuite {

Expand Down Expand Up @@ -56,4 +57,33 @@ class TypeCastSuite extends FunSuite {
}
assert(exception.getMessage.contains("Unsupported special character for delimiter"))
}

test("Nullable types are handled"){
assert(TypeCast.castTo("", IntegerType, nullable = true) == null)
}

test("String type should always return the same as the input"){
assert(TypeCast.castTo("", StringType, nullable = true) == "")
assert(TypeCast.castTo("", StringType, nullable = false) == "")
}

test("Throws exception for empty string with non null type"){
val exception = intercept[NumberFormatException]{
TypeCast.castTo("", IntegerType, nullable = false)
}
assert(exception.getMessage.contains("For input string: \"\""))
}

test("Types are cast correctly"){
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding these!

assert(TypeCast.castTo("10", ByteType) == 10)
assert(TypeCast.castTo("10", ShortType) == 10)
assert(TypeCast.castTo("10", IntegerType) == 10)
assert(TypeCast.castTo("10", LongType) == 10)
assert(TypeCast.castTo("1.00", FloatType) == 1.0)
assert(TypeCast.castTo("1.00", DoubleType) == 1.0)
assert(TypeCast.castTo("true", BooleanType) == true)
val timestamp = "2015-01-01 00:00:00"
assert(TypeCast.castTo(timestamp, TimestampType) == Timestamp.valueOf(timestamp))
assert(TypeCast.castTo("2015-01-01", DateType) == Date.valueOf("2015-01-01"))
}
}