Skip to content

Commit 7a83d71

Browse files
MaxGekkcloud-fan
authored andcommitted
[SPARK-26163][SQL] Parsing decimals from JSON using locale
## What changes were proposed in this pull request? In the PR, I propose using of the locale option to parse (and infer) decimals from JSON input. After the changes, `JacksonParser` converts input string to `BigDecimal` and to Spark's Decimal by using `java.text.DecimalFormat`. New behaviour can be switched off via SQL config `spark.sql.legacy.decimalParsing.enabled`. ## How was this patch tested? Added 2 tests to `JsonExpressionsSuite` for the `en-US`, `ko-KR`, `ru-RU`, `de-DE` locales: - Inferring decimal type using locale from JSON field values - Converting JSON field values to specified decimal type using the locales. Closes #23132 from MaxGekk/json-decimal-parsing-locale. Lead-authored-by: Maxim Gekk <[email protected]> Co-authored-by: Maxim Gekk <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 8bfea86 commit 7a83d71

File tree

7 files changed

+132
-52
lines changed

7 files changed

+132
-52
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import java.text.{DecimalFormat, DecimalFormatSymbols, ParsePosition}
21+
import java.util.Locale
22+
2023
import org.apache.spark.sql.AnalysisException
2124
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
2225
import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType}
@@ -83,4 +86,22 @@ object ExprUtils {
8386
}
8487
}
8588
}
89+
90+
def getDecimalParser(locale: Locale): String => java.math.BigDecimal = {
91+
if (locale == Locale.US) { // Special handling the default locale for backward compatibility
92+
(s: String) => new java.math.BigDecimal(s.replaceAll(",", ""))
93+
} else {
94+
val decimalFormat = new DecimalFormat("", new DecimalFormatSymbols(locale))
95+
decimalFormat.setParseBigDecimal(true)
96+
(s: String) => {
97+
val pos = new ParsePosition(0)
98+
val result = decimalFormat.parse(s, pos).asInstanceOf[java.math.BigDecimal]
99+
if (pos.getIndex() != s.length() || pos.getErrorIndex() != -1) {
100+
throw new IllegalArgumentException("Cannot parse any decimal");
101+
} else {
102+
result
103+
}
104+
}
105+
}
106+
}
86107
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,10 @@ import scala.util.parsing.combinator.RegexParsers
2323

2424
import com.fasterxml.jackson.core._
2525

26-
import org.apache.spark.sql.AnalysisException
2726
import org.apache.spark.sql.catalyst.InternalRow
2827
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2928
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
3029
import org.apache.spark.sql.catalyst.json._
31-
import org.apache.spark.sql.catalyst.json.JsonInferSchema.inferField
3230
import org.apache.spark.sql.catalyst.util._
3331
import org.apache.spark.sql.internal.SQLConf
3432
import org.apache.spark.sql.types._
@@ -775,6 +773,9 @@ case class SchemaOfJson(
775773
factory
776774
}
777775

776+
@transient
777+
private lazy val jsonInferSchema = new JsonInferSchema(jsonOptions)
778+
778779
@transient
779780
private lazy val json = child.eval().asInstanceOf[UTF8String]
780781

@@ -787,7 +788,7 @@ case class SchemaOfJson(
787788
override def eval(v: InternalRow): Any = {
788789
val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser =>
789790
parser.nextToken()
790-
inferField(parser, jsonOptions)
791+
jsonInferSchema.inferField(parser)
791792
}
792793

793794
UTF8String.fromString(dt.catalogString)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging
2929
import org.apache.spark.sql.catalyst.InternalRow
3030
import org.apache.spark.sql.catalyst.expressions._
3131
import org.apache.spark.sql.catalyst.util._
32+
import org.apache.spark.sql.internal.SQLConf
3233
import org.apache.spark.sql.types._
3334
import org.apache.spark.unsafe.types.UTF8String
3435
import org.apache.spark.util.Utils
@@ -135,6 +136,8 @@ class JacksonParser(
135136
}
136137
}
137138

139+
private val decimalParser = ExprUtils.getDecimalParser(options.locale)
140+
138141
/**
139142
* Create a converter which converts the JSON documents held by the `JsonParser`
140143
* to a value according to a desired schema.
@@ -261,6 +264,9 @@ class JacksonParser(
261264
(parser: JsonParser) => parseJsonToken[Decimal](parser, dataType) {
262265
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) =>
263266
Decimal(parser.getDecimalValue, dt.precision, dt.scale)
267+
case VALUE_STRING if parser.getTextLength >= 1 =>
268+
val bigDecimal = decimalParser(parser.getText)
269+
Decimal(bigDecimal, dt.precision, dt.scale)
264270
}
265271

266272
case st: StructType =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,23 @@ package org.apache.spark.sql.catalyst.json
1919

2020
import java.util.Comparator
2121

22+
import scala.util.control.Exception.allCatch
23+
2224
import com.fasterxml.jackson.core._
2325

2426
import org.apache.spark.SparkException
2527
import org.apache.spark.rdd.RDD
2628
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
29+
import org.apache.spark.sql.catalyst.expressions.ExprUtils
2730
import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
2831
import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode}
2932
import org.apache.spark.sql.internal.SQLConf
3033
import org.apache.spark.sql.types._
3134
import org.apache.spark.util.Utils
3235

33-
private[sql] object JsonInferSchema {
36+
private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable {
37+
38+
private val decimalParser = ExprUtils.getDecimalParser(options.locale)
3439

3540
/**
3641
* Infer the type of a collection of json records in three stages:
@@ -40,21 +45,20 @@ private[sql] object JsonInferSchema {
4045
*/
4146
def infer[T](
4247
json: RDD[T],
43-
configOptions: JSONOptions,
4448
createParser: (JsonFactory, T) => JsonParser): StructType = {
45-
val parseMode = configOptions.parseMode
46-
val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
49+
val parseMode = options.parseMode
50+
val columnNameOfCorruptRecord = options.columnNameOfCorruptRecord
4751

4852
// In each RDD partition, perform schema inference on each row and merge afterwards.
49-
val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode)
53+
val typeMerger = JsonInferSchema.compatibleRootType(columnNameOfCorruptRecord, parseMode)
5054
val mergedTypesFromPartitions = json.mapPartitions { iter =>
5155
val factory = new JsonFactory()
52-
configOptions.setJacksonOptions(factory)
56+
options.setJacksonOptions(factory)
5357
iter.flatMap { row =>
5458
try {
5559
Utils.tryWithResource(createParser(factory, row)) { parser =>
5660
parser.nextToken()
57-
Some(inferField(parser, configOptions))
61+
Some(inferField(parser))
5862
}
5963
} catch {
6064
case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match {
@@ -82,42 +86,25 @@ private[sql] object JsonInferSchema {
8286
}
8387
json.sparkContext.runJob(mergedTypesFromPartitions, foldPartition, mergeResult)
8488

85-
canonicalizeType(rootType, configOptions) match {
89+
canonicalizeType(rootType, options) match {
8690
case Some(st: StructType) => st
8791
case _ =>
8892
// canonicalizeType erases all empty structs, including the only one we want to keep
8993
StructType(Nil)
9094
}
9195
}
9296

93-
private[this] val structFieldComparator = new Comparator[StructField] {
94-
override def compare(o1: StructField, o2: StructField): Int = {
95-
o1.name.compareTo(o2.name)
96-
}
97-
}
98-
99-
private def isSorted(arr: Array[StructField]): Boolean = {
100-
var i: Int = 0
101-
while (i < arr.length - 1) {
102-
if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) {
103-
return false
104-
}
105-
i += 1
106-
}
107-
true
108-
}
109-
11097
/**
11198
* Infer the type of a json document from the parser's token stream
11299
*/
113-
def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = {
100+
def inferField(parser: JsonParser): DataType = {
114101
import com.fasterxml.jackson.core.JsonToken._
115102
parser.getCurrentToken match {
116103
case null | VALUE_NULL => NullType
117104

118105
case FIELD_NAME =>
119106
parser.nextToken()
120-
inferField(parser, configOptions)
107+
inferField(parser)
121108

122109
case VALUE_STRING if parser.getTextLength < 1 =>
123110
// Zero length strings and nulls have special handling to deal
@@ -128,18 +115,25 @@ private[sql] object JsonInferSchema {
128115
// record fields' types have been combined.
129116
NullType
130117

118+
case VALUE_STRING if options.prefersDecimal =>
119+
val decimalTry = allCatch opt {
120+
val bigDecimal = decimalParser(parser.getText)
121+
DecimalType(bigDecimal.precision, bigDecimal.scale)
122+
}
123+
decimalTry.getOrElse(StringType)
131124
case VALUE_STRING => StringType
125+
132126
case START_OBJECT =>
133127
val builder = Array.newBuilder[StructField]
134128
while (nextUntil(parser, END_OBJECT)) {
135129
builder += StructField(
136130
parser.getCurrentName,
137-
inferField(parser, configOptions),
131+
inferField(parser),
138132
nullable = true)
139133
}
140134
val fields: Array[StructField] = builder.result()
141135
// Note: other code relies on this sorting for correctness, so don't remove it!
142-
java.util.Arrays.sort(fields, structFieldComparator)
136+
java.util.Arrays.sort(fields, JsonInferSchema.structFieldComparator)
143137
StructType(fields)
144138

145139
case START_ARRAY =>
@@ -148,15 +142,15 @@ private[sql] object JsonInferSchema {
148142
// the type as we pass through all JSON objects.
149143
var elementType: DataType = NullType
150144
while (nextUntil(parser, END_ARRAY)) {
151-
elementType = compatibleType(
152-
elementType, inferField(parser, configOptions))
145+
elementType = JsonInferSchema.compatibleType(
146+
elementType, inferField(parser))
153147
}
154148

155149
ArrayType(elementType)
156150

157-
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType
151+
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if options.primitivesAsString => StringType
158152

159-
case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType
153+
case (VALUE_TRUE | VALUE_FALSE) if options.primitivesAsString => StringType
160154

161155
case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
162156
import JsonParser.NumberType._
@@ -172,7 +166,7 @@ private[sql] object JsonInferSchema {
172166
} else {
173167
DoubleType
174168
}
175-
case FLOAT | DOUBLE if configOptions.prefersDecimal =>
169+
case FLOAT | DOUBLE if options.prefersDecimal =>
176170
val v = parser.getDecimalValue
177171
if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
178172
DecimalType(Math.max(v.precision(), v.scale()), v.scale())
@@ -217,20 +211,39 @@ private[sql] object JsonInferSchema {
217211

218212
case other => Some(other)
219213
}
214+
}
215+
216+
object JsonInferSchema {
217+
val structFieldComparator = new Comparator[StructField] {
218+
override def compare(o1: StructField, o2: StructField): Int = {
219+
o1.name.compareTo(o2.name)
220+
}
221+
}
222+
223+
def isSorted(arr: Array[StructField]): Boolean = {
224+
var i: Int = 0
225+
while (i < arr.length - 1) {
226+
if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) {
227+
return false
228+
}
229+
i += 1
230+
}
231+
true
232+
}
220233

221-
private def withCorruptField(
234+
def withCorruptField(
222235
struct: StructType,
223236
other: DataType,
224237
columnNameOfCorruptRecords: String,
225-
parseMode: ParseMode) = parseMode match {
238+
parseMode: ParseMode): StructType = parseMode match {
226239
case PermissiveMode =>
227240
// If we see any other data type at the root level, we get records that cannot be
228241
// parsed. So, we use the struct as the data type and add the corrupt field to the schema.
229242
if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) {
230243
// If this given struct does not have a column used for corrupt records,
231244
// add this field.
232245
val newFields: Array[StructField] =
233-
StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
246+
StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
234247
// Note: other code relies on this sorting for correctness, so don't remove it!
235248
java.util.Arrays.sort(newFields, structFieldComparator)
236249
StructType(newFields)
@@ -253,7 +266,7 @@ private[sql] object JsonInferSchema {
253266
/**
254267
* Remove top-level ArrayType wrappers and merge the remaining schemas
255268
*/
256-
private def compatibleRootType(
269+
def compatibleRootType(
257270
columnNameOfCorruptRecords: String,
258271
parseMode: ParseMode): (DataType, DataType) => DataType = {
259272
// Since we support array of json objects at the top level,

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import java.text.SimpleDateFormat
20+
import java.text.{DecimalFormat, DecimalFormatSymbols, SimpleDateFormat}
2121
import java.util.{Calendar, Locale}
2222

2323
import org.scalatest.exceptions.TestFailedException
@@ -765,4 +765,44 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
765765
timeZoneId = gmtId),
766766
expectedErrMsg = "The field for corrupt records must be string type and nullable")
767767
}
768+
769+
def decimalInput(langTag: String): (Decimal, String) = {
770+
val decimalVal = new java.math.BigDecimal("1000.001")
771+
val decimalType = new DecimalType(10, 5)
772+
val expected = Decimal(decimalVal, decimalType.precision, decimalType.scale)
773+
val decimalFormat = new DecimalFormat("",
774+
new DecimalFormatSymbols(Locale.forLanguageTag(langTag)))
775+
val input = s"""{"d": "${decimalFormat.format(expected.toBigDecimal)}"}"""
776+
777+
(expected, input)
778+
}
779+
780+
test("parse decimals using locale") {
781+
def checkDecimalParsing(langTag: String): Unit = {
782+
val schema = new StructType().add("d", DecimalType(10, 5))
783+
val options = Map("locale" -> langTag)
784+
val (expected, input) = decimalInput(langTag)
785+
786+
checkEvaluation(
787+
JsonToStructs(schema, options, Literal.create(input), gmtId),
788+
InternalRow(expected))
789+
}
790+
791+
Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalParsing)
792+
}
793+
794+
test("inferring the decimal type using locale") {
795+
def checkDecimalInfer(langTag: String, expectedType: String): Unit = {
796+
val options = Map("locale" -> langTag, "prefersDecimal" -> "true")
797+
val (_, input) = decimalInput(langTag)
798+
799+
checkEvaluation(
800+
SchemaOfJson(Literal.create(input), options),
801+
expectedType)
802+
}
803+
804+
Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach {
805+
checkDecimalInfer(_, """struct<d:decimal(7,3)>""")
806+
}
807+
}
768808
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ object TextInputJsonDataSource extends JsonDataSource {
107107
}.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow))
108108

109109
SQLExecution.withSQLConfPropagated(json.sparkSession) {
110-
JsonInferSchema.infer(rdd, parsedOptions, rowParser)
110+
new JsonInferSchema(parsedOptions).infer(rdd, rowParser)
111111
}
112112
}
113113

@@ -166,7 +166,7 @@ object MultiLineJsonDataSource extends JsonDataSource {
166166
.getOrElse(createParser(_: JsonFactory, _: PortableDataStream))
167167

168168
SQLExecution.withSQLConfPropagated(sparkSession) {
169-
JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
169+
new JsonInferSchema(parsedOptions).infer[PortableDataStream](sampled, parser)
170170
}
171171
}
172172

0 commit comments

Comments
 (0)