Skip to content

Commit 02bf554

Browse files
HyukjinKwoncloud-fan
authored andcommitted
[SPARK-18772][SQL] Avoid unnecessary conversion try for special floats in JSON
## What changes were proposed in this pull request? This PR is based on #16199 and extracts the valid change from #9759 to resolve SPARK-18772 This avoids additional conversion try with `toFloat` and `toDouble`. For avoiding additional conversions, please refer the codes below: **Before** ```scala scala> import org.apache.spark.sql.types._ import org.apache.spark.sql.types._ scala> spark.read.schema(StructType(Seq(StructField("a", DoubleType)))).option("mode", "FAILFAST").json(Seq("""{"a": "nan"}""").toDS).show() 17/05/12 11:30:41 ERROR Executor: Exception in task 0.0 in stage 2.0 (TID 2) java.lang.NumberFormatException: For input string: "nan" ... ``` **After** ```scala scala> import org.apache.spark.sql.types._ import org.apache.spark.sql.types._ scala> spark.read.schema(StructType(Seq(StructField("a", DoubleType)))).option("mode", "FAILFAST").json(Seq("""{"a": "nan"}""").toDS).show() 17/05/12 11:44:30 ERROR Executor: Exception in task 0.0 in stage 0.0 (TID 0) java.lang.RuntimeException: Cannot parse nan as DoubleType. ... ``` ## How was this patch tested? Unit tests added in `JsonSuite`. Closes #16199 Author: hyukjinkwon <[email protected]> Author: Nathan Howell <[email protected]> Closes #17956 from HyukjinKwon/SPARK-18772. (cherry picked from commit 3f98375) Signed-off-by: Wenchen Fan <[email protected]>
1 parent d99165b commit 02bf554

File tree

2 files changed

+50
-21
lines changed

2 files changed

+50
-21
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.sql.catalyst.json
1919

2020
import java.io.ByteArrayOutputStream
21-
import java.util.Locale
2221

2322
import scala.collection.mutable.ArrayBuffer
2423
import scala.util.Try
@@ -126,16 +125,11 @@ class JacksonParser(
126125

127126
case VALUE_STRING =>
128127
// Special case handling for NaN and Infinity.
129-
val value = parser.getText
130-
val lowerCaseValue = value.toLowerCase(Locale.ROOT)
131-
if (lowerCaseValue.equals("nan") ||
132-
lowerCaseValue.equals("infinity") ||
133-
lowerCaseValue.equals("-infinity") ||
134-
lowerCaseValue.equals("inf") ||
135-
lowerCaseValue.equals("-inf")) {
136-
value.toFloat
137-
} else {
138-
throw new RuntimeException(s"Cannot parse $value as FloatType.")
128+
parser.getText match {
129+
case "NaN" => Float.NaN
130+
case "Infinity" => Float.PositiveInfinity
131+
case "-Infinity" => Float.NegativeInfinity
132+
case other => throw new RuntimeException(s"Cannot parse $other as FloatType.")
139133
}
140134
}
141135

@@ -146,16 +140,11 @@ class JacksonParser(
146140

147141
case VALUE_STRING =>
148142
// Special case handling for NaN and Infinity.
149-
val value = parser.getText
150-
val lowerCaseValue = value.toLowerCase(Locale.ROOT)
151-
if (lowerCaseValue.equals("nan") ||
152-
lowerCaseValue.equals("infinity") ||
153-
lowerCaseValue.equals("-infinity") ||
154-
lowerCaseValue.equals("inf") ||
155-
lowerCaseValue.equals("-inf")) {
156-
value.toDouble
157-
} else {
158-
throw new RuntimeException(s"Cannot parse $value as DoubleType.")
143+
parser.getText match {
144+
case "NaN" => Double.NaN
145+
case "Infinity" => Double.PositiveInfinity
146+
case "-Infinity" => Double.NegativeInfinity
147+
case other => throw new RuntimeException(s"Cannot parse $other as DoubleType.")
159148
}
160149
}
161150

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.json
2020
import java.io.{File, StringWriter}
2121
import java.nio.charset.StandardCharsets
2222
import java.sql.{Date, Timestamp}
23+
import java.util.Locale
2324

2425
import com.fasterxml.jackson.core.JsonFactory
2526
import org.apache.hadoop.fs.{Path, PathFilter}
@@ -1978,4 +1979,43 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
19781979
assert(errMsg.startsWith("The field for corrupt records must be string type and nullable"))
19791980
}
19801981
}
1982+
1983+
test("SPARK-18772: Parse special floats correctly") {
1984+
val jsons = Seq(
1985+
"""{"a": "NaN"}""",
1986+
"""{"a": "Infinity"}""",
1987+
"""{"a": "-Infinity"}""")
1988+
1989+
// positive cases
1990+
val checks: Seq[Double => Boolean] = Seq(
1991+
_.isNaN,
1992+
_.isPosInfinity,
1993+
_.isNegInfinity)
1994+
1995+
Seq(FloatType, DoubleType).foreach { dt =>
1996+
jsons.zip(checks).foreach { case (json, check) =>
1997+
val ds = spark.read
1998+
.schema(StructType(Seq(StructField("a", dt))))
1999+
.json(Seq(json).toDS())
2000+
.select($"a".cast(DoubleType)).as[Double]
2001+
assert(check(ds.first()))
2002+
}
2003+
}
2004+
2005+
// negative cases
2006+
Seq(FloatType, DoubleType).foreach { dt =>
2007+
val lowerCasedJsons = jsons.map(_.toLowerCase(Locale.ROOT))
2008+
// The special floats are case-sensitive so these cases below throw exceptions.
2009+
lowerCasedJsons.foreach { lowerCasedJson =>
2010+
val e = intercept[SparkException] {
2011+
spark.read
2012+
.option("mode", "FAILFAST")
2013+
.schema(StructType(Seq(StructField("a", dt))))
2014+
.json(Seq(lowerCasedJson).toDS())
2015+
.collect()
2016+
}
2017+
assert(e.getMessage.contains("Cannot parse"))
2018+
}
2019+
}
2020+
}
19812021
}

0 commit comments

Comments
 (0)