Skip to content

Commit ee152f3

Browse files
andy327falaki
authored andcommitted
Roundtrip null values of any type
This pull request adds functionality to spark-csv with the goal of having the ability to write null values to file and read them back out again as null. Two changes were made to enable this. First, since the `com.databricks.spark.csv` package previously had the null string hardcoded to "`null`" when saving to a csv file, this was changed to read the null token out of the passed in parameters map, from the value for "`nullToken`", enabling writing null values as empty strings by use of this option. The default is left to "`null`" to maintain the previous behavior of the library. Secondly, the `castTo` method from `com.databricks.spark.csv.util.TypeCast` had an impossible-to-reach case statement when the `castType` was an instance of `StringType`. As a result, it was not possible to read string values from file as null. This pull request adds a setting 'treatEmptyValuesAsNulls' that allows empty string values in fields that are marked as nullable to be read as null values, as expected. Again, the previous behavior is enabled by default, so this pull request only changes the behavior when `treatEmptyValuesAsNulls` is explicitly set to true. The appropriate changes to `CsvParser` and `CsvRelation` were made to include this new setting. Additionally, a unit test has been added to `CsvSuite` to test the ability to round-trip (both string and non-string) null values by writing nulls and reading them back out again as nulls. Author: Andres Perez <[email protected]> Closes #147 from andy327/feat-set-null-tokens.
1 parent ad11f75 commit ee152f3

File tree

7 files changed

+57
-5
lines changed

7 files changed

+57
-5
lines changed

build.sbt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ mimaDefaultSettings ++ Seq(
9595
ProblemFilters.excludePackage("com.databricks.spark.csv.CsvRelation"),
9696
ProblemFilters.excludePackage("com.databricks.spark.csv.util.InferSchema"),
9797
ProblemFilters.excludePackage("com.databricks.spark.sql.readers"),
98+
ProblemFilters.excludePackage("com.databricks.spark.csv.util.TypeCast"),
9899
// We allowed the private `CsvRelation` type to leak into the public method signature:
99100
ProblemFilters.exclude[IncompatibleResultTypeProblem](
100101
"com.databricks.spark.csv.DefaultSource.createRelation")

src/main/scala/com/databricks/spark/csv/CsvParser.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class CsvParser extends Serializable {
3535
private var parseMode: String = ParseModes.DEFAULT
3636
private var ignoreLeadingWhiteSpace: Boolean = false
3737
private var ignoreTrailingWhiteSpace: Boolean = false
38+
private var treatEmptyValuesAsNulls: Boolean = false
3839
private var parserLib: String = ParserLibs.DEFAULT
3940
private var charset: String = TextFile.DEFAULT_CHARSET.name()
4041
private var inferSchema: Boolean = false
@@ -84,6 +85,11 @@ class CsvParser extends Serializable {
8485
this
8586
}
8687

88+
def withTreatEmptyValuesAsNulls(treatAsNull: Boolean): CsvParser = {
89+
this.treatEmptyValuesAsNulls = treatAsNull
90+
this
91+
}
92+
8793
def withParserLib(parserLib: String): CsvParser = {
8894
this.parserLib = parserLib
8995
this
@@ -114,6 +120,7 @@ class CsvParser extends Serializable {
114120
parserLib,
115121
ignoreLeadingWhiteSpace,
116122
ignoreTrailingWhiteSpace,
123+
treatEmptyValuesAsNulls,
117124
schema,
118125
inferSchema)(sqlContext)
119126
sqlContext.baseRelationToDataFrame(relation)
@@ -132,6 +139,7 @@ class CsvParser extends Serializable {
132139
parserLib,
133140
ignoreLeadingWhiteSpace,
134141
ignoreTrailingWhiteSpace,
142+
treatEmptyValuesAsNulls,
135143
schema,
136144
inferSchema)(sqlContext)
137145
sqlContext.baseRelationToDataFrame(relation)

src/main/scala/com/databricks/spark/csv/CsvRelation.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ case class CsvRelation protected[spark] (
4343
parserLib: String,
4444
ignoreLeadingWhiteSpace: Boolean,
4545
ignoreTrailingWhiteSpace: Boolean,
46+
treatEmptyValuesAsNulls: Boolean,
4647
userSchema: StructType = null,
4748
inferCsvSchema: Boolean)(@transient val sqlContext: SQLContext)
4849
extends BaseRelation with TableScan with InsertableRelation {
@@ -113,7 +114,8 @@ case class CsvRelation protected[spark] (
113114
index = 0
114115
while (index < schemaFields.length) {
115116
val field = schemaFields(index)
116-
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable)
117+
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable,
118+
treatEmptyValuesAsNulls)
117119
index = index + 1
118120
}
119121
Some(Row.fromSeq(rowArray))

src/main/scala/com/databricks/spark/csv/DefaultSource.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@ class DefaultSource
112112
} else {
113113
throw new Exception("Ignore white space flag can be true or false")
114114
}
115+
val treatEmptyValuesAsNulls = parameters.getOrElse("treatEmptyValuesAsNulls", "false")
116+
val treatEmptyValuesAsNullsFlag = if (treatEmptyValuesAsNulls == "false") {
117+
false
118+
} else if (treatEmptyValuesAsNulls == "true") {
119+
true
120+
} else {
121+
throw new Exception("Treat empty values as null flag can be true or false")
122+
}
115123

116124
val charset = parameters.getOrElse("charset", TextFile.DEFAULT_CHARSET.name())
117125
// TODO validate charset?
@@ -137,6 +145,7 @@ class DefaultSource
137145
parserLib,
138146
ignoreLeadingWhiteSpaceFlag,
139147
ignoreTrailingWhiteSpaceFlag,
148+
treatEmptyValuesAsNullsFlag,
140149
schema,
141150
inferSchemaFlag)(sqlContext)
142151
}

src/main/scala/com/databricks/spark/csv/package.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ package object csv {
5252
parserLib = parserLib,
5353
ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace,
5454
ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace,
55+
treatEmptyValuesAsNulls = false,
5556
inferCsvSchema = inferSchema)(sqlContext)
5657
sqlContext.baseRelationToDataFrame(csvRelation)
5758
}
@@ -76,6 +77,7 @@ package object csv {
7677
parserLib = parserLib,
7778
ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace,
7879
ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace,
80+
treatEmptyValuesAsNulls = false,
7981
inferCsvSchema = inferSchema)(sqlContext)
8082
sqlContext.baseRelationToDataFrame(csvRelation)
8183
}
@@ -116,11 +118,13 @@ package object csv {
116118
case None => None
117119
}
118120

121+
val nullValue = parameters.getOrElse("nullValue", "null")
122+
119123
val csvFormatBase = CSVFormat.DEFAULT
120124
.withDelimiter(delimiterChar)
121125
.withEscape(escapeChar)
122126
.withSkipHeaderRecord(false)
123-
.withNullString("null")
127+
.withNullString(nullValue)
124128

125129
val csvFormat = quoteChar match {
126130
case Some(c) => csvFormatBase.withQuote(c)
@@ -139,7 +143,7 @@ package object csv {
139143
.withDelimiter(delimiterChar)
140144
.withEscape(escapeChar)
141145
.withSkipHeaderRecord(false)
142-
.withNullString("null")
146+
.withNullString(nullValue)
143147

144148
val csvFormat = quoteChar match {
145149
case Some(c) => csvFormatBase.withQuote(c)

src/main/scala/com/databricks/spark/csv/util/TypeCast.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@ object TypeCast {
3535
* @param datum string value
3636
* @param castType SparkSQL type
3737
*/
38-
private[csv] def castTo(datum: String, castType: DataType, nullable: Boolean = true): Any = {
39-
if (datum == "" && nullable && !castType.isInstanceOf[StringType]){
38+
private[csv] def castTo(
39+
datum: String,
40+
castType: DataType,
41+
nullable: Boolean = true,
42+
treatEmptyValuesAsNulls: Boolean = false): Any = {
43+
if (datum == "" && nullable && (!castType.isInstanceOf[StringType] || treatEmptyValuesAsNulls)){
4044
null
4145
} else {
4246
castType match {

src/test/scala/com/databricks/spark/csv/CsvSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,30 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
167167
assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt"))
168168
}
169169

170+
test("DSL test roundtrip nulls") {
171+
// Create temp directory
172+
TestUtils.deleteRecursively(new File(tempEmptyDir))
173+
new File(tempEmptyDir).mkdirs()
174+
val copyFilePath = tempEmptyDir + "null-numbers.csv"
175+
val agesSchema = StructType(List(StructField("name", StringType, true),
176+
StructField("age", IntegerType, true)))
177+
178+
val agesRows = Seq(Row("alice", 35), Row("bob", null), Row(null, 24))
179+
val agesRdd = sqlContext.sparkContext.parallelize(agesRows)
180+
val agesDf = sqlContext.createDataFrame(agesRdd, agesSchema)
181+
182+
agesDf.saveAsCsvFile(copyFilePath, Map("header" -> "true", "nullValue" -> ""))
183+
184+
val agesCopy = new CsvParser()
185+
.withSchema(agesSchema)
186+
.withUseHeader(true)
187+
.withTreatEmptyValuesAsNulls(true)
188+
.withParserLib(parserLib)
189+
.csvFile(sqlContext, copyFilePath)
190+
191+
assert(agesCopy.count == agesRows.size)
192+
assert(agesCopy.collect.toSet == agesRows.toSet)
193+
}
170194

171195
test("DSL test with alternative delimiter and quote") {
172196
val results = new CsvParser()

0 commit comments

Comments
 (0)