Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
04c8cde
[SPARK-18352][SQL] Support parsing multiline json files
Dec 22, 2016
86e3cd2
JacksonParser.parseJsonToken should be explicit about nulls and boxing
Dec 22, 2016
f48fc66
Increase type safety of makeRootConverter, remove runtime type tests
Dec 21, 2016
361df08
Field converter lookups should be O(1)
Dec 23, 2016
06169f7
Support inference of tables with no columns
Dec 27, 2016
9b8a265
Improve failedRecord consistency with and without wholeFile mode enabled
Dec 29, 2016
42bee5e
Use withTempPath instead of withTempDir
Feb 8, 2017
b17809e
Simplify the corrupt document test
Feb 8, 2017
422f9b0
Always return a ByteBuffer from getByteBuffer
Feb 8, 2017
042d746
Add @Since annotations for all PortableDataStream public methods
Feb 9, 2017
e1a620a
Very verbosely test to see if a warning has already been printed
Feb 10, 2017
3e67473
Eagerly validate the corrupt column datatype
Feb 11, 2017
8ed787f
Avoid broadcasting JsonDataSource references
Feb 11, 2017
14d5f93
Remove name binding
Feb 11, 2017
5f5214b
Repartition by value instead of luck
Feb 11, 2017
7296f7e
Always provide a `T => UTF8String` conversion function
Feb 15, 2017
691fa2a
Reorder documentation to match the function parameter order
Feb 16, 2017
a629470
Fix build break in Python due to bad rebase
Feb 16, 2017
463062a
Fix constructor invocation style
Feb 16, 2017
24786d1
Check wholeFile roundtrip against all columns and the source RDD
Feb 17, 2017
e323317
Add tests for FAILFAST mode
Feb 17, 2017
58118f2
More style fixes
Feb 17, 2017
b801ab0
Missed one Javadoc parameter reordering
Feb 17, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,13 @@ public void writeTo(ByteBuffer buffer) {
buffer.position(pos + numBytes);
}

public void writeTo(OutputStream out) throws IOException {
/**
* Returns a {@link ByteBuffer} wrapping the base object if it is a byte array
* or a copy of the data if the base object is not a byte array.
*
* Unlike getBytes this will not create a copy the array if this is a slice.
*/
public @Nonnull ByteBuffer getByteBuffer() {
if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) {
final byte[] bytes = (byte[]) base;

Expand All @@ -160,12 +166,20 @@ public void writeTo(OutputStream out) throws IOException {
throw new ArrayIndexOutOfBoundsException();
}

out.write(bytes, (int) arrayOffset, numBytes);
return ByteBuffer.wrap(bytes, (int) arrayOffset, numBytes);
} else {
out.write(getBytes());
return ByteBuffer.wrap(getBytes());
}
}

public void writeTo(OutputStream out) throws IOException {
final ByteBuffer bb = this.getByteBuffer();
assert(bb.hasArray());

// similar to Utils.writeByteBuffer but without the spark-core dependency
out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining());
}

/**
* Returns the number of bytes for a code point with the first byte as `b`
* @param b The first byte of a code point
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFil

import org.apache.spark.internal.config
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since

/**
* A general format for reading whole files in as streams, byte arrays,
Expand Down Expand Up @@ -175,6 +176,7 @@ class PortableDataStream(
* Create a new DataInputStream from the split and context. The user of this method is responsible
* for closing the stream after usage.
*/
@Since("1.2.0")
def open(): DataInputStream = {
val pathp = split.getPath(index)
val fs = pathp.getFileSystem(conf)
Expand All @@ -184,6 +186,7 @@ class PortableDataStream(
/**
* Read the file as a byte array
*/
@Since("1.2.0")
def toArray(): Array[Byte] = {
val stream = open()
try {
Expand All @@ -193,6 +196,10 @@ class PortableDataStream(
}
}

@Since("1.2.0")
def getPath(): String = path

@Since("2.2.0")
def getConfiguration: Configuration = conf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should rename it to getConf, getConfiguration is too verbose.

}

13 changes: 8 additions & 5 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
timeZone=None):
timeZone=None, wholeFile=None):
"""
Loads a JSON file (`JSON Lines text format or newline-delimited JSON
<http://jsonlines.org/>`_) or an RDD of Strings storing JSON objects (one object per
record) and returns the result as a :class`DataFrame`.
Loads a JSON file and returns the results as a :class:`DataFrame`.

Both JSON (one record per file) and `JSON Lines <http://jsonlines.org/>`_
(newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter.

If the ``schema`` parameter is not specified, this function goes
through the input once to determine the input schema.
Expand Down Expand Up @@ -212,6 +213,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
:param timeZone: sets the string that indicates a timezone to be used to parse timestamps.
If None is set, it uses the default value, session local timezone.
:param wholeFile: parse one record, which may span multiple lines, per file. If None is
set, it uses the default value, ``false``.

>>> df1 = spark.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
Expand All @@ -228,7 +231,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, timeZone=timeZone)
timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
Expand Down
14 changes: 9 additions & 5 deletions python/pyspark/sql/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,13 @@ def load(self, path=None, format=None, schema=None, **options):
def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None,
timestampFormat=None, timeZone=None):
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
timeZone=None, wholeFile=None):
"""
Loads a JSON file stream (`JSON Lines text format or newline-delimited JSON
<http://jsonlines.org/>`_) and returns a :class`DataFrame`.
Loads a JSON file stream and returns the results as a :class:`DataFrame`.

Both JSON (one record per file) and `JSON Lines <http://jsonlines.org/>`_
(newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter.

If the ``schema`` parameter is not specified, this function goes
through the input once to determine the input schema.
Expand Down Expand Up @@ -483,6 +485,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
:param timeZone: sets the string that indicates a timezone to be used to parse timestamps.
If None is set, it uses the default value, session local timezone.
:param wholeFile: parse one record, which may span multiple lines, per file. If None is
set, it uses the default value, ``false``.

>>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema)
>>> json_sdf.isStreaming
Expand All @@ -496,7 +500,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, timeZone=timeZone)
timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile)
if isinstance(path, basestring):
return self._df(self._jreader.json(path))
else:
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,13 @@ def test_udf_with_order_by_and_limit(self):
res.explain(True)
self.assertEqual(res.collect(), [Row(id=0, copy=0)])

def test_wholefile_json(self):
from pyspark.sql.types import StringType
people1 = self.spark.read.json("python/test_support/sql/people.json")
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
wholeFile=True)
self.assertEqual(people1.collect(), people_array.collect())

def test_udf_with_input_file_name(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.types import StringType
Expand Down
13 changes: 13 additions & 0 deletions python/test_support/sql/people_array.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[
{
"name": "Michael"
},
{
"name": "Andy",
"age": 30
},
{
"name": "Justin",
"age": 19
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -497,16 +497,20 @@ case class JsonToStruct(
lazy val parser =
new JacksonParser(
schema,
"invalid", // Not used since we force fail fast. Invalid rows will be set to `null`.
new JSONOptions(options ++ Map("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get))
new JSONOptions(options + ("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get))

override def dataType: DataType = schema

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

override def nullSafeEval(json: Any): Any = {
try parser.parse(json.toString).headOption.orNull catch {
try {
parser.parse(
json.asInstanceOf[UTF8String],
CreateJacksonParser.utf8String,
identity[UTF8String]).headOption.orNull
} catch {
case _: SparkSQLJsonProcessingException => null
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.json

import java.io.InputStream

import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
import org.apache.hadoop.io.Text

import org.apache.spark.unsafe.types.UTF8String

private[sql] object CreateJacksonParser extends Serializable {
def string(jsonFactory: JsonFactory, record: String): JsonParser = {
jsonFactory.createParser(record)
}

def utf8String(jsonFactory: JsonFactory, record: UTF8String): JsonParser = {
val bb = record.getByteBuffer
assert(bb.hasArray)

jsonFactory.createParser(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
}

def text(jsonFactory: JsonFactory, record: Text): JsonParser = {
jsonFactory.createParser(record.getBytes, 0, record.getLength)
}

def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = {
jsonFactory.createParser(record)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,20 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs
* Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]].
*/
private[sql] class JSONOptions(
@transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String)
@transient private val parameters: CaseInsensitiveMap[String],
defaultTimeZoneId: String,
defaultColumnNameOfCorruptRecord: String)
extends Logging with Serializable {

def this(parameters: Map[String, String], defaultTimeZoneId: String) =
this(CaseInsensitiveMap(parameters), defaultTimeZoneId)
def this(
parameters: Map[String, String],
defaultTimeZoneId: String,
defaultColumnNameOfCorruptRecord: String = "") = {
this(
CaseInsensitiveMap(parameters),
defaultTimeZoneId,
defaultColumnNameOfCorruptRecord)
}

val samplingRatio =
parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
Expand All @@ -57,7 +66,8 @@ private[sql] class JSONOptions(
parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false)
val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName)
private val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
val columnNameOfCorruptRecord = parameters.get("columnNameOfCorruptRecord")
val columnNameOfCorruptRecord =
parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord)

val timeZone: TimeZone = TimeZone.getTimeZone(parameters.getOrElse("timeZone", defaultTimeZoneId))

Expand All @@ -69,6 +79,8 @@ private[sql] class JSONOptions(
FastDateFormat.getInstance(
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US)

val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false)

// Parse mode flags
if (!ParseModes.isValidMode(parseMode)) {
logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.")
Expand Down
Loading