Skip to content

Commit 64a97a2

Browse files
committed
Merge remote-tracking branch 'origin/master' into decimal-parsing-locale
# Conflicts: # sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala
2 parents 3125c23 + 0558d02 commit 64a97a2

File tree

20 files changed

+386
-341
lines changed

20 files changed

+386
-341
lines changed

python/pyspark/sql/readwriter.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
177177
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
178178
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
179179
multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None,
180-
dropFieldIfAllNull=None, encoding=None):
180+
dropFieldIfAllNull=None, encoding=None, locale=None):
181181
"""
182182
Loads JSON files and returns the results as a :class:`DataFrame`.
183183
@@ -249,6 +249,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
249249
:param dropFieldIfAllNull: whether to ignore column of all null values or empty
250250
array/struct during schema inference. If None is set, it
251251
uses the default value, ``false``.
252+
:param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
253+
it uses the default value, ``en-US``. For instance, ``locale`` is used while
254+
parsing dates and timestamps.
252255
253256
>>> df1 = spark.read.json('python/test_support/sql/people.json')
254257
>>> df1.dtypes
@@ -267,7 +270,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
267270
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
268271
timestampFormat=timestampFormat, multiLine=multiLine,
269272
allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep,
270-
samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding)
273+
samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding,
274+
locale=locale)
271275
if isinstance(path, basestring):
272276
path = [path]
273277
if type(path) == list:
@@ -349,7 +353,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
349353
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
350354
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
351355
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
352-
samplingRatio=None, enforceSchema=None, emptyValue=None):
356+
samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None):
353357
r"""Loads a CSV file and returns the result as a :class:`DataFrame`.
354358
355359
This function will go through the input once to determine the input schema if
@@ -446,6 +450,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
446450
If None is set, it uses the default value, ``1.0``.
447451
:param emptyValue: sets the string representation of an empty value. If None is set, it uses
448452
the default value, empty string.
453+
:param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
454+
it uses the default value, ``en-US``. For instance, ``locale`` is used while
455+
parsing dates and timestamps.
449456
450457
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
451458
>>> df.dtypes
@@ -465,7 +472,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
465472
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
466473
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
467474
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio,
468-
enforceSchema=enforceSchema, emptyValue=emptyValue)
475+
enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale)
469476
if isinstance(path, basestring):
470477
path = [path]
471478
if type(path) == list:

python/pyspark/sql/streaming.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
404404
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
405405
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
406406
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
407-
multiLine=None, allowUnquotedControlChars=None, lineSep=None):
407+
multiLine=None, allowUnquotedControlChars=None, lineSep=None, locale=None):
408408
"""
409409
Loads a JSON file stream and returns the results as a :class:`DataFrame`.
410410
@@ -469,6 +469,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
469469
including tab and line feed characters) or not.
470470
:param lineSep: defines the line separator that should be used for parsing. If None is
471471
set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
472+
:param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
473+
it uses the default value, ``en-US``. For instance, ``locale`` is used while
474+
parsing dates and timestamps.
472475
473476
>>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema)
474477
>>> json_sdf.isStreaming
@@ -483,7 +486,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
483486
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
484487
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
485488
timestampFormat=timestampFormat, multiLine=multiLine,
486-
allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep)
489+
allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, locale=locale)
487490
if isinstance(path, basestring):
488491
return self._df(self._jreader.json(path))
489492
else:
@@ -564,7 +567,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
564567
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
565568
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
566569
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
567-
enforceSchema=None, emptyValue=None):
570+
enforceSchema=None, emptyValue=None, locale=None):
568571
r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.
569572
570573
This function will go through the input once to determine the input schema if
@@ -660,6 +663,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
660663
different, ``\0`` otherwise..
661664
:param emptyValue: sets the string representation of an empty value. If None is set, it uses
662665
the default value, empty string.
666+
:param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
667+
it uses the default value, ``en-US``. For instance, ``locale`` is used while
668+
parsing dates and timestamps.
663669
664670
>>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema)
665671
>>> csv_sdf.isStreaming
@@ -677,7 +683,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
677683
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
678684
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
679685
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema,
680-
emptyValue=emptyValue)
686+
emptyValue=emptyValue, locale=locale)
681687
if isinstance(path, basestring):
682688
return self._df(self._jreader.csv(path))
683689
else:

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,19 @@ private[sql] class JSONOptions(
7676
// Whether to ignore column of all null values or empty array/struct during schema inference
7777
val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false)
7878

79+
// A language tag in IETF BCP 47 format
80+
val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US)
81+
7982
val timeZone: TimeZone = DateTimeUtils.getTimeZone(
8083
parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId))
8184

8285
// Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
8386
val dateFormat: FastDateFormat =
84-
FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US)
87+
FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale)
8588

8689
val timestampFormat: FastDateFormat =
8790
FastDateFormat.getInstance(
88-
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US)
91+
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale)
8992

9093
val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)
9194

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import java.text.{DecimalFormat, DecimalFormatSymbols}
21+
import java.text.SimpleDateFormat
2122
import java.util.{Calendar, Locale}
2223

2324
import org.scalatest.exceptions.TestFailedException
@@ -211,6 +212,22 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P
211212
)
212213
}
213214

215+
test("parse date with locale") {
216+
Seq("en-US", "ru-RU").foreach { langTag =>
217+
val locale = Locale.forLanguageTag(langTag)
218+
val date = new SimpleDateFormat("yyyy-MM-dd").parse("2018-11-05")
219+
val schema = new StructType().add("d", DateType)
220+
val dateFormat = "MMM yyyy"
221+
val sdf = new SimpleDateFormat(dateFormat, locale)
222+
val dateStr = sdf.format(date)
223+
val options = Map("dateFormat" -> dateFormat, "locale" -> langTag)
224+
225+
checkEvaluation(
226+
CsvToStructs(schema, options, Literal.create(dateStr), gmtId),
227+
InternalRow(17836)) // number of days from 1970-01-01
228+
}
229+
}
230+
214231
test("parse decimals using locale") {
215232
Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach { langTag =>
216233
val schema = new StructType().add("d", DecimalType(10, 5))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import java.util.Calendar
20+
import java.text.SimpleDateFormat
21+
import java.util.{Calendar, Locale}
2122

2223
import org.scalatest.exceptions.TestFailedException
2324

@@ -737,4 +738,20 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
737738
CreateMap(Seq(Literal.create("allowNumericLeadingZeros"), Literal.create("true")))),
738739
"struct<col:bigint>")
739740
}
741+
742+
test("parse date with locale") {
743+
Seq("en-US", "ru-RU").foreach { langTag =>
744+
val locale = Locale.forLanguageTag(langTag)
745+
val date = new SimpleDateFormat("yyyy-MM-dd").parse("2018-11-05")
746+
val schema = new StructType().add("d", DateType)
747+
val dateFormat = "MMM yyyy"
748+
val sdf = new SimpleDateFormat(dateFormat, locale)
749+
val dateStr = s"""{"d":"${sdf.format(date)}"}"""
750+
val options = Map("dateFormat" -> dateFormat, "locale" -> langTag)
751+
752+
checkEvaluation(
753+
JsonToStructs(schema, options, Literal.create(dateStr), gmtId),
754+
InternalRow(17836)) // number of days from 1970-01-01
755+
}
756+
}
740757
}

0 commit comments

Comments
 (0)