Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.csv._
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -360,17 +360,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
extraOptions.toMap,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val createParser = CreateJacksonParser.string _

val schema = userSpecifiedSchema.getOrElse {
JsonInferSchema.infer(
jsonDataset.rdd,
parsedOptions,
createParser)
TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions)
}

verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)

val createParser = CreateJacksonParser.string _
val parsed = jsonDataset.rdd.mapPartitions { iter =>
val parser = new JacksonParser(schema, parsedOptions)
iter.flatMap(parser.parse(_, createParser, UTF8String.fromString))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,30 @@

package org.apache.spark.sql.execution.datasources.json

import scala.reflect.ClassTag

import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
import com.google.common.io.ByteStreams
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.io.Text
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat}
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat

import org.apache.spark.TaskContext
import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.execution.datasources.{CodecStreams, HadoopFileLinesReader, PartitionedFile}
import org.apache.spark.sql.execution.datasources.{CodecStreams, DataSource, HadoopFileLinesReader, PartitionedFile}
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

/**
* Common functions for parsing JSON files
* @tparam T A datatype containing the unparsed JSON, such as [[Text]] or [[String]]
*/
abstract class JsonDataSource[T] extends Serializable {
abstract class JsonDataSource extends Serializable {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this file basically resembles CSVDataSource. (Note that this is almost identical if #17256 is merged).

def isSplitable: Boolean

/**
Expand All @@ -53,35 +51,24 @@ abstract class JsonDataSource[T] extends Serializable {
file: PartitionedFile,
parser: JacksonParser): Iterator[InternalRow]

/**
* Create an [[RDD]] that handles the preliminary parsing of [[T]] records
*/
protected def createBaseRdd(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus]): RDD[T]

/**
* A generic wrapper to invoke the correct [[JsonFactory]] method to allocate a [[JsonParser]]
* for an instance of [[T]]
*/
def createParser(jsonFactory: JsonFactory, value: T): JsonParser

final def infer(
final def inferSchema(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): Option[StructType] = {
if (inputPaths.nonEmpty) {
val jsonSchema = JsonInferSchema.infer(
createBaseRdd(sparkSession, inputPaths),
parsedOptions,
createParser)
val jsonSchema = infer(sparkSession, inputPaths, parsedOptions)
checkConstraints(jsonSchema)
Some(jsonSchema)
} else {
None
}
}

protected def infer(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): StructType

/** Constraints to be imposed on schema to be stored. */
private def checkConstraints(schema: StructType): Unit = {
if (schema.fieldNames.length != schema.fieldNames.distinct.length) {
Expand All @@ -95,53 +82,46 @@ abstract class JsonDataSource[T] extends Serializable {
}

object JsonDataSource {
def apply(options: JSONOptions): JsonDataSource[_] = {
def apply(options: JSONOptions): JsonDataSource = {
if (options.wholeFile) {
WholeFileJsonDataSource
} else {
TextInputJsonDataSource
}
}

/**
* Create a new [[RDD]] via the supplied callback if there is at least one file to process,
* otherwise an [[org.apache.spark.rdd.EmptyRDD]] will be returned.
*/
def createBaseRdd[T : ClassTag](
sparkSession: SparkSession,
inputPaths: Seq[FileStatus])(
fn: (Configuration, String) => RDD[T]): RDD[T] = {
val paths = inputPaths.map(_.getPath)

if (paths.nonEmpty) {
val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
FileInputFormat.setInputPaths(job, paths: _*)
fn(job.getConfiguration, paths.mkString(","))
} else {
sparkSession.sparkContext.emptyRDD[T]
}
}
}

object TextInputJsonDataSource extends JsonDataSource[Text] {
object TextInputJsonDataSource extends JsonDataSource {
override val isSplitable: Boolean = {
// splittable if the underlying source is
true
}

override protected def createBaseRdd(
override def infer(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus]): RDD[Text] = {
JsonDataSource.createBaseRdd(sparkSession, inputPaths) {
case (conf, name) =>
sparkSession.sparkContext.newAPIHadoopRDD(
conf,
classOf[TextInputFormat],
classOf[LongWritable],
classOf[Text])
.setName(s"JsonLines: $name")
.values // get the text column
}
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): StructType = {
val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths)
inferFromDataset(json, parsedOptions)
}

def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = {
val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions)
val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0))
JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String)
}

private def createBaseDataset(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus]): Dataset[String] = {
val paths = inputPaths.map(_.getPath.toString)
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = paths,
className = classOf[TextFileFormat].getName
).resolveRelation(checkFilesExist = false))
.select("value").as(Encoders.STRING)
}

override def readFile(
Expand All @@ -150,41 +130,48 @@ object TextInputJsonDataSource extends JsonDataSource[Text] {
parser: JacksonParser): Iterator[InternalRow] = {
val linesReader = new HadoopFileLinesReader(file, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
linesReader.flatMap(parser.parse(_, createParser, textToUTF8String))
linesReader.flatMap(parser.parse(_, CreateJacksonParser.text, textToUTF8String))
}

private def textToUTF8String(value: Text): UTF8String = {
UTF8String.fromBytes(value.getBytes, 0, value.getLength)
}

override def createParser(jsonFactory: JsonFactory, value: Text): JsonParser = {
CreateJacksonParser.text(jsonFactory, value)
}
}

object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] {
object WholeFileJsonDataSource extends JsonDataSource {
override val isSplitable: Boolean = {
false
}

override protected def createBaseRdd(
override def infer(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): StructType = {
val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths)
val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions)
JsonInferSchema.infer(sampled, parsedOptions, createParser)
}

private def createBaseRdd(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = {
JsonDataSource.createBaseRdd(sparkSession, inputPaths) {
case (conf, name) =>
new BinaryFileRDD(
sparkSession.sparkContext,
classOf[StreamInputFormat],
classOf[String],
classOf[PortableDataStream],
conf,
sparkSession.sparkContext.defaultMinPartitions)
.setName(s"JsonFile: $name")
.values
}
val paths = inputPaths.map(_.getPath)
val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
val conf = job.getConfiguration
val name = paths.mkString(",")
FileInputFormat.setInputPaths(job, paths: _*)
new BinaryFileRDD(
sparkSession.sparkContext,
classOf[StreamInputFormat],
classOf[String],
classOf[PortableDataStream],
conf,
sparkSession.sparkContext.defaultMinPartitions)
.setName(s"JsonFile: $name")
.values
}

override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
CreateJacksonParser.inputStream(
jsonFactory,
CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
options,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
JsonDataSource(parsedOptions).infer(
JsonDataSource(parsedOptions).inferSchema(
sparkSession, files, parsedOptions)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,11 @@ private[sql] object JsonInferSchema {
json: RDD[T],
configOptions: JSONOptions,
createParser: (JsonFactory, T) => JsonParser): StructType = {
require(configOptions.samplingRatio > 0,
s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0")
val shouldHandleCorruptRecord = configOptions.permissive
val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
val schemaData = if (configOptions.samplingRatio > 0.99) {
json
} else {
json.sample(withReplacement = false, configOptions.samplingRatio, 1)
}

// perform schema inference on each row and merge afterwards
val rootType = schemaData.mapPartitions { iter =>
val rootType = json.mapPartitions { iter =>
val factory = new JsonFactory()
configOptions.setJacksonOptions(factory)
iter.flatMap { row =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
Copy link
Member Author

@HyukjinKwon HyukjinKwon Mar 13, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be too much. I am willing to revert this back if anyone feels it is a bit odd.

I just made this just for matching it to CSVUtils that contains variants for logically same preprocessing performed on different data type (e.g. Iterator, RDD, Dataset).

* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.json

import org.apache.spark.input.PortableDataStream
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.catalyst.json.JSONOptions

object JsonUtils {
/**
* Sample JSON dataset as configured by `samplingRatio`.
*/
def sample(json: Dataset[String], options: JSONOptions): Dataset[String] = {
require(options.samplingRatio > 0,
s"samplingRatio (${options.samplingRatio}) should be greater than 0")
if (options.samplingRatio > 0.99) {
json
} else {
json.sample(withReplacement = false, options.samplingRatio, 1)
}
}

/**
* Sample JSON RDD as configured by `samplingRatio`.
*/
def sample(json: RDD[PortableDataStream], options: JSONOptions): RDD[PortableDataStream] = {
require(options.samplingRatio > 0,
s"samplingRatio (${options.samplingRatio}) should be greater than 0")
if (options.samplingRatio > 0.99) {
json
} else {
json.sample(withReplacement = false, options.samplingRatio, 1)
}
}
}