diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index facc16bc53108..64f5507730a87 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -176,7 +176,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
- multiLine=None, allowUnquotedControlChars=None):
+ multiLine=None, allowUnquotedControlChars=None, charset=None):
"""
Loads JSON files and returns the results as a :class:`DataFrame`.
@@ -237,6 +237,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param allowUnquotedControlChars: allows JSON Strings to contain unquoted control
characters (ASCII characters with value less than 32,
including tab and line feed characters) or not.
+ :param charset: standard charset name, for example UTF-8, UTF-16 and UTF-32. If None is
+ set, the charset of input json will be detected automatically.
>>> df1 = spark.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
@@ -254,7 +256,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, multiLine=multiLine,
- allowUnquotedControlChars=allowUnquotedControlChars)
+ allowUnquotedControlChars=allowUnquotedControlChars, charset=charset)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 480815d27333f..fc019f2d1ebeb 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -654,6 +654,13 @@ def test_multiLine_json(self):
multiLine=True)
self.assertEqual(people1.collect(), people_array.collect())
+ def test_charset_json(self):
+ people_array = self.spark.read\
+ .json("python/test_support/sql/people_array_utf16le.json",
+ multiLine=True, charset="UTF-16LE")
+ expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')]
+ self.assertEqual(people_array.collect(), expected)
+
def test_multiline_csv(self):
ages_newlines = self.spark.read.csv(
"python/test_support/sql/ages_newlines.csv", multiLine=True)
diff --git a/python/test_support/sql/people_array_utf16le.json b/python/test_support/sql/people_array_utf16le.json
new file mode 100644
index 0000000000000..9c657fa30ac9c
Binary files /dev/null and b/python/test_support/sql/people_array_utf16le.json differ
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala
index 025a388aacaa5..df393906557f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala
@@ -39,11 +39,25 @@ private[sql] object CreateJacksonParser extends Serializable {
jsonFactory.createParser(new InputStreamReader(bain, "UTF-8"))
}
- def text(jsonFactory: JsonFactory, record: Text): JsonParser = {
- jsonFactory.createParser(record.getBytes, 0, record.getLength)
+ def text(jsonFactory: JsonFactory, record: Text, charset: Option[String] = None): JsonParser = {
+ charset match {
+ case Some(cs) =>
+ val bain = new ByteArrayInputStream(record.getBytes, 0, record.getLength)
+ jsonFactory.createParser(new InputStreamReader(bain, cs))
+ case _ =>
+ jsonFactory.createParser(record.getBytes, 0, record.getLength)
+ }
}
- def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = {
- jsonFactory.createParser(record)
+ def inputStream(
+ jsonFactory: JsonFactory,
+ is: InputStream,
+ charset: Option[String] = None): JsonParser = {
+ charset match {
+ case Some(cs) =>
+ jsonFactory.createParser(new InputStreamReader(is, cs))
+ case _ =>
+ jsonFactory.createParser(is)
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index 652412b34478a..c261778421c12 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -85,6 +85,12 @@ private[sql] class JSONOptions(
val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)
+ /**
+ * Standard charset name. For example UTF-8, UTF-16 and UTF-32.
+ * If charset is not specified (None), it will be detected automatically.
+ */
+ val charset: Option[String] = parameters.get("charset")
+
/** Sets config options on a Jackson [[JsonFactory]]. */
def setJacksonOptions(factory: JsonFactory): Unit = {
factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
index 7f6956994f31f..8ff165a1032dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.json
-import java.io.ByteArrayOutputStream
+import java.io.{ByteArrayOutputStream, CharConversionException}
import scala.collection.mutable.ArrayBuffer
import scala.util.Try
@@ -361,6 +361,15 @@ class JacksonParser(
// For such records, all fields other than the field configured by
// `columnNameOfCorruptRecord` are set to `null`.
throw BadRecordException(() => recordLiteral(record), () => None, e)
+ case e: CharConversionException if options.charset.isEmpty =>
+ val msg =
+ """Failed to parse a character. Charset was detected automatically.
+ |You might want to set it explicitly via the charset option like:
+ | .option("charset", "UTF-8")
+ |Example of supported charsets:
+ | UTF-8, UTF-16, UTF-16BE, UTF-16LE, UTF-32, UTF-32BE, UTF-32LE
+ |""".stripMargin + e.getMessage
+ throw new CharConversionException(msg)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 0139913aaa4e2..cd271bbdf6183 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -366,6 +366,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `java.text.SimpleDateFormat`. This applies to timestamp type.
*
`multiLine` (default `false`): parse one record, which may span multiple lines,
* per file
+ * `charset` (by default it is not set): allows to forcibly set one of standard basic
+ * or extended charsets for input jsons. For example UTF-8, UTF-16BE, UTF-32. If the charset
+ * is not specified (by default), the charset is detected automatically.
*
*
* @since 2.0.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
index 77e7edc8e7a20..913b15c09b09b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -122,8 +122,10 @@ object TextInputJsonDataSource extends JsonDataSource {
schema: StructType): Iterator[InternalRow] = {
val linesReader = new HadoopFileLinesReader(file, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
+ val charset = parser.options.charset
+
val safeParser = new FailureSafeParser[Text](
- input => parser.parse(input, CreateJacksonParser.text, textToUTF8String),
+ input => parser.parse[Text](input, CreateJacksonParser.text(_, _, charset), textToUTF8String),
parser.options.parseMode,
schema,
parser.options.columnNameOfCorruptRecord)
@@ -146,7 +148,12 @@ object MultiLineJsonDataSource extends JsonDataSource {
parsedOptions: JSONOptions): StructType = {
val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths)
val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions)
- JsonInferSchema.infer(sampled, parsedOptions, createParser)
+
+ JsonInferSchema.infer[PortableDataStream](
+ sampled,
+ parsedOptions,
+ createParser(_, _, parsedOptions.charset)
+ )
}
private def createBaseRdd(
@@ -168,11 +175,16 @@ object MultiLineJsonDataSource extends JsonDataSource {
.values
}
- private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
+ private def createParser(
+ jsonFactory: JsonFactory,
+ record: PortableDataStream,
+ charset: Option[String] = None): JsonParser = {
val path = new Path(record.getPath())
CreateJacksonParser.inputStream(
jsonFactory,
- CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, path))
+ CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, path),
+ charset
+ )
}
override def readFile(
@@ -180,21 +192,26 @@ object MultiLineJsonDataSource extends JsonDataSource {
file: PartitionedFile,
parser: JacksonParser,
schema: StructType): Iterator[InternalRow] = {
+ def createInputStream() = {
+ CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath)))
+ }
def partitionedFileString(ignored: Any): UTF8String = {
- Utils.tryWithResource {
- CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath)))
- } { inputStream =>
- UTF8String.fromBytes(ByteStreams.toByteArray(inputStream))
+ Utils.tryWithResource(createInputStream()) { is =>
+ UTF8String.fromBytes(ByteStreams.toByteArray(is))
}
}
+ val charset = parser.options.charset
val safeParser = new FailureSafeParser[InputStream](
- input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString),
+ input => parser.parse[InputStream](
+ input,
+ CreateJacksonParser.inputStream(_, _, charset),
+ partitionedFileString
+ ),
parser.options.parseMode,
schema,
parser.options.columnNameOfCorruptRecord)
- safeParser.parse(
- CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))))
+ safeParser.parse(createInputStream())
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index 0862c746fffad..8d422dd95bfff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.json
+import java.nio.charset.{Charset, StandardCharsets}
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
@@ -151,7 +153,16 @@ private[json] class JsonOutputWriter(
context: TaskAttemptContext)
extends OutputWriter with Logging {
- private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
+ private val charset = options.charset match {
+ case Some(charsetName) => Charset.forName(charsetName)
+ case _ => StandardCharsets.UTF_8
+ }
+
+ private val writer = CodecStreams.createOutputStreamWriter(
+ context,
+ new Path(path),
+ charset
+ )
// create the Generator without separator inserted between 2 records
private[this] val gen = new JacksonGenerator(dataSchema, writer, options)
diff --git a/sql/core/src/test/resources/json-tests/utf16LE.json b/sql/core/src/test/resources/json-tests/utf16LE.json
new file mode 100644
index 0000000000000..ce4117fd299df
Binary files /dev/null and b/sql/core/src/test/resources/json-tests/utf16LE.json differ
diff --git a/sql/core/src/test/resources/json-tests/utf16WithBOM.json b/sql/core/src/test/resources/json-tests/utf16WithBOM.json
new file mode 100644
index 0000000000000..65e7e2f729481
Binary files /dev/null and b/sql/core/src/test/resources/json-tests/utf16WithBOM.json differ
diff --git a/sql/core/src/test/resources/json-tests/utf32BEWithBOM.json b/sql/core/src/test/resources/json-tests/utf32BEWithBOM.json
new file mode 100644
index 0000000000000..6c7733c577872
Binary files /dev/null and b/sql/core/src/test/resources/json-tests/utf32BEWithBOM.json differ
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 8c8d41ebf115a..0b18a6948035d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -2063,4 +2063,178 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
)
}
}
+
+ def testFile(fileName: String): String = {
+ Thread.currentThread().getContextClassLoader.getResource(fileName).toString
+ }
+
+ test("json in UTF-16 with BOM") {
+ val fileName = "json-tests/utf16WithBOM.json"
+ val schema = new StructType().add("firstName", StringType).add("lastName", StringType)
+ val jsonDF = spark.read.schema(schema)
+ // The mode filters null rows produced because new line delimiter
+ // for UTF-8 is used by default.
+ .option("mode", "DROPMALFORMED")
+ .json(testFile(fileName))
+
+ checkAnswer(jsonDF, Seq(
+ Row("Chris", "Baird"), Row("Doug", "Rood")
+ ))
+ }
+
+ test("multi-line json in UTF-32BE with BOM") {
+ val fileName = "json-tests/utf32BEWithBOM.json"
+ val schema = new StructType().add("firstName", StringType).add("lastName", StringType)
+ val jsonDF = spark.read.schema(schema)
+ .option("multiline", "true")
+ .json(testFile(fileName))
+
+ checkAnswer(jsonDF, Seq(Row("Chris", "Baird")))
+ }
+
+ test("Use user's charset in reading of multi-line json in UTF-16LE") {
+ val fileName = "json-tests/utf16LE.json"
+ val schema = new StructType().add("firstName", StringType).add("lastName", StringType)
+ val jsonDF = spark.read.schema(schema)
+ .option("multiline", "true")
+ .option("charset", "UTF-16LE")
+ .json(testFile(fileName))
+
+ checkAnswer(jsonDF, Seq(Row("Chris", "Baird")))
+ }
+
+ test("Unsupported charset name") {
+ val invalidCharset = "UTF-128"
+ val exception = intercept[SparkException] {
+ spark.read
+ .option("charset", invalidCharset)
+ .json(testFile("json-tests/utf16LE.json"))
+ .count()
+ }
+ val causedBy = exception.getCause
+
+ assert(causedBy.isInstanceOf[java.io.UnsupportedEncodingException])
+ assert(causedBy.getMessage.contains(invalidCharset))
+ }
+
+ test("checking that the charset option is case agnostic") {
+ val fileName = "json-tests/utf16LE.json"
+ val schema = new StructType().add("firstName", StringType).add("lastName", StringType)
+ val jsonDF = spark.read.schema(schema)
+ .option("multiline", "true")
+ .option("charset", "uTf-16lE")
+ .json(testFile(fileName))
+
+ checkAnswer(jsonDF, Seq(Row("Chris", "Baird")))
+ }
+
+
+ test("specified charset is not matched to actual charset") {
+ val fileName = "json-tests/utf16LE.json"
+ val schema = new StructType().add("firstName", StringType).add("lastName", StringType)
+ val exception = intercept[SparkException] {
+ spark.read.schema(schema)
+ .option("mode", "FAILFAST")
+ .option("multiline", "true")
+ .option("charset", "UTF-16BE")
+ .json(testFile(fileName))
+ .count()
+ }
+ val errMsg = exception.getMessage
+
+ assert(errMsg.contains("Malformed records are detected in record parsing"))
+ }
+
+ def checkCharset(
+ expectedCharset: String,
+ pathToJsonFiles: String,
+ expectedContent: String
+ ): Unit = {
+ val jsonFiles = new File(pathToJsonFiles)
+ .listFiles()
+ .filter(_.isFile)
+ .filter(_.getName.endsWith("json"))
+ val jsonContent = jsonFiles.map { file =>
+ scala.io.Source.fromFile(file, expectedCharset).mkString
+ }
+ val cleanedContent = jsonContent
+ .mkString
+ .trim
+ .replaceAll(" ", "")
+
+ assert(cleanedContent == expectedContent)
+ }
+
+ test("save json in UTF-32BE") {
+ val charset = "UTF-32BE"
+ withTempPath { path =>
+ val df = spark.createDataset(Seq(("Dog", 42)))
+ df.write
+ .option("charset", charset)
+ .format("json").mode("overwrite")
+ .save(path.getCanonicalPath)
+
+ checkCharset(
+ expectedCharset = charset,
+ pathToJsonFiles = path.getCanonicalPath,
+ expectedContent = """{"_1":"Dog","_2":42}"""
+ )
+ }
+ }
+
+ test("save json in default charset - UTF-8") {
+ withTempPath { path =>
+ val df = spark.createDataset(Seq(("Dog", 42)))
+ df.write
+ .format("json").mode("overwrite")
+ .save(path.getCanonicalPath)
+
+ checkCharset(
+ expectedCharset = "UTF-8",
+ pathToJsonFiles = path.getCanonicalPath,
+ expectedContent = """{"_1":"Dog","_2":42}"""
+ )
+ }
+ }
+
+ test("wrong output charset") {
+ val charset = "UTF-128"
+ val exception = intercept[SparkException] {
+ withTempPath { path =>
+ val df = spark.createDataset(Seq((0)))
+ df.write
+ .option("charset", charset)
+ .format("json").mode("overwrite")
+ .save(path.getCanonicalPath)
+ }
+ }
+ val causedBy = exception.getCause.getCause.getCause
+
+ assert(causedBy.isInstanceOf[java.nio.charset.UnsupportedCharsetException])
+ assert(causedBy.getMessage == charset)
+ }
+
+ test("read written json in UTF-16") {
+ val charset = "UTF-16"
+ case class Rec(f1: String, f2: Int)
+ withTempPath { path =>
+ val ds = spark.createDataset(Seq(
+ ("a", 1), ("b", 2), ("c", 3))
+ ).repartition(2)
+ ds.write
+ .option("charset", charset)
+ .format("json").mode("overwrite")
+ .save(path.getCanonicalPath)
+ val savedDf = spark
+ .read
+ .schema(ds.schema)
+ .option("charset", charset)
+ // Wrong (nulls) rows are produced because new line delimiter
+ // for UTF-8 is used by default.
+ .option("mode", "DROPMALFORMED")
+ .json(path.getCanonicalPath)
+
+ checkAnswer(savedDf.toDF(), ds.toDF())
+ }
+ }
}