diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 3800d53c02f4..87b9e8eb445a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -147,7 +147,13 @@ public void writeTo(ByteBuffer buffer) { buffer.position(pos + numBytes); } - public void writeTo(OutputStream out) throws IOException { + /** + * Returns a {@link ByteBuffer} wrapping the base object if it is a byte array + * or a copy of the data if the base object is not a byte array. + * + * Unlike getBytes this will not create a copy the array if this is a slice. + */ + public @Nonnull ByteBuffer getByteBuffer() { if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) { final byte[] bytes = (byte[]) base; @@ -160,12 +166,20 @@ public void writeTo(OutputStream out) throws IOException { throw new ArrayIndexOutOfBoundsException(); } - out.write(bytes, (int) arrayOffset, numBytes); + return ByteBuffer.wrap(bytes, (int) arrayOffset, numBytes); } else { - out.write(getBytes()); + return ByteBuffer.wrap(getBytes()); } } + public void writeTo(OutputStream out) throws IOException { + final ByteBuffer bb = this.getByteBuffer(); + assert(bb.hasArray()); + + // similar to Utils.writeByteBuffer but without the spark-core dependency + out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()); + } + /** * Returns the number of bytes for a code point with the first byte as `b` * @param b The first byte of a code point diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 59404e08895a..9606c4754314 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFil import org.apache.spark.internal.config import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since /** * A general format for reading whole files in as streams, byte arrays, @@ -175,6 +176,7 @@ class PortableDataStream( * Create a new DataInputStream from the split and context. The user of this method is responsible * for closing the stream after usage. */ + @Since("1.2.0") def open(): DataInputStream = { val pathp = split.getPath(index) val fs = pathp.getFileSystem(conf) @@ -184,6 +186,7 @@ class PortableDataStream( /** * Read the file as a byte array */ + @Since("1.2.0") def toArray(): Array[Byte] = { val stream = open() try { @@ -193,6 +196,10 @@ class PortableDataStream( } } + @Since("1.2.0") def getPath(): String = path + + @Since("2.2.0") + def getConfiguration: Configuration = conf } diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 167833488980..6bed390e60c9 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -159,11 +159,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - timeZone=None): + timeZone=None, wholeFile=None): """ - Loads a JSON file (`JSON Lines text format or newline-delimited JSON - `_) or an RDD of Strings storing JSON objects (one object per - record) and returns the result as a :class`DataFrame`. + Loads a JSON file and returns the results as a :class:`DataFrame`. + + Both JSON (one record per file) and `JSON Lines `_ + (newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -212,6 +213,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. If None is set, it uses the default value, session local timezone. + :param wholeFile: parse one record, which may span multiple lines, per file. If None is + set, it uses the default value, ``false``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -228,7 +231,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, - timestampFormat=timestampFormat, timeZone=timeZone) + timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index d988e596a86d..965c8f6b269e 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -428,11 +428,13 @@ def load(self, path=None, format=None, schema=None, **options): def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, - mode=None, columnNameOfCorruptRecord=None, dateFormat=None, - timestampFormat=None, timeZone=None): + mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, + timeZone=None, wholeFile=None): """ - Loads a JSON file stream (`JSON Lines text format or newline-delimited JSON - `_) and returns a :class`DataFrame`. + Loads a JSON file stream and returns the results as a :class:`DataFrame`. + + Both JSON (one record per file) and `JSON Lines `_ + (newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -483,6 +485,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. If None is set, it uses the default value, session local timezone. + :param wholeFile: parse one record, which may span multiple lines, per file. If None is + set, it uses the default value, ``false``. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) >>> json_sdf.isStreaming @@ -496,7 +500,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, - timestampFormat=timestampFormat, timeZone=timeZone) + timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d8b7b3137c1c..9058443285ac 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -439,6 +439,13 @@ def test_udf_with_order_by_and_limit(self): res.explain(True) self.assertEqual(res.collect(), [Row(id=0, copy=0)]) + def test_wholefile_json(self): + from pyspark.sql.types import StringType + people1 = self.spark.read.json("python/test_support/sql/people.json") + people_array = self.spark.read.json("python/test_support/sql/people_array.json", + wholeFile=True) + self.assertEqual(people1.collect(), people_array.collect()) + def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name from pyspark.sql.types import StringType diff --git a/python/test_support/sql/people_array.json b/python/test_support/sql/people_array.json new file mode 100644 index 000000000000..c27c48fe343e --- /dev/null +++ b/python/test_support/sql/people_array.json @@ -0,0 +1,13 @@ +[ + { + "name": "Michael" + }, + { + "name": "Andy", + "age": 30 + }, + { + "name": "Justin", + "age": 19 + } +] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index bd852a50fe71..1e690a446951 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -497,8 +497,7 @@ case class JsonToStruct( lazy val parser = new JacksonParser( schema, - "invalid", // Not used since we force fail fast. Invalid rows will be set to `null`. - new JSONOptions(options ++ Map("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get)) + new JSONOptions(options + ("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get)) override def dataType: DataType = schema @@ -506,7 +505,12 @@ case class JsonToStruct( copy(timeZoneId = Option(timeZoneId)) override def nullSafeEval(json: Any): Any = { - try parser.parse(json.toString).headOption.orNull catch { + try { + parser.parse( + json.asInstanceOf[UTF8String], + CreateJacksonParser.utf8String, + identity[UTF8String]).headOption.orNull + } catch { case _: SparkSQLJsonProcessingException => null } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala new file mode 100644 index 000000000000..e0ed03a68981 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.json + +import java.io.InputStream + +import com.fasterxml.jackson.core.{JsonFactory, JsonParser} +import org.apache.hadoop.io.Text + +import org.apache.spark.unsafe.types.UTF8String + +private[sql] object CreateJacksonParser extends Serializable { + def string(jsonFactory: JsonFactory, record: String): JsonParser = { + jsonFactory.createParser(record) + } + + def utf8String(jsonFactory: JsonFactory, record: UTF8String): JsonParser = { + val bb = record.getByteBuffer + assert(bb.hasArray) + + jsonFactory.createParser(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + } + + def text(jsonFactory: JsonFactory, record: Text): JsonParser = { + jsonFactory.createParser(record.getBytes, 0, record.getLength) + } + + def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = { + jsonFactory.createParser(record) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5307ce1cb711..5a91f9c1939a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -31,11 +31,20 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. */ private[sql] class JSONOptions( - @transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String) + @transient private val parameters: CaseInsensitiveMap[String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { - def this(parameters: Map[String, String], defaultTimeZoneId: String) = - this(CaseInsensitiveMap(parameters), defaultTimeZoneId) + def this( + parameters: Map[String, String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String = "") = { + this( + CaseInsensitiveMap(parameters), + defaultTimeZoneId, + defaultColumnNameOfCorruptRecord) + } val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) @@ -57,7 +66,8 @@ private[sql] class JSONOptions( parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") - val columnNameOfCorruptRecord = parameters.get("columnNameOfCorruptRecord") + val columnNameOfCorruptRecord = + parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) val timeZone: TimeZone = TimeZone.getTimeZone(parameters.getOrElse("timeZone", defaultTimeZoneId)) @@ -69,6 +79,8 @@ private[sql] class JSONOptions( FastDateFormat.getInstance( parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US) + val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) + // Parse mode flags if (!ParseModes.isValidMode(parseMode)) { logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 03e27ba934fb..995095969d7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -39,7 +39,6 @@ private[sql] class SparkSQLJsonProcessingException(msg: String) extends RuntimeE */ class JacksonParser( schema: StructType, - columnNameOfCorruptRecord: String, options: JSONOptions) extends Logging { import JacksonUtils._ @@ -48,69 +47,110 @@ class JacksonParser( // A `ValueConverter` is responsible for converting a value from `JsonParser` // to a value in a field for `InternalRow`. - private type ValueConverter = (JsonParser) => Any + private type ValueConverter = JsonParser => AnyRef // `ValueConverter`s for the root schema for all fields in the schema - private val rootConverter: ValueConverter = makeRootConverter(schema) + private val rootConverter = makeRootConverter(schema) private val factory = new JsonFactory() options.setJacksonOptions(factory) private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length)) + private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) + corruptFieldIndex.foreach(idx => require(schema(idx).dataType == StringType)) + + @transient + private[this] var isWarningPrinted: Boolean = false + @transient - private[this] var isWarningPrintedForMalformedRecord: Boolean = false + private def printWarningForMalformedRecord(record: () => UTF8String): Unit = { + def sampleRecord: String = { + if (options.wholeFile) { + "" + } else { + s"Sample record: ${record()}\n" + } + } + + def footer: String = { + s"""Code example to print all malformed records (scala): + |=================================================== + |// The corrupted record exists in column ${options.columnNameOfCorruptRecord}. + |val parsedJson = spark.read.json("/path/to/json/file/test.json") + | + """.stripMargin + } + + if (options.permissive) { + logWarning( + s"""Found at least one malformed record. The JSON reader will replace + |all malformed records with placeholder null in current $PERMISSIVE_MODE parser mode. + |To find out which corrupted records have been replaced with null, please use the + |default inferred schema instead of providing a custom schema. + | + |${sampleRecord ++ footer} + | + """.stripMargin) + } else if (options.dropMalformed) { + logWarning( + s"""Found at least one malformed record. The JSON reader will drop + |all malformed records in current $DROP_MALFORMED_MODE parser mode. To find out which + |corrupted records have been dropped, please switch the parser mode to $PERMISSIVE_MODE + |mode and use the default inferred schema. + | + |${sampleRecord ++ footer} + | + """.stripMargin) + } + } + + @transient + private def printWarningIfWholeFile(): Unit = { + if (options.wholeFile && corruptFieldIndex.isDefined) { + logWarning( + s"""Enabling wholeFile mode and defining columnNameOfCorruptRecord may result + |in very large allocations or OutOfMemoryExceptions being raised. + | + """.stripMargin) + } + } /** * This function deals with the cases it fails to parse. This function will be called * when exceptions are caught during converting. This functions also deals with `mode` option. */ - private def failedRecord(record: String): Seq[InternalRow] = { - // create a row even if no corrupt record column is present - if (options.failFast) { - throw new SparkSQLJsonProcessingException(s"Malformed line in FAILFAST mode: $record") - } - if (options.dropMalformed) { - if (!isWarningPrintedForMalformedRecord) { - logWarning( - s"""Found at least one malformed records (sample: $record). The JSON reader will drop - |all malformed records in current $DROP_MALFORMED_MODE parser mode. To find out which - |corrupted records have been dropped, please switch the parser mode to $PERMISSIVE_MODE - |mode and use the default inferred schema. - | - |Code example to print all malformed records (scala): - |=================================================== - |// The corrupted record exists in column ${columnNameOfCorruptRecord} - |val parsedJson = spark.read.json("/path/to/json/file/test.json") - | - """.stripMargin) - isWarningPrintedForMalformedRecord = true - } - Nil - } else if (schema.getFieldIndex(columnNameOfCorruptRecord).isEmpty) { - if (!isWarningPrintedForMalformedRecord) { - logWarning( - s"""Found at least one malformed records (sample: $record). The JSON reader will replace - |all malformed records with placeholder null in current $PERMISSIVE_MODE parser mode. - |To find out which corrupted records have been replaced with null, please use the - |default inferred schema instead of providing a custom schema. - | - |Code example to print all malformed records (scala): - |=================================================== - |// The corrupted record exists in column ${columnNameOfCorruptRecord}. - |val parsedJson = spark.read.json("/path/to/json/file/test.json") - | - """.stripMargin) - isWarningPrintedForMalformedRecord = true - } - emptyRow - } else { - val row = new GenericInternalRow(schema.length) - for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecord)) { - require(schema(corruptIndex).dataType == StringType) - row.update(corruptIndex, UTF8String.fromString(record)) - } - Seq(row) + private def failedRecord(record: () => UTF8String): Seq[InternalRow] = { + corruptFieldIndex match { + case _ if options.failFast => + if (options.wholeFile) { + throw new SparkSQLJsonProcessingException("Malformed line in FAILFAST mode") + } else { + throw new SparkSQLJsonProcessingException(s"Malformed line in FAILFAST mode: ${record()}") + } + + case _ if options.dropMalformed => + if (!isWarningPrinted) { + printWarningForMalformedRecord(record) + isWarningPrinted = true + } + Nil + + case None => + if (!isWarningPrinted) { + printWarningForMalformedRecord(record) + isWarningPrinted = true + } + emptyRow + + case Some(corruptIndex) => + if (!isWarningPrinted) { + printWarningIfWholeFile() + isWarningPrinted = true + } + val row = new GenericInternalRow(schema.length) + row.update(corruptIndex, record()) + Seq(row) } } @@ -119,11 +159,11 @@ class JacksonParser( * to a value according to a desired schema. This is a wrapper for the method * `makeConverter()` to handle a row wrapped with an array. */ - private def makeRootConverter(st: StructType): ValueConverter = { + private def makeRootConverter(st: StructType): JsonParser => Seq[InternalRow] = { val elementConverter = makeConverter(st) - val fieldConverters = st.map(_.dataType).map(makeConverter) - (parser: JsonParser) => parseJsonToken(parser, st) { - case START_OBJECT => convertObject(parser, st, fieldConverters) + val fieldConverters = st.map(_.dataType).map(makeConverter).toArray + (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, st) { + case START_OBJECT => convertObject(parser, st, fieldConverters) :: Nil // SPARK-3308: support reading top level JSON arrays and take every element // in such an array as a row // @@ -137,7 +177,15 @@ class JacksonParser( // List([str_a_1,null]) // List([str_a_2,null], [null,str_b_3]) // - case START_ARRAY => convertArray(parser, elementConverter) + case START_ARRAY => + val array = convertArray(parser, elementConverter) + // Here, as we support reading top level JSON arrays and take every element + // in such an array as a row, this case is possible. + if (array.numElements() == 0) { + Nil + } else { + array.toArray[InternalRow](schema).toSeq + } } } @@ -145,35 +193,35 @@ class JacksonParser( * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. */ - private[sql] def makeConverter(dataType: DataType): ValueConverter = dataType match { + def makeConverter(dataType: DataType): ValueConverter = dataType match { case BooleanType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Boolean](parser, dataType) { case VALUE_TRUE => true case VALUE_FALSE => false } case ByteType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Byte](parser, dataType) { case VALUE_NUMBER_INT => parser.getByteValue } case ShortType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Short](parser, dataType) { case VALUE_NUMBER_INT => parser.getShortValue } case IntegerType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Integer](parser, dataType) { case VALUE_NUMBER_INT => parser.getIntValue } case LongType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) { case VALUE_NUMBER_INT => parser.getLongValue } case FloatType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Float](parser, dataType) { case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => parser.getFloatValue @@ -193,7 +241,7 @@ class JacksonParser( } case DoubleType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Double](parser, dataType) { case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => parser.getDoubleValue @@ -213,7 +261,7 @@ class JacksonParser( } case StringType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[UTF8String](parser, dataType) { case VALUE_STRING => UTF8String.fromString(parser.getText) @@ -227,66 +275,71 @@ class JacksonParser( } case TimestampType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) { case VALUE_STRING => + val stringValue = parser.getText // This one will lose microseconds parts. // See https://issues.apache.org/jira/browse/SPARK-10681. - Try(options.timestampFormat.parse(parser.getText).getTime * 1000L) - .getOrElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - DateTimeUtils.stringToTime(parser.getText).getTime * 1000L - } + Long.box { + Try(options.timestampFormat.parse(stringValue).getTime * 1000L) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.stringToTime(stringValue).getTime * 1000L + } + } case VALUE_NUMBER_INT => parser.getLongValue * 1000000L } case DateType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Integer](parser, dataType) { case VALUE_STRING => val stringValue = parser.getText // This one will lose microseconds parts. // See https://issues.apache.org/jira/browse/SPARK-10681.x - Try(DateTimeUtils.millisToDays(options.dateFormat.parse(parser.getText).getTime)) - .getOrElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime)) + Int.box { + Try(DateTimeUtils.millisToDays(options.dateFormat.parse(stringValue).getTime)) + .orElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(stringValue).getTime)) + } .getOrElse { - // In Spark 1.5.0, we store the data as number of days since epoch in string. - // So, we just convert it to Int. - stringValue.toInt - } + // In Spark 1.5.0, we store the data as number of days since epoch in string. + // So, we just convert it to Int. + stringValue.toInt + } } } case BinaryType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[Array[Byte]](parser, dataType) { case VALUE_STRING => parser.getBinaryValue } case dt: DecimalType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[Decimal](parser, dataType) { case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) => Decimal(parser.getDecimalValue, dt.precision, dt.scale) } case st: StructType => - val fieldConverters = st.map(_.dataType).map(makeConverter) - (parser: JsonParser) => parseJsonToken(parser, dataType) { + val fieldConverters = st.map(_.dataType).map(makeConverter).toArray + (parser: JsonParser) => parseJsonToken[InternalRow](parser, dataType) { case START_OBJECT => convertObject(parser, st, fieldConverters) } case at: ArrayType => val elementConverter = makeConverter(at.elementType) - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[ArrayData](parser, dataType) { case START_ARRAY => convertArray(parser, elementConverter) } case mt: MapType => val valueConverter = makeConverter(mt.valueType) - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[MapData](parser, dataType) { case START_OBJECT => convertMap(parser, valueConverter) } @@ -298,7 +351,7 @@ class JacksonParser( // Here, we pass empty `PartialFunction` so that this case can be // handled as a failed conversion. It will throw an exception as // long as the value is not null. - parseJsonToken(parser, dataType)(PartialFunction.empty[JsonToken, Any]) + parseJsonToken[AnyRef](parser, dataType)(PartialFunction.empty[JsonToken, AnyRef]) } /** @@ -306,14 +359,14 @@ class JacksonParser( * to parse the JSON token using given function `f`. If the `f` failed to parse and convert the * token, call `failedConversion` to handle the token. */ - private def parseJsonToken( + private def parseJsonToken[R >: Null]( parser: JsonParser, - dataType: DataType)(f: PartialFunction[JsonToken, Any]): Any = { + dataType: DataType)(f: PartialFunction[JsonToken, R]): R = { parser.getCurrentToken match { case FIELD_NAME => // There are useless FIELD_NAMEs between START_OBJECT and END_OBJECT tokens parser.nextToken() - parseJsonToken(parser, dataType)(f) + parseJsonToken[R](parser, dataType)(f) case null | VALUE_NULL => null @@ -325,9 +378,9 @@ class JacksonParser( * This function throws an exception for failed conversion, but returns null for empty string, * to guard the non string types. */ - private def failedConversion( + private def failedConversion[R >: Null]( parser: JsonParser, - dataType: DataType): PartialFunction[JsonToken, Any] = { + dataType: DataType): PartialFunction[JsonToken, R] = { case VALUE_STRING if parser.getTextLength < 1 => // If conversion is failed, this produces `null` rather than throwing exception. // This will protect the mismatch of types. @@ -348,7 +401,7 @@ class JacksonParser( private def convertObject( parser: JsonParser, schema: StructType, - fieldConverters: Seq[ValueConverter]): InternalRow = { + fieldConverters: Array[ValueConverter]): InternalRow = { val row = new GenericInternalRow(schema.length) while (nextUntil(parser, JsonToken.END_OBJECT)) { schema.getFieldIndex(parser.getCurrentName) match { @@ -394,36 +447,30 @@ class JacksonParser( } /** - * Parse the string JSON input to the set of [[InternalRow]]s. + * Parse the JSON input to the set of [[InternalRow]]s. + * + * @param recordLiteral an optional function that will be used to generate + * the corrupt record text instead of record.toString */ - def parse(input: String): Seq[InternalRow] = { - if (input.trim.isEmpty) { - Nil - } else { - try { - Utils.tryWithResource(factory.createParser(input)) { parser => - parser.nextToken() - rootConverter.apply(parser) match { - case null => failedRecord(input) - case row: InternalRow => row :: Nil - case array: ArrayData => - // Here, as we support reading top level JSON arrays and take every element - // in such an array as a row, this case is possible. - if (array.numElements() == 0) { - Nil - } else { - array.toArray[InternalRow](schema) - } - case _ => - failedRecord(input) + def parse[T]( + record: T, + createParser: (JsonFactory, T) => JsonParser, + recordLiteral: T => UTF8String): Seq[InternalRow] = { + try { + Utils.tryWithResource(createParser(factory, record)) { parser => + // a null first token is equivalent to testing for input.trim.isEmpty + // but it works on any token stream and not just strings + parser.nextToken() match { + case null => Nil + case _ => rootConverter.apply(parser) match { + case null => throw new SparkSQLJsonProcessingException("Root converter returned null") + case rows => rows } } - } catch { - case _: JsonProcessingException => - failedRecord(input) - case _: SparkSQLJsonProcessingException => - failedRecord(input) } + } catch { + case _: JsonProcessingException | _: SparkSQLJsonProcessingException => + failedRecord(() => recordLiteral(record)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 780fe51ac699..cb9493a57564 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -26,14 +26,14 @@ import org.apache.spark.internal.Logging import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.JsonInferSchema import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String /** * Interface used to load a [[Dataset]] from external storage systems (e.g. file systems, @@ -261,8 +261,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads a JSON file (JSON Lines text format or - * newline-delimited JSON) and returns the result as a `DataFrame`. + * Loads a JSON file and returns the results as a `DataFrame`. + * + * Both JSON (one record per file) and JSON Lines + * (newline-delimited JSON) are supported and can be selected with the `wholeFile` option. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -301,6 +303,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone * to be used to parse timestamps.
  • + *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, + * per file
  • * * * @since 2.0.0 @@ -332,20 +336,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def json(jsonRDD: RDD[String]): DataFrame = { - val parsedOptions: JSONOptions = - new JSONOptions(extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone) - val columnNameOfCorruptRecord = - parsedOptions.columnNameOfCorruptRecord - .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val parsedOptions = new JSONOptions( + extraOptions.toMap, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val createParser = CreateJacksonParser.string _ + val schema = userSpecifiedSchema.getOrElse { JsonInferSchema.infer( jsonRDD, - columnNameOfCorruptRecord, - parsedOptions) + parsedOptions, + createParser) } + val parsed = jsonRDD.mapPartitions { iter => - val parser = new JacksonParser(schema, columnNameOfCorruptRecord, parsedOptions) - iter.flatMap(parser.parse) + val parser = new JacksonParser(schema, parsedOptions) + iter.flatMap(parser.parse(_, createParser, UTF8String.fromString)) } Dataset.ofRows( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala index 900263aeb21d..0762d1b7daae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution.datasources -import java.io.{OutputStream, OutputStreamWriter} +import java.io.{InputStream, OutputStream, OutputStreamWriter} import java.nio.charset.{Charset, StandardCharsets} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.compress._ import org.apache.hadoop.mapreduce.JobContext @@ -27,6 +28,20 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.util.ReflectionUtils object CodecStreams { + private def getDecompressionCodec(config: Configuration, file: Path): Option[CompressionCodec] = { + val compressionCodecs = new CompressionCodecFactory(config) + Option(compressionCodecs.getCodec(file)) + } + + def createInputStream(config: Configuration, file: Path): InputStream = { + val fs = file.getFileSystem(config) + val inputStream: InputStream = fs.open(file) + + getDecompressionCodec(config, file) + .map(codec => codec.createInputStream(inputStream)) + .getOrElse(inputStream) + } + private def getCompressionCodec( context: JobContext, file: Option[Path] = None): Option[CompressionCodec] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala new file mode 100644 index 000000000000..3e984effcb8d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import java.io.InputStream + +import scala.reflect.ClassTag + +import com.fasterxml.jackson.core.{JsonFactory, JsonParser} +import com.google.common.io.ByteStreams +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat} + +import org.apache.spark.TaskContext +import org.apache.spark.input.{PortableDataStream, StreamInputFormat} +import org.apache.spark.rdd.{BinaryFileRDD, RDD} +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.execution.datasources.{CodecStreams, HadoopFileLinesReader, PartitionedFile} +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +/** + * Common functions for parsing JSON files + * @tparam T A datatype containing the unparsed JSON, such as [[Text]] or [[String]] + */ +abstract class JsonDataSource[T] extends Serializable { + def isSplitable: Boolean + + /** + * Parse a [[PartitionedFile]] into 0 or more [[InternalRow]] instances + */ + def readFile( + conf: Configuration, + file: PartitionedFile, + parser: JacksonParser): Iterator[InternalRow] + + /** + * Create an [[RDD]] that handles the preliminary parsing of [[T]] records + */ + protected def createBaseRdd( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus]): RDD[T] + + /** + * A generic wrapper to invoke the correct [[JsonFactory]] method to allocate a [[JsonParser]] + * for an instance of [[T]] + */ + def createParser(jsonFactory: JsonFactory, value: T): JsonParser + + final def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): Option[StructType] = { + if (inputPaths.nonEmpty) { + val jsonSchema = JsonInferSchema.infer( + createBaseRdd(sparkSession, inputPaths), + parsedOptions, + createParser) + checkConstraints(jsonSchema) + Some(jsonSchema) + } else { + None + } + } + + /** Constraints to be imposed on schema to be stored. */ + private def checkConstraints(schema: StructType): Unit = { + if (schema.fieldNames.length != schema.fieldNames.distinct.length) { + val duplicateColumns = schema.fieldNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => "\"" + x + "\"" + }.mkString(", ") + throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + + s"cannot save to JSON format") + } + } +} + +object JsonDataSource { + def apply(options: JSONOptions): JsonDataSource[_] = { + if (options.wholeFile) { + WholeFileJsonDataSource + } else { + TextInputJsonDataSource + } + } + + /** + * Create a new [[RDD]] via the supplied callback if there is at least one file to process, + * otherwise an [[org.apache.spark.rdd.EmptyRDD]] will be returned. + */ + def createBaseRdd[T : ClassTag]( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus])( + fn: (Configuration, String) => RDD[T]): RDD[T] = { + val paths = inputPaths.map(_.getPath) + + if (paths.nonEmpty) { + val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + FileInputFormat.setInputPaths(job, paths: _*) + fn(job.getConfiguration, paths.mkString(",")) + } else { + sparkSession.sparkContext.emptyRDD[T] + } + } +} + +object TextInputJsonDataSource extends JsonDataSource[Text] { + override val isSplitable: Boolean = { + // splittable if the underlying source is + true + } + + override protected def createBaseRdd( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus]): RDD[Text] = { + JsonDataSource.createBaseRdd(sparkSession, inputPaths) { + case (conf, name) => + sparkSession.sparkContext.newAPIHadoopRDD( + conf, + classOf[TextInputFormat], + classOf[LongWritable], + classOf[Text]) + .setName(s"JsonLines: $name") + .values // get the text column + } + } + + override def readFile( + conf: Configuration, + file: PartitionedFile, + parser: JacksonParser): Iterator[InternalRow] = { + val linesReader = new HadoopFileLinesReader(file, conf) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + linesReader.flatMap(parser.parse(_, createParser, textToUTF8String)) + } + + private def textToUTF8String(value: Text): UTF8String = { + UTF8String.fromBytes(value.getBytes, 0, value.getLength) + } + + override def createParser(jsonFactory: JsonFactory, value: Text): JsonParser = { + CreateJacksonParser.text(jsonFactory, value) + } +} + +object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { + override val isSplitable: Boolean = { + false + } + + override protected def createBaseRdd( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = { + JsonDataSource.createBaseRdd(sparkSession, inputPaths) { + case (conf, name) => + new BinaryFileRDD( + sparkSession.sparkContext, + classOf[StreamInputFormat], + classOf[String], + classOf[PortableDataStream], + conf, + sparkSession.sparkContext.defaultMinPartitions) + .setName(s"JsonFile: $name") + .values + } + } + + private def createInputStream(config: Configuration, path: String): InputStream = { + val inputStream = CodecStreams.createInputStream(config, new Path(path)) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) + inputStream + } + + override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream( + jsonFactory, + createInputStream(record.getConfiguration, record.getPath())) + } + + override def readFile( + conf: Configuration, + file: PartitionedFile, + parser: JacksonParser): Iterator[InternalRow] = { + def partitionedFileString(ignored: Any): UTF8String = { + Utils.tryWithResource(createInputStream(conf, file.filePath)) { inputStream => + UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) + } + } + + parser.parse( + createInputStream(conf, file.filePath), + CreateJacksonParser.inputStream, + partitionedFileString).toIterator + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index b4a8ff2cf01a..2cbf4ea7beac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -19,15 +19,10 @@ package org.apache.spark.sql.execution.datasources.json import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapred.{JobConf, TextInputFormat} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.CompressionCodecs @@ -37,29 +32,30 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { + override val shortName: String = "json" - override def shortName(): String = "json" + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + val parsedOptions = new JSONOptions( + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val jsonDataSource = JsonDataSource(parsedOptions) + jsonDataSource.isSplitable && super.isSplitable(sparkSession, options, path) + } override def inferSchema( sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - if (files.isEmpty) { - None - } else { - val parsedOptions: JSONOptions = - new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - val columnNameOfCorruptRecord = - parsedOptions.columnNameOfCorruptRecord - .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) - val jsonSchema = JsonInferSchema.infer( - createBaseRdd(sparkSession, files), - columnNameOfCorruptRecord, - parsedOptions) - checkConstraints(jsonSchema) - - Some(jsonSchema) - } + val parsedOptions = new JSONOptions( + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + JsonDataSource(parsedOptions).infer( + sparkSession, files, parsedOptions) } override def prepareWrite( @@ -68,8 +64,10 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { val conf = job.getConfiguration - val parsedOptions: JSONOptions = - new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val parsedOptions = new JSONOptions( + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) parsedOptions.compressionCodec.foreach { codec => CompressionCodecs.setCodecConfiguration(conf, codec) } @@ -99,47 +97,17 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val parsedOptions: JSONOptions = - new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord - .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val parsedOptions = new JSONOptions( + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) (file: PartitionedFile) => { - val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) - val lines = linesReader.map(_.toString) - val parser = new JacksonParser(requiredSchema, columnNameOfCorruptRecord, parsedOptions) - lines.flatMap(parser.parse) - } - } - - private def createBaseRdd( - sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[String] = { - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) - val conf = job.getConfiguration - - val paths = inputPaths.map(_.getPath) - - if (paths.nonEmpty) { - FileInputFormat.setInputPaths(job, paths: _*) - } - - sparkSession.sparkContext.hadoopRDD( - conf.asInstanceOf[JobConf], - classOf[TextInputFormat], - classOf[LongWritable], - classOf[Text]).map(_._2.toString) // get the text line - } - - /** Constraints to be imposed on schema to be stored. */ - private def checkConstraints(schema: StructType): Unit = { - if (schema.fieldNames.length != schema.fieldNames.distinct.length) { - val duplicateColumns = schema.fieldNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => "\"" + x + "\"" - }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + - s"cannot save to JSON format") + val parser = new JacksonParser(requiredSchema, parsedOptions) + JsonDataSource(parsedOptions).readFile( + broadcastedHadoopConf.value.value, + file, + parser) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index f51c18d46f45..ab09358115c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -36,13 +36,14 @@ private[sql] object JsonInferSchema { * 2. Merge types by choosing the lowest type necessary to cover equal keys * 3. Replace any remaining null fields with string, the top type */ - def infer( - json: RDD[String], - columnNameOfCorruptRecord: String, - configOptions: JSONOptions): StructType = { + def infer[T]( + json: RDD[T], + configOptions: JSONOptions, + createParser: (JsonFactory, T) => JsonParser): StructType = { require(configOptions.samplingRatio > 0, s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0") val shouldHandleCorruptRecord = configOptions.permissive + val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord val schemaData = if (configOptions.samplingRatio > 0.99) { json } else { @@ -55,7 +56,7 @@ private[sql] object JsonInferSchema { configOptions.setJacksonOptions(factory) iter.flatMap { row => try { - Utils.tryWithResource(factory.createParser(row)) { parser => + Utils.tryWithResource(createParser(factory, row)) { parser => parser.nextToken() Some(inferField(parser, configOptions)) } @@ -79,7 +80,7 @@ private[sql] object JsonInferSchema { private[this] val structFieldComparator = new Comparator[StructField] { override def compare(o1: StructField, o2: StructField): Int = { - o1.name.compare(o2.name) + o1.name.compareTo(o2.name) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 4e706da184c0..99943944f3c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -141,8 +141,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } /** - * Loads a JSON file stream (JSON Lines text format or - * newline-delimited JSON) and returns the result as a `DataFrame`. + * Loads a JSON file stream and returns the results as a `DataFrame`. + * + * Both JSON (one record per file) and JSON Lines + * (newline-delimited JSON) are supported and can be selected with the `wholeFile` option. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -183,6 +185,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone * to be used to parse timestamps.
  • + *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, + * per file
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 9344aeda0017..05aa2ab2ce2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -28,8 +28,8 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.rdd.RDD import org.apache.spark.SparkException -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} +import org.apache.spark.sql.{functions => F, _} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType @@ -64,7 +64,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dummyOption = new JSONOptions(Map.empty[String, String], "GMT") val dummySchema = StructType(Seq.empty) - val parser = new JacksonParser(dummySchema, "", dummyOption) + val parser = new JacksonParser(dummySchema, dummyOption) Utils.tryWithResource(factory.createParser(writer.toString)) { jsonParser => jsonParser.nextToken() @@ -1367,7 +1367,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { // This is really a test that it doesn't throw an exception val emptySchema = JsonInferSchema.infer( - empty, "", new JSONOptions(Map.empty[String, String], "GMT")) + empty, + new JSONOptions(Map.empty[String, String], "GMT"), + CreateJacksonParser.string) assert(StructType(Seq()) === emptySchema) } @@ -1392,7 +1394,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-8093 Erase empty structs") { val emptySchema = JsonInferSchema.infer( - emptyRecords, "", new JSONOptions(Map.empty[String, String], "GMT")) + emptyRecords, + new JSONOptions(Map.empty[String, String], "GMT"), + CreateJacksonParser.string) assert(StructType(Seq()) === emptySchema) } @@ -1802,4 +1806,142 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val df2 = spark.read.option("PREfersdecimaL", "true").json(records) assert(df2.schema == schema) } + + test("SPARK-18352: Parse normal multi-line JSON files (compressed)") { + withTempPath { dir => + val path = dir.getCanonicalPath + primitiveFieldAndType + .toDF("value") + .write + .option("compression", "GzIp") + .text(path) + + assert(new File(path).listFiles().exists(_.getName.endsWith(".gz"))) + + val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDir = new File(dir, "json").getCanonicalPath + jsonDF.coalesce(1).write + .option("compression", "gZiP") + .json(jsonDir) + + assert(new File(jsonDir).listFiles().exists(_.getName.endsWith(".json.gz"))) + + val originalData = spark.read.json(primitiveFieldAndType) + checkAnswer(jsonDF, originalData) + checkAnswer(spark.read.schema(originalData.schema).json(jsonDir), originalData) + } + } + + test("SPARK-18352: Parse normal multi-line JSON files (uncompressed)") { + withTempPath { dir => + val path = dir.getCanonicalPath + primitiveFieldAndType + .toDF("value") + .write + .text(path) + + val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDir = new File(dir, "json").getCanonicalPath + jsonDF.coalesce(1).write.json(jsonDir) + + val compressedFiles = new File(jsonDir).listFiles() + assert(compressedFiles.exists(_.getName.endsWith(".json"))) + + val originalData = spark.read.json(primitiveFieldAndType) + checkAnswer(jsonDF, originalData) + checkAnswer(spark.read.schema(originalData.schema).json(jsonDir), originalData) + } + } + + test("SPARK-18352: Expect one JSON document per file") { + // the json parser terminates as soon as it sees a matching END_OBJECT or END_ARRAY token. + // this might not be the optimal behavior but this test verifies that only the first value + // is parsed and the rest are discarded. + + // alternatively the parser could continue parsing following objects, which may further reduce + // allocations by skipping the line reader entirely + + withTempPath { dir => + val path = dir.getCanonicalPath + spark + .createDataFrame(Seq(Tuple1("{}{invalid}"))) + .coalesce(1) + .write + .text(path) + + val jsonDF = spark.read.option("wholeFile", true).json(path) + // no corrupt record column should be created + assert(jsonDF.schema === StructType(Seq())) + // only the first object should be read + assert(jsonDF.count() === 1) + } + } + + test("SPARK-18352: Handle multi-line corrupt documents (PERMISSIVE)") { + withTempPath { dir => + val path = dir.getCanonicalPath + val corruptRecordCount = additionalCorruptRecords.count().toInt + assert(corruptRecordCount === 5) + + additionalCorruptRecords + .toDF("value") + // this is the minimum partition count that avoids hash collisions + .repartition(corruptRecordCount * 4, F.hash($"value")) + .write + .text(path) + + val jsonDF = spark.read.option("wholeFile", true).option("mode", "PERMISSIVE").json(path) + assert(jsonDF.count() === corruptRecordCount) + assert(jsonDF.schema === new StructType() + .add("_corrupt_record", StringType) + .add("dummy", StringType)) + val counts = jsonDF + .join( + additionalCorruptRecords.toDF("value"), + F.regexp_replace($"_corrupt_record", "(^\\s+|\\s+$)", "") === F.trim($"value"), + "outer") + .agg( + F.count($"dummy").as("valid"), + F.count($"_corrupt_record").as("corrupt"), + F.count("*").as("count")) + checkAnswer(counts, Row(1, 4, 6)) + } + } + + test("SPARK-18352: Handle multi-line corrupt documents (FAILFAST)") { + withTempPath { dir => + val path = dir.getCanonicalPath + val corruptRecordCount = additionalCorruptRecords.count().toInt + assert(corruptRecordCount === 5) + + additionalCorruptRecords + .toDF("value") + // this is the minimum partition count that avoids hash collisions + .repartition(corruptRecordCount * 4, F.hash($"value")) + .write + .text(path) + + val schema = new StructType().add("dummy", StringType) + + // `FAILFAST` mode should throw an exception for corrupt records. + val exceptionOne = intercept[SparkException] { + spark.read + .option("wholeFile", true) + .option("mode", "FAILFAST") + .json(path) + .collect() + } + assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode")) + + val exceptionTwo = intercept[SparkException] { + spark.read + .option("wholeFile", true) + .option("mode", "FAILFAST") + .schema(schema) + .json(path) + .collect() + } + assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode")) + } + } }