diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index facc16bc53108..64f5507730a87 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None): + multiLine=None, allowUnquotedControlChars=None, charset=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -237,6 +237,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param charset: standard charset name, for example UTF-8, UTF-16 and UTF-32. If None is + set, the charset of input json will be detected automatically. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -254,7 +256,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars) + allowUnquotedControlChars=allowUnquotedControlChars, charset=charset) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 480815d27333f..fc019f2d1ebeb 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -654,6 +654,13 @@ def test_multiLine_json(self): multiLine=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_charset_json(self): + people_array = self.spark.read\ + .json("python/test_support/sql/people_array_utf16le.json", + multiLine=True, charset="UTF-16LE") + expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')] + self.assertEqual(people_array.collect(), expected) + def test_multiline_csv(self): ages_newlines = self.spark.read.csv( "python/test_support/sql/ages_newlines.csv", multiLine=True) diff --git a/python/test_support/sql/people_array_utf16le.json b/python/test_support/sql/people_array_utf16le.json new file mode 100644 index 0000000000000..9c657fa30ac9c Binary files /dev/null and b/python/test_support/sql/people_array_utf16le.json differ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index 025a388aacaa5..df393906557f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -39,11 +39,25 @@ private[sql] object CreateJacksonParser extends Serializable { jsonFactory.createParser(new InputStreamReader(bain, "UTF-8")) } - def text(jsonFactory: JsonFactory, record: Text): JsonParser = { - jsonFactory.createParser(record.getBytes, 0, record.getLength) + def text(jsonFactory: JsonFactory, record: Text, charset: Option[String] = None): JsonParser = { + charset match { + case Some(cs) => + val bain = new ByteArrayInputStream(record.getBytes, 0, record.getLength) + jsonFactory.createParser(new InputStreamReader(bain, cs)) + case _ => + jsonFactory.createParser(record.getBytes, 0, record.getLength) + } } - def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = { - jsonFactory.createParser(record) + def inputStream( + jsonFactory: JsonFactory, + is: InputStream, + charset: Option[String] = None): JsonParser = { + charset match { + case Some(cs) => + jsonFactory.createParser(new InputStreamReader(is, cs)) + case _ => + jsonFactory.createParser(is) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 652412b34478a..c261778421c12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -85,6 +85,12 @@ private[sql] class JSONOptions( val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) + /** + * Standard charset name. For example UTF-8, UTF-16 and UTF-32. + * If charset is not specified (None), it will be detected automatically. + */ + val charset: Option[String] = parameters.get("charset") + /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 7f6956994f31f..8ff165a1032dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.json -import java.io.ByteArrayOutputStream +import java.io.{ByteArrayOutputStream, CharConversionException} import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -361,6 +361,15 @@ class JacksonParser( // For such records, all fields other than the field configured by // `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => recordLiteral(record), () => None, e) + case e: CharConversionException if options.charset.isEmpty => + val msg = + """Failed to parse a character. Charset was detected automatically. + |You might want to set it explicitly via the charset option like: + | .option("charset", "UTF-8") + |Example of supported charsets: + | UTF-8, UTF-16, UTF-16BE, UTF-16LE, UTF-32, UTF-32BE, UTF-32LE + |""".stripMargin + e.getMessage + throw new CharConversionException(msg) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 0139913aaa4e2..cd271bbdf6183 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -366,6 +366,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • + *
  • `charset` (by default it is not set): allows to forcibly set one of standard basic + * or extended charsets for input jsons. For example UTF-8, UTF-16BE, UTF-32. If the charset + * is not specified (by default), the charset is detected automatically.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 77e7edc8e7a20..913b15c09b09b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -122,8 +122,10 @@ object TextInputJsonDataSource extends JsonDataSource { schema: StructType): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + val charset = parser.options.charset + val safeParser = new FailureSafeParser[Text]( - input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), + input => parser.parse[Text](input, CreateJacksonParser.text(_, _, charset), textToUTF8String), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) @@ -146,7 +148,12 @@ object MultiLineJsonDataSource extends JsonDataSource { parsedOptions: JSONOptions): StructType = { val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths) val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions) - JsonInferSchema.infer(sampled, parsedOptions, createParser) + + JsonInferSchema.infer[PortableDataStream]( + sampled, + parsedOptions, + createParser(_, _, parsedOptions.charset) + ) } private def createBaseRdd( @@ -168,11 +175,16 @@ object MultiLineJsonDataSource extends JsonDataSource { .values } - private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { + private def createParser( + jsonFactory: JsonFactory, + record: PortableDataStream, + charset: Option[String] = None): JsonParser = { val path = new Path(record.getPath()) CreateJacksonParser.inputStream( jsonFactory, - CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, path)) + CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, path), + charset + ) } override def readFile( @@ -180,21 +192,26 @@ object MultiLineJsonDataSource extends JsonDataSource { file: PartitionedFile, parser: JacksonParser, schema: StructType): Iterator[InternalRow] = { + def createInputStream() = { + CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))) + } def partitionedFileString(ignored: Any): UTF8String = { - Utils.tryWithResource { - CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))) - } { inputStream => - UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) + Utils.tryWithResource(createInputStream()) { is => + UTF8String.fromBytes(ByteStreams.toByteArray(is)) } } + val charset = parser.options.charset val safeParser = new FailureSafeParser[InputStream]( - input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString), + input => parser.parse[InputStream]( + input, + CreateJacksonParser.inputStream(_, _, charset), + partitionedFileString + ), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) - safeParser.parse( - CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath)))) + safeParser.parse(createInputStream()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 0862c746fffad..8d422dd95bfff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.json +import java.nio.charset.{Charset, StandardCharsets} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -151,7 +153,16 @@ private[json] class JsonOutputWriter( context: TaskAttemptContext) extends OutputWriter with Logging { - private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) + private val charset = options.charset match { + case Some(charsetName) => Charset.forName(charsetName) + case _ => StandardCharsets.UTF_8 + } + + private val writer = CodecStreams.createOutputStreamWriter( + context, + new Path(path), + charset + ) // create the Generator without separator inserted between 2 records private[this] val gen = new JacksonGenerator(dataSchema, writer, options) diff --git a/sql/core/src/test/resources/json-tests/utf16LE.json b/sql/core/src/test/resources/json-tests/utf16LE.json new file mode 100644 index 0000000000000..ce4117fd299df Binary files /dev/null and b/sql/core/src/test/resources/json-tests/utf16LE.json differ diff --git a/sql/core/src/test/resources/json-tests/utf16WithBOM.json b/sql/core/src/test/resources/json-tests/utf16WithBOM.json new file mode 100644 index 0000000000000..65e7e2f729481 Binary files /dev/null and b/sql/core/src/test/resources/json-tests/utf16WithBOM.json differ diff --git a/sql/core/src/test/resources/json-tests/utf32BEWithBOM.json b/sql/core/src/test/resources/json-tests/utf32BEWithBOM.json new file mode 100644 index 0000000000000..6c7733c577872 Binary files /dev/null and b/sql/core/src/test/resources/json-tests/utf32BEWithBOM.json differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 8c8d41ebf115a..0b18a6948035d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2063,4 +2063,178 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + def testFile(fileName: String): String = { + Thread.currentThread().getContextClassLoader.getResource(fileName).toString + } + + test("json in UTF-16 with BOM") { + val fileName = "json-tests/utf16WithBOM.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + // The mode filters null rows produced because new line delimiter + // for UTF-8 is used by default. + .option("mode", "DROPMALFORMED") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq( + Row("Chris", "Baird"), Row("Doug", "Rood") + )) + } + + test("multi-line json in UTF-32BE with BOM") { + val fileName = "json-tests/utf32BEWithBOM.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("Use user's charset in reading of multi-line json in UTF-16LE") { + val fileName = "json-tests/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .option("charset", "UTF-16LE") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("Unsupported charset name") { + val invalidCharset = "UTF-128" + val exception = intercept[SparkException] { + spark.read + .option("charset", invalidCharset) + .json(testFile("json-tests/utf16LE.json")) + .count() + } + val causedBy = exception.getCause + + assert(causedBy.isInstanceOf[java.io.UnsupportedEncodingException]) + assert(causedBy.getMessage.contains(invalidCharset)) + } + + test("checking that the charset option is case agnostic") { + val fileName = "json-tests/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .option("charset", "uTf-16lE") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + + test("specified charset is not matched to actual charset") { + val fileName = "json-tests/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val exception = intercept[SparkException] { + spark.read.schema(schema) + .option("mode", "FAILFAST") + .option("multiline", "true") + .option("charset", "UTF-16BE") + .json(testFile(fileName)) + .count() + } + val errMsg = exception.getMessage + + assert(errMsg.contains("Malformed records are detected in record parsing")) + } + + def checkCharset( + expectedCharset: String, + pathToJsonFiles: String, + expectedContent: String + ): Unit = { + val jsonFiles = new File(pathToJsonFiles) + .listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("json")) + val jsonContent = jsonFiles.map { file => + scala.io.Source.fromFile(file, expectedCharset).mkString + } + val cleanedContent = jsonContent + .mkString + .trim + .replaceAll(" ", "") + + assert(cleanedContent == expectedContent) + } + + test("save json in UTF-32BE") { + val charset = "UTF-32BE" + withTempPath { path => + val df = spark.createDataset(Seq(("Dog", 42))) + df.write + .option("charset", charset) + .format("json").mode("overwrite") + .save(path.getCanonicalPath) + + checkCharset( + expectedCharset = charset, + pathToJsonFiles = path.getCanonicalPath, + expectedContent = """{"_1":"Dog","_2":42}""" + ) + } + } + + test("save json in default charset - UTF-8") { + withTempPath { path => + val df = spark.createDataset(Seq(("Dog", 42))) + df.write + .format("json").mode("overwrite") + .save(path.getCanonicalPath) + + checkCharset( + expectedCharset = "UTF-8", + pathToJsonFiles = path.getCanonicalPath, + expectedContent = """{"_1":"Dog","_2":42}""" + ) + } + } + + test("wrong output charset") { + val charset = "UTF-128" + val exception = intercept[SparkException] { + withTempPath { path => + val df = spark.createDataset(Seq((0))) + df.write + .option("charset", charset) + .format("json").mode("overwrite") + .save(path.getCanonicalPath) + } + } + val causedBy = exception.getCause.getCause.getCause + + assert(causedBy.isInstanceOf[java.nio.charset.UnsupportedCharsetException]) + assert(causedBy.getMessage == charset) + } + + test("read written json in UTF-16") { + val charset = "UTF-16" + case class Rec(f1: String, f2: Int) + withTempPath { path => + val ds = spark.createDataset(Seq( + ("a", 1), ("b", 2), ("c", 3)) + ).repartition(2) + ds.write + .option("charset", charset) + .format("json").mode("overwrite") + .save(path.getCanonicalPath) + val savedDf = spark + .read + .schema(ds.schema) + .option("charset", charset) + // Wrong (nulls) rows are produced because new line delimiter + // for UTF-8 is used by default. + .option("mode", "DROPMALFORMED") + .json(path.getCanonicalPath) + + checkAnswer(savedDf.toDF(), ds.toDF()) + } + } }