diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 3800d53c02f4..87b9e8eb445a 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -147,7 +147,13 @@ public void writeTo(ByteBuffer buffer) {
buffer.position(pos + numBytes);
}
- public void writeTo(OutputStream out) throws IOException {
+ /**
+ * Returns a {@link ByteBuffer} wrapping the base object if it is a byte array
+ * or a copy of the data if the base object is not a byte array.
+ *
+ * Unlike getBytes this will not create a copy the array if this is a slice.
+ */
+ public @Nonnull ByteBuffer getByteBuffer() {
if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) {
final byte[] bytes = (byte[]) base;
@@ -160,12 +166,20 @@ public void writeTo(OutputStream out) throws IOException {
throw new ArrayIndexOutOfBoundsException();
}
- out.write(bytes, (int) arrayOffset, numBytes);
+ return ByteBuffer.wrap(bytes, (int) arrayOffset, numBytes);
} else {
- out.write(getBytes());
+ return ByteBuffer.wrap(getBytes());
}
}
+ public void writeTo(OutputStream out) throws IOException {
+ final ByteBuffer bb = this.getByteBuffer();
+ assert(bb.hasArray());
+
+ // similar to Utils.writeByteBuffer but without the spark-core dependency
+ out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining());
+ }
+
/**
* Returns the number of bytes for a code point with the first byte as `b`
* @param b The first byte of a code point
diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
index 59404e08895a..9606c4754314 100644
--- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
+++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
@@ -29,6 +29,7 @@ import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFil
import org.apache.spark.internal.config
import org.apache.spark.SparkContext
+import org.apache.spark.annotation.Since
/**
* A general format for reading whole files in as streams, byte arrays,
@@ -175,6 +176,7 @@ class PortableDataStream(
* Create a new DataInputStream from the split and context. The user of this method is responsible
* for closing the stream after usage.
*/
+ @Since("1.2.0")
def open(): DataInputStream = {
val pathp = split.getPath(index)
val fs = pathp.getFileSystem(conf)
@@ -184,6 +186,7 @@ class PortableDataStream(
/**
* Read the file as a byte array
*/
+ @Since("1.2.0")
def toArray(): Array[Byte] = {
val stream = open()
try {
@@ -193,6 +196,10 @@ class PortableDataStream(
}
}
+ @Since("1.2.0")
def getPath(): String = path
+
+ @Since("2.2.0")
+ def getConfiguration: Configuration = conf
}
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 167833488980..6bed390e60c9 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -159,11 +159,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
- timeZone=None):
+ timeZone=None, wholeFile=None):
"""
- Loads a JSON file (`JSON Lines text format or newline-delimited JSON
- `_) or an RDD of Strings storing JSON objects (one object per
- record) and returns the result as a :class`DataFrame`.
+ Loads a JSON file and returns the results as a :class:`DataFrame`.
+
+ Both JSON (one record per file) and `JSON Lines `_
+ (newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter.
If the ``schema`` parameter is not specified, this function goes
through the input once to determine the input schema.
@@ -212,6 +213,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
:param timeZone: sets the string that indicates a timezone to be used to parse timestamps.
If None is set, it uses the default value, session local timezone.
+ :param wholeFile: parse one record, which may span multiple lines, per file. If None is
+ set, it uses the default value, ``false``.
>>> df1 = spark.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
@@ -228,7 +231,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
- timestampFormat=timestampFormat, timeZone=timeZone)
+ timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index d988e596a86d..965c8f6b269e 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -428,11 +428,13 @@ def load(self, path=None, format=None, schema=None, **options):
def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
- mode=None, columnNameOfCorruptRecord=None, dateFormat=None,
- timestampFormat=None, timeZone=None):
+ mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
+ timeZone=None, wholeFile=None):
"""
- Loads a JSON file stream (`JSON Lines text format or newline-delimited JSON
- `_) and returns a :class`DataFrame`.
+ Loads a JSON file stream and returns the results as a :class:`DataFrame`.
+
+ Both JSON (one record per file) and `JSON Lines `_
+ (newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter.
If the ``schema`` parameter is not specified, this function goes
through the input once to determine the input schema.
@@ -483,6 +485,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
:param timeZone: sets the string that indicates a timezone to be used to parse timestamps.
If None is set, it uses the default value, session local timezone.
+ :param wholeFile: parse one record, which may span multiple lines, per file. If None is
+ set, it uses the default value, ``false``.
>>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema)
>>> json_sdf.isStreaming
@@ -496,7 +500,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
- timestampFormat=timestampFormat, timeZone=timeZone)
+ timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile)
if isinstance(path, basestring):
return self._df(self._jreader.json(path))
else:
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index d8b7b3137c1c..9058443285ac 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -439,6 +439,13 @@ def test_udf_with_order_by_and_limit(self):
res.explain(True)
self.assertEqual(res.collect(), [Row(id=0, copy=0)])
+ def test_wholefile_json(self):
+ from pyspark.sql.types import StringType
+ people1 = self.spark.read.json("python/test_support/sql/people.json")
+ people_array = self.spark.read.json("python/test_support/sql/people_array.json",
+ wholeFile=True)
+ self.assertEqual(people1.collect(), people_array.collect())
+
def test_udf_with_input_file_name(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.types import StringType
diff --git a/python/test_support/sql/people_array.json b/python/test_support/sql/people_array.json
new file mode 100644
index 000000000000..c27c48fe343e
--- /dev/null
+++ b/python/test_support/sql/people_array.json
@@ -0,0 +1,13 @@
+[
+ {
+ "name": "Michael"
+ },
+ {
+ "name": "Andy",
+ "age": 30
+ },
+ {
+ "name": "Justin",
+ "age": 19
+ }
+]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index bd852a50fe71..1e690a446951 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -497,8 +497,7 @@ case class JsonToStruct(
lazy val parser =
new JacksonParser(
schema,
- "invalid", // Not used since we force fail fast. Invalid rows will be set to `null`.
- new JSONOptions(options ++ Map("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get))
+ new JSONOptions(options + ("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get))
override def dataType: DataType = schema
@@ -506,7 +505,12 @@ case class JsonToStruct(
copy(timeZoneId = Option(timeZoneId))
override def nullSafeEval(json: Any): Any = {
- try parser.parse(json.toString).headOption.orNull catch {
+ try {
+ parser.parse(
+ json.asInstanceOf[UTF8String],
+ CreateJacksonParser.utf8String,
+ identity[UTF8String]).headOption.orNull
+ } catch {
case _: SparkSQLJsonProcessingException => null
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala
new file mode 100644
index 000000000000..e0ed03a68981
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.json
+
+import java.io.InputStream
+
+import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
+import org.apache.hadoop.io.Text
+
+import org.apache.spark.unsafe.types.UTF8String
+
+private[sql] object CreateJacksonParser extends Serializable {
+ def string(jsonFactory: JsonFactory, record: String): JsonParser = {
+ jsonFactory.createParser(record)
+ }
+
+ def utf8String(jsonFactory: JsonFactory, record: UTF8String): JsonParser = {
+ val bb = record.getByteBuffer
+ assert(bb.hasArray)
+
+ jsonFactory.createParser(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
+ }
+
+ def text(jsonFactory: JsonFactory, record: Text): JsonParser = {
+ jsonFactory.createParser(record.getBytes, 0, record.getLength)
+ }
+
+ def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = {
+ jsonFactory.createParser(record)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index 5307ce1cb711..5a91f9c1939a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -31,11 +31,20 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs
* Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]].
*/
private[sql] class JSONOptions(
- @transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String)
+ @transient private val parameters: CaseInsensitiveMap[String],
+ defaultTimeZoneId: String,
+ defaultColumnNameOfCorruptRecord: String)
extends Logging with Serializable {
- def this(parameters: Map[String, String], defaultTimeZoneId: String) =
- this(CaseInsensitiveMap(parameters), defaultTimeZoneId)
+ def this(
+ parameters: Map[String, String],
+ defaultTimeZoneId: String,
+ defaultColumnNameOfCorruptRecord: String = "") = {
+ this(
+ CaseInsensitiveMap(parameters),
+ defaultTimeZoneId,
+ defaultColumnNameOfCorruptRecord)
+ }
val samplingRatio =
parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
@@ -57,7 +66,8 @@ private[sql] class JSONOptions(
parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false)
val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName)
private val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
- val columnNameOfCorruptRecord = parameters.get("columnNameOfCorruptRecord")
+ val columnNameOfCorruptRecord =
+ parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord)
val timeZone: TimeZone = TimeZone.getTimeZone(parameters.getOrElse("timeZone", defaultTimeZoneId))
@@ -69,6 +79,8 @@ private[sql] class JSONOptions(
FastDateFormat.getInstance(
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US)
+ val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false)
+
// Parse mode flags
if (!ParseModes.isValidMode(parseMode)) {
logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
index 03e27ba934fb..995095969d7a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
@@ -39,7 +39,6 @@ private[sql] class SparkSQLJsonProcessingException(msg: String) extends RuntimeE
*/
class JacksonParser(
schema: StructType,
- columnNameOfCorruptRecord: String,
options: JSONOptions) extends Logging {
import JacksonUtils._
@@ -48,69 +47,110 @@ class JacksonParser(
// A `ValueConverter` is responsible for converting a value from `JsonParser`
// to a value in a field for `InternalRow`.
- private type ValueConverter = (JsonParser) => Any
+ private type ValueConverter = JsonParser => AnyRef
// `ValueConverter`s for the root schema for all fields in the schema
- private val rootConverter: ValueConverter = makeRootConverter(schema)
+ private val rootConverter = makeRootConverter(schema)
private val factory = new JsonFactory()
options.setJacksonOptions(factory)
private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length))
+ private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord)
+ corruptFieldIndex.foreach(idx => require(schema(idx).dataType == StringType))
+
+ @transient
+ private[this] var isWarningPrinted: Boolean = false
+
@transient
- private[this] var isWarningPrintedForMalformedRecord: Boolean = false
+ private def printWarningForMalformedRecord(record: () => UTF8String): Unit = {
+ def sampleRecord: String = {
+ if (options.wholeFile) {
+ ""
+ } else {
+ s"Sample record: ${record()}\n"
+ }
+ }
+
+ def footer: String = {
+ s"""Code example to print all malformed records (scala):
+ |===================================================
+ |// The corrupted record exists in column ${options.columnNameOfCorruptRecord}.
+ |val parsedJson = spark.read.json("/path/to/json/file/test.json")
+ |
+ """.stripMargin
+ }
+
+ if (options.permissive) {
+ logWarning(
+ s"""Found at least one malformed record. The JSON reader will replace
+ |all malformed records with placeholder null in current $PERMISSIVE_MODE parser mode.
+ |To find out which corrupted records have been replaced with null, please use the
+ |default inferred schema instead of providing a custom schema.
+ |
+ |${sampleRecord ++ footer}
+ |
+ """.stripMargin)
+ } else if (options.dropMalformed) {
+ logWarning(
+ s"""Found at least one malformed record. The JSON reader will drop
+ |all malformed records in current $DROP_MALFORMED_MODE parser mode. To find out which
+ |corrupted records have been dropped, please switch the parser mode to $PERMISSIVE_MODE
+ |mode and use the default inferred schema.
+ |
+ |${sampleRecord ++ footer}
+ |
+ """.stripMargin)
+ }
+ }
+
+ @transient
+ private def printWarningIfWholeFile(): Unit = {
+ if (options.wholeFile && corruptFieldIndex.isDefined) {
+ logWarning(
+ s"""Enabling wholeFile mode and defining columnNameOfCorruptRecord may result
+ |in very large allocations or OutOfMemoryExceptions being raised.
+ |
+ """.stripMargin)
+ }
+ }
/**
* This function deals with the cases it fails to parse. This function will be called
* when exceptions are caught during converting. This functions also deals with `mode` option.
*/
- private def failedRecord(record: String): Seq[InternalRow] = {
- // create a row even if no corrupt record column is present
- if (options.failFast) {
- throw new SparkSQLJsonProcessingException(s"Malformed line in FAILFAST mode: $record")
- }
- if (options.dropMalformed) {
- if (!isWarningPrintedForMalformedRecord) {
- logWarning(
- s"""Found at least one malformed records (sample: $record). The JSON reader will drop
- |all malformed records in current $DROP_MALFORMED_MODE parser mode. To find out which
- |corrupted records have been dropped, please switch the parser mode to $PERMISSIVE_MODE
- |mode and use the default inferred schema.
- |
- |Code example to print all malformed records (scala):
- |===================================================
- |// The corrupted record exists in column ${columnNameOfCorruptRecord}
- |val parsedJson = spark.read.json("/path/to/json/file/test.json")
- |
- """.stripMargin)
- isWarningPrintedForMalformedRecord = true
- }
- Nil
- } else if (schema.getFieldIndex(columnNameOfCorruptRecord).isEmpty) {
- if (!isWarningPrintedForMalformedRecord) {
- logWarning(
- s"""Found at least one malformed records (sample: $record). The JSON reader will replace
- |all malformed records with placeholder null in current $PERMISSIVE_MODE parser mode.
- |To find out which corrupted records have been replaced with null, please use the
- |default inferred schema instead of providing a custom schema.
- |
- |Code example to print all malformed records (scala):
- |===================================================
- |// The corrupted record exists in column ${columnNameOfCorruptRecord}.
- |val parsedJson = spark.read.json("/path/to/json/file/test.json")
- |
- """.stripMargin)
- isWarningPrintedForMalformedRecord = true
- }
- emptyRow
- } else {
- val row = new GenericInternalRow(schema.length)
- for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecord)) {
- require(schema(corruptIndex).dataType == StringType)
- row.update(corruptIndex, UTF8String.fromString(record))
- }
- Seq(row)
+ private def failedRecord(record: () => UTF8String): Seq[InternalRow] = {
+ corruptFieldIndex match {
+ case _ if options.failFast =>
+ if (options.wholeFile) {
+ throw new SparkSQLJsonProcessingException("Malformed line in FAILFAST mode")
+ } else {
+ throw new SparkSQLJsonProcessingException(s"Malformed line in FAILFAST mode: ${record()}")
+ }
+
+ case _ if options.dropMalformed =>
+ if (!isWarningPrinted) {
+ printWarningForMalformedRecord(record)
+ isWarningPrinted = true
+ }
+ Nil
+
+ case None =>
+ if (!isWarningPrinted) {
+ printWarningForMalformedRecord(record)
+ isWarningPrinted = true
+ }
+ emptyRow
+
+ case Some(corruptIndex) =>
+ if (!isWarningPrinted) {
+ printWarningIfWholeFile()
+ isWarningPrinted = true
+ }
+ val row = new GenericInternalRow(schema.length)
+ row.update(corruptIndex, record())
+ Seq(row)
}
}
@@ -119,11 +159,11 @@ class JacksonParser(
* to a value according to a desired schema. This is a wrapper for the method
* `makeConverter()` to handle a row wrapped with an array.
*/
- private def makeRootConverter(st: StructType): ValueConverter = {
+ private def makeRootConverter(st: StructType): JsonParser => Seq[InternalRow] = {
val elementConverter = makeConverter(st)
- val fieldConverters = st.map(_.dataType).map(makeConverter)
- (parser: JsonParser) => parseJsonToken(parser, st) {
- case START_OBJECT => convertObject(parser, st, fieldConverters)
+ val fieldConverters = st.map(_.dataType).map(makeConverter).toArray
+ (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, st) {
+ case START_OBJECT => convertObject(parser, st, fieldConverters) :: Nil
// SPARK-3308: support reading top level JSON arrays and take every element
// in such an array as a row
//
@@ -137,7 +177,15 @@ class JacksonParser(
// List([str_a_1,null])
// List([str_a_2,null], [null,str_b_3])
//
- case START_ARRAY => convertArray(parser, elementConverter)
+ case START_ARRAY =>
+ val array = convertArray(parser, elementConverter)
+ // Here, as we support reading top level JSON arrays and take every element
+ // in such an array as a row, this case is possible.
+ if (array.numElements() == 0) {
+ Nil
+ } else {
+ array.toArray[InternalRow](schema).toSeq
+ }
}
}
@@ -145,35 +193,35 @@ class JacksonParser(
* Create a converter which converts the JSON documents held by the `JsonParser`
* to a value according to a desired schema.
*/
- private[sql] def makeConverter(dataType: DataType): ValueConverter = dataType match {
+ def makeConverter(dataType: DataType): ValueConverter = dataType match {
case BooleanType =>
- (parser: JsonParser) => parseJsonToken(parser, dataType) {
+ (parser: JsonParser) => parseJsonToken[java.lang.Boolean](parser, dataType) {
case VALUE_TRUE => true
case VALUE_FALSE => false
}
case ByteType =>
- (parser: JsonParser) => parseJsonToken(parser, dataType) {
+ (parser: JsonParser) => parseJsonToken[java.lang.Byte](parser, dataType) {
case VALUE_NUMBER_INT => parser.getByteValue
}
case ShortType =>
- (parser: JsonParser) => parseJsonToken(parser, dataType) {
+ (parser: JsonParser) => parseJsonToken[java.lang.Short](parser, dataType) {
case VALUE_NUMBER_INT => parser.getShortValue
}
case IntegerType =>
- (parser: JsonParser) => parseJsonToken(parser, dataType) {
+ (parser: JsonParser) => parseJsonToken[java.lang.Integer](parser, dataType) {
case VALUE_NUMBER_INT => parser.getIntValue
}
case LongType =>
- (parser: JsonParser) => parseJsonToken(parser, dataType) {
+ (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) {
case VALUE_NUMBER_INT => parser.getLongValue
}
case FloatType =>
- (parser: JsonParser) => parseJsonToken(parser, dataType) {
+ (parser: JsonParser) => parseJsonToken[java.lang.Float](parser, dataType) {
case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
parser.getFloatValue
@@ -193,7 +241,7 @@ class JacksonParser(
}
case DoubleType =>
- (parser: JsonParser) => parseJsonToken(parser, dataType) {
+ (parser: JsonParser) => parseJsonToken[java.lang.Double](parser, dataType) {
case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
parser.getDoubleValue
@@ -213,7 +261,7 @@ class JacksonParser(
}
case StringType =>
- (parser: JsonParser) => parseJsonToken(parser, dataType) {
+ (parser: JsonParser) => parseJsonToken[UTF8String](parser, dataType) {
case VALUE_STRING =>
UTF8String.fromString(parser.getText)
@@ -227,66 +275,71 @@ class JacksonParser(
}
case TimestampType =>
- (parser: JsonParser) => parseJsonToken(parser, dataType) {
+ (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) {
case VALUE_STRING =>
+ val stringValue = parser.getText
// This one will lose microseconds parts.
// See https://issues.apache.org/jira/browse/SPARK-10681.
- Try(options.timestampFormat.parse(parser.getText).getTime * 1000L)
- .getOrElse {
- // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards
- // compatibility.
- DateTimeUtils.stringToTime(parser.getText).getTime * 1000L
- }
+ Long.box {
+ Try(options.timestampFormat.parse(stringValue).getTime * 1000L)
+ .getOrElse {
+ // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards
+ // compatibility.
+ DateTimeUtils.stringToTime(stringValue).getTime * 1000L
+ }
+ }
case VALUE_NUMBER_INT =>
parser.getLongValue * 1000000L
}
case DateType =>
- (parser: JsonParser) => parseJsonToken(parser, dataType) {
+ (parser: JsonParser) => parseJsonToken[java.lang.Integer](parser, dataType) {
case VALUE_STRING =>
val stringValue = parser.getText
// This one will lose microseconds parts.
// See https://issues.apache.org/jira/browse/SPARK-10681.x
- Try(DateTimeUtils.millisToDays(options.dateFormat.parse(parser.getText).getTime))
- .getOrElse {
- // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards
- // compatibility.
- Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime))
+ Int.box {
+ Try(DateTimeUtils.millisToDays(options.dateFormat.parse(stringValue).getTime))
+ .orElse {
+ // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards
+ // compatibility.
+ Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(stringValue).getTime))
+ }
.getOrElse {
- // In Spark 1.5.0, we store the data as number of days since epoch in string.
- // So, we just convert it to Int.
- stringValue.toInt
- }
+ // In Spark 1.5.0, we store the data as number of days since epoch in string.
+ // So, we just convert it to Int.
+ stringValue.toInt
+ }
}
}
case BinaryType =>
- (parser: JsonParser) => parseJsonToken(parser, dataType) {
+ (parser: JsonParser) => parseJsonToken[Array[Byte]](parser, dataType) {
case VALUE_STRING => parser.getBinaryValue
}
case dt: DecimalType =>
- (parser: JsonParser) => parseJsonToken(parser, dataType) {
+ (parser: JsonParser) => parseJsonToken[Decimal](parser, dataType) {
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) =>
Decimal(parser.getDecimalValue, dt.precision, dt.scale)
}
case st: StructType =>
- val fieldConverters = st.map(_.dataType).map(makeConverter)
- (parser: JsonParser) => parseJsonToken(parser, dataType) {
+ val fieldConverters = st.map(_.dataType).map(makeConverter).toArray
+ (parser: JsonParser) => parseJsonToken[InternalRow](parser, dataType) {
case START_OBJECT => convertObject(parser, st, fieldConverters)
}
case at: ArrayType =>
val elementConverter = makeConverter(at.elementType)
- (parser: JsonParser) => parseJsonToken(parser, dataType) {
+ (parser: JsonParser) => parseJsonToken[ArrayData](parser, dataType) {
case START_ARRAY => convertArray(parser, elementConverter)
}
case mt: MapType =>
val valueConverter = makeConverter(mt.valueType)
- (parser: JsonParser) => parseJsonToken(parser, dataType) {
+ (parser: JsonParser) => parseJsonToken[MapData](parser, dataType) {
case START_OBJECT => convertMap(parser, valueConverter)
}
@@ -298,7 +351,7 @@ class JacksonParser(
// Here, we pass empty `PartialFunction` so that this case can be
// handled as a failed conversion. It will throw an exception as
// long as the value is not null.
- parseJsonToken(parser, dataType)(PartialFunction.empty[JsonToken, Any])
+ parseJsonToken[AnyRef](parser, dataType)(PartialFunction.empty[JsonToken, AnyRef])
}
/**
@@ -306,14 +359,14 @@ class JacksonParser(
* to parse the JSON token using given function `f`. If the `f` failed to parse and convert the
* token, call `failedConversion` to handle the token.
*/
- private def parseJsonToken(
+ private def parseJsonToken[R >: Null](
parser: JsonParser,
- dataType: DataType)(f: PartialFunction[JsonToken, Any]): Any = {
+ dataType: DataType)(f: PartialFunction[JsonToken, R]): R = {
parser.getCurrentToken match {
case FIELD_NAME =>
// There are useless FIELD_NAMEs between START_OBJECT and END_OBJECT tokens
parser.nextToken()
- parseJsonToken(parser, dataType)(f)
+ parseJsonToken[R](parser, dataType)(f)
case null | VALUE_NULL => null
@@ -325,9 +378,9 @@ class JacksonParser(
* This function throws an exception for failed conversion, but returns null for empty string,
* to guard the non string types.
*/
- private def failedConversion(
+ private def failedConversion[R >: Null](
parser: JsonParser,
- dataType: DataType): PartialFunction[JsonToken, Any] = {
+ dataType: DataType): PartialFunction[JsonToken, R] = {
case VALUE_STRING if parser.getTextLength < 1 =>
// If conversion is failed, this produces `null` rather than throwing exception.
// This will protect the mismatch of types.
@@ -348,7 +401,7 @@ class JacksonParser(
private def convertObject(
parser: JsonParser,
schema: StructType,
- fieldConverters: Seq[ValueConverter]): InternalRow = {
+ fieldConverters: Array[ValueConverter]): InternalRow = {
val row = new GenericInternalRow(schema.length)
while (nextUntil(parser, JsonToken.END_OBJECT)) {
schema.getFieldIndex(parser.getCurrentName) match {
@@ -394,36 +447,30 @@ class JacksonParser(
}
/**
- * Parse the string JSON input to the set of [[InternalRow]]s.
+ * Parse the JSON input to the set of [[InternalRow]]s.
+ *
+ * @param recordLiteral an optional function that will be used to generate
+ * the corrupt record text instead of record.toString
*/
- def parse(input: String): Seq[InternalRow] = {
- if (input.trim.isEmpty) {
- Nil
- } else {
- try {
- Utils.tryWithResource(factory.createParser(input)) { parser =>
- parser.nextToken()
- rootConverter.apply(parser) match {
- case null => failedRecord(input)
- case row: InternalRow => row :: Nil
- case array: ArrayData =>
- // Here, as we support reading top level JSON arrays and take every element
- // in such an array as a row, this case is possible.
- if (array.numElements() == 0) {
- Nil
- } else {
- array.toArray[InternalRow](schema)
- }
- case _ =>
- failedRecord(input)
+ def parse[T](
+ record: T,
+ createParser: (JsonFactory, T) => JsonParser,
+ recordLiteral: T => UTF8String): Seq[InternalRow] = {
+ try {
+ Utils.tryWithResource(createParser(factory, record)) { parser =>
+ // a null first token is equivalent to testing for input.trim.isEmpty
+ // but it works on any token stream and not just strings
+ parser.nextToken() match {
+ case null => Nil
+ case _ => rootConverter.apply(parser) match {
+ case null => throw new SparkSQLJsonProcessingException("Root converter returned null")
+ case rows => rows
}
}
- } catch {
- case _: JsonProcessingException =>
- failedRecord(input)
- case _: SparkSQLJsonProcessingException =>
- failedRecord(input)
}
+ } catch {
+ case _: JsonProcessingException | _: SparkSQLJsonProcessingException =>
+ failedRecord(() => recordLiteral(record))
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 780fe51ac699..cb9493a57564 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -26,14 +26,14 @@ import org.apache.spark.internal.Logging
import org.apache.spark.Partition
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions}
-import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.types.UTF8String
/**
* Interface used to load a [[Dataset]] from external storage systems (e.g. file systems,
@@ -261,8 +261,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}
/**
- * Loads a JSON file (JSON Lines text format or
- * newline-delimited JSON) and returns the result as a `DataFrame`.
+ * Loads a JSON file and returns the results as a `DataFrame`.
+ *
+ * Both JSON (one record per file) and JSON Lines
+ * (newline-delimited JSON) are supported and can be selected with the `wholeFile` option.
*
* This function goes through the input once to determine the input schema. If you know the
* schema in advance, use the version that specifies the schema to avoid the extra scan.
@@ -301,6 +303,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `java.text.SimpleDateFormat`. This applies to timestamp type.
*
`timeZone` (default session local timezone): sets the string that indicates a timezone
* to be used to parse timestamps.
+ * `wholeFile` (default `false`): parse one record, which may span multiple lines,
+ * per file
*
*
* @since 2.0.0
@@ -332,20 +336,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @since 1.4.0
*/
def json(jsonRDD: RDD[String]): DataFrame = {
- val parsedOptions: JSONOptions =
- new JSONOptions(extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone)
- val columnNameOfCorruptRecord =
- parsedOptions.columnNameOfCorruptRecord
- .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
+ val parsedOptions = new JSONOptions(
+ extraOptions.toMap,
+ sparkSession.sessionState.conf.sessionLocalTimeZone,
+ sparkSession.sessionState.conf.columnNameOfCorruptRecord)
+ val createParser = CreateJacksonParser.string _
+
val schema = userSpecifiedSchema.getOrElse {
JsonInferSchema.infer(
jsonRDD,
- columnNameOfCorruptRecord,
- parsedOptions)
+ parsedOptions,
+ createParser)
}
+
val parsed = jsonRDD.mapPartitions { iter =>
- val parser = new JacksonParser(schema, columnNameOfCorruptRecord, parsedOptions)
- iter.flatMap(parser.parse)
+ val parser = new JacksonParser(schema, parsedOptions)
+ iter.flatMap(parser.parse(_, createParser, UTF8String.fromString))
}
Dataset.ofRows(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala
index 900263aeb21d..0762d1b7daae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.execution.datasources
-import java.io.{OutputStream, OutputStreamWriter}
+import java.io.{InputStream, OutputStream, OutputStreamWriter}
import java.nio.charset.{Charset, StandardCharsets}
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.compress._
import org.apache.hadoop.mapreduce.JobContext
@@ -27,6 +28,20 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.util.ReflectionUtils
object CodecStreams {
+ private def getDecompressionCodec(config: Configuration, file: Path): Option[CompressionCodec] = {
+ val compressionCodecs = new CompressionCodecFactory(config)
+ Option(compressionCodecs.getCodec(file))
+ }
+
+ def createInputStream(config: Configuration, file: Path): InputStream = {
+ val fs = file.getFileSystem(config)
+ val inputStream: InputStream = fs.open(file)
+
+ getDecompressionCodec(config, file)
+ .map(codec => codec.createInputStream(inputStream))
+ .getOrElse(inputStream)
+ }
+
private def getCompressionCodec(
context: JobContext,
file: Option[Path] = None): Option[CompressionCodec] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
new file mode 100644
index 000000000000..3e984effcb8d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.json
+
+import java.io.InputStream
+
+import scala.reflect.ClassTag
+
+import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
+import com.google.common.io.ByteStreams
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.io.{LongWritable, Text}
+import org.apache.hadoop.mapreduce.Job
+import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat}
+
+import org.apache.spark.TaskContext
+import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
+import org.apache.spark.rdd.{BinaryFileRDD, RDD}
+import org.apache.spark.sql.{AnalysisException, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
+import org.apache.spark.sql.execution.datasources.{CodecStreams, HadoopFileLinesReader, PartitionedFile}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.Utils
+
+/**
+ * Common functions for parsing JSON files
+ * @tparam T A datatype containing the unparsed JSON, such as [[Text]] or [[String]]
+ */
+abstract class JsonDataSource[T] extends Serializable {
+ def isSplitable: Boolean
+
+ /**
+ * Parse a [[PartitionedFile]] into 0 or more [[InternalRow]] instances
+ */
+ def readFile(
+ conf: Configuration,
+ file: PartitionedFile,
+ parser: JacksonParser): Iterator[InternalRow]
+
+ /**
+ * Create an [[RDD]] that handles the preliminary parsing of [[T]] records
+ */
+ protected def createBaseRdd(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus]): RDD[T]
+
+ /**
+ * A generic wrapper to invoke the correct [[JsonFactory]] method to allocate a [[JsonParser]]
+ * for an instance of [[T]]
+ */
+ def createParser(jsonFactory: JsonFactory, value: T): JsonParser
+
+ final def infer(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus],
+ parsedOptions: JSONOptions): Option[StructType] = {
+ if (inputPaths.nonEmpty) {
+ val jsonSchema = JsonInferSchema.infer(
+ createBaseRdd(sparkSession, inputPaths),
+ parsedOptions,
+ createParser)
+ checkConstraints(jsonSchema)
+ Some(jsonSchema)
+ } else {
+ None
+ }
+ }
+
+ /** Constraints to be imposed on schema to be stored. */
+ private def checkConstraints(schema: StructType): Unit = {
+ if (schema.fieldNames.length != schema.fieldNames.distinct.length) {
+ val duplicateColumns = schema.fieldNames.groupBy(identity).collect {
+ case (x, ys) if ys.length > 1 => "\"" + x + "\""
+ }.mkString(", ")
+ throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " +
+ s"cannot save to JSON format")
+ }
+ }
+}
+
+object JsonDataSource {
+ def apply(options: JSONOptions): JsonDataSource[_] = {
+ if (options.wholeFile) {
+ WholeFileJsonDataSource
+ } else {
+ TextInputJsonDataSource
+ }
+ }
+
+ /**
+ * Create a new [[RDD]] via the supplied callback if there is at least one file to process,
+ * otherwise an [[org.apache.spark.rdd.EmptyRDD]] will be returned.
+ */
+ def createBaseRdd[T : ClassTag](
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus])(
+ fn: (Configuration, String) => RDD[T]): RDD[T] = {
+ val paths = inputPaths.map(_.getPath)
+
+ if (paths.nonEmpty) {
+ val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
+ FileInputFormat.setInputPaths(job, paths: _*)
+ fn(job.getConfiguration, paths.mkString(","))
+ } else {
+ sparkSession.sparkContext.emptyRDD[T]
+ }
+ }
+}
+
+object TextInputJsonDataSource extends JsonDataSource[Text] {
+ override val isSplitable: Boolean = {
+ // splittable if the underlying source is
+ true
+ }
+
+ override protected def createBaseRdd(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus]): RDD[Text] = {
+ JsonDataSource.createBaseRdd(sparkSession, inputPaths) {
+ case (conf, name) =>
+ sparkSession.sparkContext.newAPIHadoopRDD(
+ conf,
+ classOf[TextInputFormat],
+ classOf[LongWritable],
+ classOf[Text])
+ .setName(s"JsonLines: $name")
+ .values // get the text column
+ }
+ }
+
+ override def readFile(
+ conf: Configuration,
+ file: PartitionedFile,
+ parser: JacksonParser): Iterator[InternalRow] = {
+ val linesReader = new HadoopFileLinesReader(file, conf)
+ Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
+ linesReader.flatMap(parser.parse(_, createParser, textToUTF8String))
+ }
+
+ private def textToUTF8String(value: Text): UTF8String = {
+ UTF8String.fromBytes(value.getBytes, 0, value.getLength)
+ }
+
+ override def createParser(jsonFactory: JsonFactory, value: Text): JsonParser = {
+ CreateJacksonParser.text(jsonFactory, value)
+ }
+}
+
+object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] {
+ override val isSplitable: Boolean = {
+ false
+ }
+
+ override protected def createBaseRdd(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = {
+ JsonDataSource.createBaseRdd(sparkSession, inputPaths) {
+ case (conf, name) =>
+ new BinaryFileRDD(
+ sparkSession.sparkContext,
+ classOf[StreamInputFormat],
+ classOf[String],
+ classOf[PortableDataStream],
+ conf,
+ sparkSession.sparkContext.defaultMinPartitions)
+ .setName(s"JsonFile: $name")
+ .values
+ }
+ }
+
+ private def createInputStream(config: Configuration, path: String): InputStream = {
+ val inputStream = CodecStreams.createInputStream(config, new Path(path))
+ Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close()))
+ inputStream
+ }
+
+ override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
+ CreateJacksonParser.inputStream(
+ jsonFactory,
+ createInputStream(record.getConfiguration, record.getPath()))
+ }
+
+ override def readFile(
+ conf: Configuration,
+ file: PartitionedFile,
+ parser: JacksonParser): Iterator[InternalRow] = {
+ def partitionedFileString(ignored: Any): UTF8String = {
+ Utils.tryWithResource(createInputStream(conf, file.filePath)) { inputStream =>
+ UTF8String.fromBytes(ByteStreams.toByteArray(inputStream))
+ }
+ }
+
+ parser.parse(
+ createInputStream(conf, file.filePath),
+ CreateJacksonParser.inputStream,
+ partitionedFileString).toIterator
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index b4a8ff2cf01a..2cbf4ea7beac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -19,15 +19,10 @@ package org.apache.spark.sql.execution.datasources.json
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
-import org.apache.hadoop.io.{LongWritable, Text}
-import org.apache.hadoop.mapred.{JobConf, TextInputFormat}
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
-import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
-import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
+import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.CompressionCodecs
@@ -37,29 +32,30 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
+ override val shortName: String = "json"
- override def shortName(): String = "json"
+ override def isSplitable(
+ sparkSession: SparkSession,
+ options: Map[String, String],
+ path: Path): Boolean = {
+ val parsedOptions = new JSONOptions(
+ options,
+ sparkSession.sessionState.conf.sessionLocalTimeZone,
+ sparkSession.sessionState.conf.columnNameOfCorruptRecord)
+ val jsonDataSource = JsonDataSource(parsedOptions)
+ jsonDataSource.isSplitable && super.isSplitable(sparkSession, options, path)
+ }
override def inferSchema(
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
- if (files.isEmpty) {
- None
- } else {
- val parsedOptions: JSONOptions =
- new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
- val columnNameOfCorruptRecord =
- parsedOptions.columnNameOfCorruptRecord
- .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
- val jsonSchema = JsonInferSchema.infer(
- createBaseRdd(sparkSession, files),
- columnNameOfCorruptRecord,
- parsedOptions)
- checkConstraints(jsonSchema)
-
- Some(jsonSchema)
- }
+ val parsedOptions = new JSONOptions(
+ options,
+ sparkSession.sessionState.conf.sessionLocalTimeZone,
+ sparkSession.sessionState.conf.columnNameOfCorruptRecord)
+ JsonDataSource(parsedOptions).infer(
+ sparkSession, files, parsedOptions)
}
override def prepareWrite(
@@ -68,8 +64,10 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
val conf = job.getConfiguration
- val parsedOptions: JSONOptions =
- new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
+ val parsedOptions = new JSONOptions(
+ options,
+ sparkSession.sessionState.conf.sessionLocalTimeZone,
+ sparkSession.sessionState.conf.columnNameOfCorruptRecord)
parsedOptions.compressionCodec.foreach { codec =>
CompressionCodecs.setCodecConfiguration(conf, codec)
}
@@ -99,47 +97,17 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
- val parsedOptions: JSONOptions =
- new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
- val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord
- .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
+ val parsedOptions = new JSONOptions(
+ options,
+ sparkSession.sessionState.conf.sessionLocalTimeZone,
+ sparkSession.sessionState.conf.columnNameOfCorruptRecord)
(file: PartitionedFile) => {
- val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
- Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
- val lines = linesReader.map(_.toString)
- val parser = new JacksonParser(requiredSchema, columnNameOfCorruptRecord, parsedOptions)
- lines.flatMap(parser.parse)
- }
- }
-
- private def createBaseRdd(
- sparkSession: SparkSession,
- inputPaths: Seq[FileStatus]): RDD[String] = {
- val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
- val conf = job.getConfiguration
-
- val paths = inputPaths.map(_.getPath)
-
- if (paths.nonEmpty) {
- FileInputFormat.setInputPaths(job, paths: _*)
- }
-
- sparkSession.sparkContext.hadoopRDD(
- conf.asInstanceOf[JobConf],
- classOf[TextInputFormat],
- classOf[LongWritable],
- classOf[Text]).map(_._2.toString) // get the text line
- }
-
- /** Constraints to be imposed on schema to be stored. */
- private def checkConstraints(schema: StructType): Unit = {
- if (schema.fieldNames.length != schema.fieldNames.distinct.length) {
- val duplicateColumns = schema.fieldNames.groupBy(identity).collect {
- case (x, ys) if ys.length > 1 => "\"" + x + "\""
- }.mkString(", ")
- throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " +
- s"cannot save to JSON format")
+ val parser = new JacksonParser(requiredSchema, parsedOptions)
+ JsonDataSource(parsedOptions).readFile(
+ broadcastedHadoopConf.value.value,
+ file,
+ parser)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
index f51c18d46f45..ab09358115c0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
@@ -36,13 +36,14 @@ private[sql] object JsonInferSchema {
* 2. Merge types by choosing the lowest type necessary to cover equal keys
* 3. Replace any remaining null fields with string, the top type
*/
- def infer(
- json: RDD[String],
- columnNameOfCorruptRecord: String,
- configOptions: JSONOptions): StructType = {
+ def infer[T](
+ json: RDD[T],
+ configOptions: JSONOptions,
+ createParser: (JsonFactory, T) => JsonParser): StructType = {
require(configOptions.samplingRatio > 0,
s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0")
val shouldHandleCorruptRecord = configOptions.permissive
+ val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
val schemaData = if (configOptions.samplingRatio > 0.99) {
json
} else {
@@ -55,7 +56,7 @@ private[sql] object JsonInferSchema {
configOptions.setJacksonOptions(factory)
iter.flatMap { row =>
try {
- Utils.tryWithResource(factory.createParser(row)) { parser =>
+ Utils.tryWithResource(createParser(factory, row)) { parser =>
parser.nextToken()
Some(inferField(parser, configOptions))
}
@@ -79,7 +80,7 @@ private[sql] object JsonInferSchema {
private[this] val structFieldComparator = new Comparator[StructField] {
override def compare(o1: StructField, o2: StructField): Int = {
- o1.name.compare(o2.name)
+ o1.name.compareTo(o2.name)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 4e706da184c0..99943944f3c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -141,8 +141,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
}
/**
- * Loads a JSON file stream (JSON Lines text format or
- * newline-delimited JSON) and returns the result as a `DataFrame`.
+ * Loads a JSON file stream and returns the results as a `DataFrame`.
+ *
+ * Both JSON (one record per file) and JSON Lines
+ * (newline-delimited JSON) are supported and can be selected with the `wholeFile` option.
*
* This function goes through the input once to determine the input schema. If you know the
* schema in advance, use the version that specifies the schema to avoid the extra scan.
@@ -183,6 +185,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* `java.text.SimpleDateFormat`. This applies to timestamp type.
* `timeZone` (default session local timezone): sets the string that indicates a timezone
* to be used to parse timestamps.
+ * `wholeFile` (default `false`): parse one record, which may span multiple lines,
+ * per file
*
*
* @since 2.0.0
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 9344aeda0017..05aa2ab2ce2d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -28,8 +28,8 @@ import org.apache.hadoop.io.compress.GzipCodec
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkException
-import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions}
+import org.apache.spark.sql.{functions => F, _}
+import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType
@@ -64,7 +64,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val dummyOption = new JSONOptions(Map.empty[String, String], "GMT")
val dummySchema = StructType(Seq.empty)
- val parser = new JacksonParser(dummySchema, "", dummyOption)
+ val parser = new JacksonParser(dummySchema, dummyOption)
Utils.tryWithResource(factory.createParser(writer.toString)) { jsonParser =>
jsonParser.nextToken()
@@ -1367,7 +1367,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
test("SPARK-6245 JsonRDD.inferSchema on empty RDD") {
// This is really a test that it doesn't throw an exception
val emptySchema = JsonInferSchema.infer(
- empty, "", new JSONOptions(Map.empty[String, String], "GMT"))
+ empty,
+ new JSONOptions(Map.empty[String, String], "GMT"),
+ CreateJacksonParser.string)
assert(StructType(Seq()) === emptySchema)
}
@@ -1392,7 +1394,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
test("SPARK-8093 Erase empty structs") {
val emptySchema = JsonInferSchema.infer(
- emptyRecords, "", new JSONOptions(Map.empty[String, String], "GMT"))
+ emptyRecords,
+ new JSONOptions(Map.empty[String, String], "GMT"),
+ CreateJacksonParser.string)
assert(StructType(Seq()) === emptySchema)
}
@@ -1802,4 +1806,142 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val df2 = spark.read.option("PREfersdecimaL", "true").json(records)
assert(df2.schema == schema)
}
+
+ test("SPARK-18352: Parse normal multi-line JSON files (compressed)") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ primitiveFieldAndType
+ .toDF("value")
+ .write
+ .option("compression", "GzIp")
+ .text(path)
+
+ assert(new File(path).listFiles().exists(_.getName.endsWith(".gz")))
+
+ val jsonDF = spark.read.option("wholeFile", true).json(path)
+ val jsonDir = new File(dir, "json").getCanonicalPath
+ jsonDF.coalesce(1).write
+ .option("compression", "gZiP")
+ .json(jsonDir)
+
+ assert(new File(jsonDir).listFiles().exists(_.getName.endsWith(".json.gz")))
+
+ val originalData = spark.read.json(primitiveFieldAndType)
+ checkAnswer(jsonDF, originalData)
+ checkAnswer(spark.read.schema(originalData.schema).json(jsonDir), originalData)
+ }
+ }
+
+ test("SPARK-18352: Parse normal multi-line JSON files (uncompressed)") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ primitiveFieldAndType
+ .toDF("value")
+ .write
+ .text(path)
+
+ val jsonDF = spark.read.option("wholeFile", true).json(path)
+ val jsonDir = new File(dir, "json").getCanonicalPath
+ jsonDF.coalesce(1).write.json(jsonDir)
+
+ val compressedFiles = new File(jsonDir).listFiles()
+ assert(compressedFiles.exists(_.getName.endsWith(".json")))
+
+ val originalData = spark.read.json(primitiveFieldAndType)
+ checkAnswer(jsonDF, originalData)
+ checkAnswer(spark.read.schema(originalData.schema).json(jsonDir), originalData)
+ }
+ }
+
+ test("SPARK-18352: Expect one JSON document per file") {
+ // the json parser terminates as soon as it sees a matching END_OBJECT or END_ARRAY token.
+ // this might not be the optimal behavior but this test verifies that only the first value
+ // is parsed and the rest are discarded.
+
+ // alternatively the parser could continue parsing following objects, which may further reduce
+ // allocations by skipping the line reader entirely
+
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ spark
+ .createDataFrame(Seq(Tuple1("{}{invalid}")))
+ .coalesce(1)
+ .write
+ .text(path)
+
+ val jsonDF = spark.read.option("wholeFile", true).json(path)
+ // no corrupt record column should be created
+ assert(jsonDF.schema === StructType(Seq()))
+ // only the first object should be read
+ assert(jsonDF.count() === 1)
+ }
+ }
+
+ test("SPARK-18352: Handle multi-line corrupt documents (PERMISSIVE)") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ val corruptRecordCount = additionalCorruptRecords.count().toInt
+ assert(corruptRecordCount === 5)
+
+ additionalCorruptRecords
+ .toDF("value")
+ // this is the minimum partition count that avoids hash collisions
+ .repartition(corruptRecordCount * 4, F.hash($"value"))
+ .write
+ .text(path)
+
+ val jsonDF = spark.read.option("wholeFile", true).option("mode", "PERMISSIVE").json(path)
+ assert(jsonDF.count() === corruptRecordCount)
+ assert(jsonDF.schema === new StructType()
+ .add("_corrupt_record", StringType)
+ .add("dummy", StringType))
+ val counts = jsonDF
+ .join(
+ additionalCorruptRecords.toDF("value"),
+ F.regexp_replace($"_corrupt_record", "(^\\s+|\\s+$)", "") === F.trim($"value"),
+ "outer")
+ .agg(
+ F.count($"dummy").as("valid"),
+ F.count($"_corrupt_record").as("corrupt"),
+ F.count("*").as("count"))
+ checkAnswer(counts, Row(1, 4, 6))
+ }
+ }
+
+ test("SPARK-18352: Handle multi-line corrupt documents (FAILFAST)") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ val corruptRecordCount = additionalCorruptRecords.count().toInt
+ assert(corruptRecordCount === 5)
+
+ additionalCorruptRecords
+ .toDF("value")
+ // this is the minimum partition count that avoids hash collisions
+ .repartition(corruptRecordCount * 4, F.hash($"value"))
+ .write
+ .text(path)
+
+ val schema = new StructType().add("dummy", StringType)
+
+ // `FAILFAST` mode should throw an exception for corrupt records.
+ val exceptionOne = intercept[SparkException] {
+ spark.read
+ .option("wholeFile", true)
+ .option("mode", "FAILFAST")
+ .json(path)
+ .collect()
+ }
+ assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode"))
+
+ val exceptionTwo = intercept[SparkException] {
+ spark.read
+ .option("wholeFile", true)
+ .option("mode", "FAILFAST")
+ .schema(schema)
+ .json(path)
+ .collect()
+ }
+ assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode"))
+ }
+ }
}