Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions R/pkg/inst/tests/testthat/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1370,9 +1370,8 @@ test_that("column functions", {
# passing option
df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}")))
schema2 <- structType(structField("date", "date"))
expect_error(tryCatch(collect(select(df, from_json(df$col, schema2))),
error = function(e) { stop(e) }),
paste0(".*(java.lang.NumberFormatException: For input string:).*"))
s <- collect(select(df, from_json(df$col, schema2)))
expect_equal(s[[1]][[1]], NA)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also sounds a behavior change. Could you add another test case here to trigger the exception?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or a bug fix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea it's a minor bug fix, see cloud-fan#4

I'm not sure if it worth a ticket.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh, I see.

s <- collect(select(df, from_json(df$col, schema2, dateFormat = "dd/MM/yyyy")))
expect_is(s[[1]][[1]]$date, "Date")
expect_equal(as.character(s[[1]][[1]]$date), "2014-10-21")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, ParseModes}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, GenericArrayData, ParseModes}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -583,7 +583,7 @@ case class JsonToStructs(
CreateJacksonParser.utf8String,
identity[UTF8String]))
} catch {
case _: SparkSQLJsonProcessingException => null
case _: BadRecordException => null
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ private[sql] class JSONOptions(
val allowBackslashEscapingAnyCharacter =
parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false)
val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName)
private val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about creating an enum, like what we are doing for SaveMode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea this can be a good follow-up

val columnNameOfCorruptRecord =
parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,14 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

private[sql] class SparkSQLJsonProcessingException(msg: String) extends RuntimeException(msg)

/**
* Constructs a parser for a given schema that translates a json string to an [[InternalRow]].
*/
class JacksonParser(
schema: StructType,
options: JSONOptions) extends Logging {
val options: JSONOptions) extends Logging {

import JacksonUtils._
import ParseModes._
import com.fasterxml.jackson.core.JsonToken._

// A `ValueConverter` is responsible for converting a value from `JsonParser`
Expand All @@ -55,108 +52,6 @@ class JacksonParser(
private val factory = new JsonFactory()
options.setJacksonOptions(factory)

private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length))

private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord)
corruptFieldIndex.foreach { corrFieldIndex =>
require(schema(corrFieldIndex).dataType == StringType)
require(schema(corrFieldIndex).nullable)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above checking sounds missing in the new codes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a sanity check, actually this check is already done in DataFrameReader.csv/json and JsonFileFormat/CSVFileFormat


@transient
private[this] var isWarningPrinted: Boolean = false

@transient
private def printWarningForMalformedRecord(record: () => UTF8String): Unit = {
def sampleRecord: String = {
if (options.wholeFile) {
""
} else {
s"Sample record: ${record()}\n"
}
}

def footer: String = {
s"""Code example to print all malformed records (scala):
|===================================================
|// The corrupted record exists in column ${options.columnNameOfCorruptRecord}.
|val parsedJson = spark.read.json("/path/to/json/file/test.json")
|
""".stripMargin
}

if (options.permissive) {
logWarning(
s"""Found at least one malformed record. The JSON reader will replace
|all malformed records with placeholder null in current $PERMISSIVE_MODE parser mode.
|To find out which corrupted records have been replaced with null, please use the
|default inferred schema instead of providing a custom schema.
|
|${sampleRecord ++ footer}
|
""".stripMargin)
} else if (options.dropMalformed) {
logWarning(
s"""Found at least one malformed record. The JSON reader will drop
|all malformed records in current $DROP_MALFORMED_MODE parser mode. To find out which
|corrupted records have been dropped, please switch the parser mode to $PERMISSIVE_MODE
|mode and use the default inferred schema.
|
|${sampleRecord ++ footer}
|
""".stripMargin)
}
}

@transient
private def printWarningIfWholeFile(): Unit = {
if (options.wholeFile && corruptFieldIndex.isDefined) {
logWarning(
s"""Enabling wholeFile mode and defining columnNameOfCorruptRecord may result
|in very large allocations or OutOfMemoryExceptions being raised.
|
""".stripMargin)
}
}

/**
* This function deals with the cases it fails to parse. This function will be called
* when exceptions are caught during converting. This functions also deals with `mode` option.
*/
private def failedRecord(record: () => UTF8String): Seq[InternalRow] = {
corruptFieldIndex match {
case _ if options.failFast =>
if (options.wholeFile) {
throw new SparkSQLJsonProcessingException("Malformed line in FAILFAST mode")
} else {
throw new SparkSQLJsonProcessingException(s"Malformed line in FAILFAST mode: ${record()}")
}

case _ if options.dropMalformed =>
if (!isWarningPrinted) {
printWarningForMalformedRecord(record)
isWarningPrinted = true
}
Nil

case None =>
if (!isWarningPrinted) {
printWarningForMalformedRecord(record)
isWarningPrinted = true
}
emptyRow

case Some(corruptIndex) =>
if (!isWarningPrinted) {
printWarningIfWholeFile()
isWarningPrinted = true
}
val row = new GenericInternalRow(schema.length)
row.update(corruptIndex, record())
Seq(row)
}
}

/**
* Create a converter which converts the JSON documents held by the `JsonParser`
* to a value according to a desired schema. This is a wrapper for the method
Expand Down Expand Up @@ -239,7 +134,7 @@ class JacksonParser(
lowerCaseValue.equals("-inf")) {
value.toFloat
} else {
throw new SparkSQLJsonProcessingException(s"Cannot parse $value as FloatType.")
throw new RuntimeException(s"Cannot parse $value as FloatType.")
}
}

Expand All @@ -259,7 +154,7 @@ class JacksonParser(
lowerCaseValue.equals("-inf")) {
value.toDouble
} else {
throw new SparkSQLJsonProcessingException(s"Cannot parse $value as DoubleType.")
throw new RuntimeException(s"Cannot parse $value as DoubleType.")
}
}

Expand Down Expand Up @@ -391,9 +286,8 @@ class JacksonParser(

case token =>
// We cannot parse this token based on the given data type. So, we throw a
// SparkSQLJsonProcessingException and this exception will be caught by
// `parse` method.
throw new SparkSQLJsonProcessingException(
// RuntimeException and this exception will be caught by `parse` method.
throw new RuntimeException(
s"Failed to parse a value for data type $dataType (current token: $token).")
}

Expand Down Expand Up @@ -466,14 +360,14 @@ class JacksonParser(
parser.nextToken() match {
case null => Nil
case _ => rootConverter.apply(parser) match {
case null => throw new SparkSQLJsonProcessingException("Root converter returned null")
case null => throw new RuntimeException("Root converter returned null")
case rows => rows
}
}
}
} catch {
case _: JsonProcessingException | _: SparkSQLJsonProcessingException =>
failedRecord(() => recordLiteral(record))
case e @ (_: RuntimeException | _: JsonProcessingException) =>
throw BadRecordException(() => recordLiteral(record), () => None, e)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String

class FailureSafeParser[IN](
rawParser: IN => Seq[InternalRow],
mode: String,
schema: StructType,
columnNameOfCorruptRecord: String) {

private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord)
private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord))
private val resultRow = new GenericInternalRow(schema.length)
private val nullResult = new GenericInternalRow(schema.length)

// This function takes 2 parameters: an optional partial result, and the bad record. If the given
// schema doesn't contain a field for corrupted record, we just return the partial result or a
// row with all fields null. If the given schema contains a field for corrupted record, we will
// set the bad record to this field, and set other fields according to the partial result or null.
private val toResultRow: (Option[InternalRow], () => UTF8String) => InternalRow = {
if (corruptFieldIndex.isDefined) {
(row, badRecord) => {
var i = 0
while (i < actualSchema.length) {
val from = actualSchema(i)
resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i, from.dataType)).orNull
i += 1
}
resultRow(corruptFieldIndex.get) = badRecord()
resultRow
}
} else {
(row, _) => row.getOrElse(nullResult)
}
}

def parse(input: IN): Iterator[InternalRow] = {
try {
rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null))
} catch {
case e: BadRecordException if ParseModes.isPermissiveMode(mode) =>
Iterator(toResultRow(e.partialResult(), e.record))
case _: BadRecordException if ParseModes.isDropMalformedMode(mode) =>
Iterator.empty
case e: BadRecordException => throw e.cause
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is FAIL_FAST_MODE, if my understanding is not wrong. Should we issue the error message including FAILFAST, like what we did before?

This is also an behavior change? If users did not correctly spell the mode string, we treated it as the PERMISSIVE mode. Now, we changed it to the FAILFAST mode.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a test case to cover it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just did a re-check. This part in CSV is kind of messy. The codes are random without any rule. At the very beginning, we should have test cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParseModes.isPermissiveMode returns true if the mode string is invalid

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh... this is kind of tricky.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add tests in follow-up

}
}
}

/**
* Exception thrown when the underlying parser meet a bad record and can't parse it.
* @param record a function to return the record that cause the parser to fail
* @param partialResult a function that returns an optional row, which is the partial result of
* parsing this bad record.
* @param cause the actual exception about why the record is bad and can't be parsed.
*/
case class BadRecordException(
record: () => UTF8String,
partialResult: () => Option[InternalRow],
cause: Throwable) extends Exception(cause)
23 changes: 19 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.Partition
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.FailureSafeParser
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.csv._
Expand Down Expand Up @@ -382,11 +383,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}

verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
val actualSchema =
StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))

val createParser = CreateJacksonParser.string _
val parsed = jsonDataset.rdd.mapPartitions { iter =>
val parser = new JacksonParser(schema, parsedOptions)
iter.flatMap(parser.parse(_, createParser, UTF8String.fromString))
val rawParser = new JacksonParser(actualSchema, parsedOptions)
val parser = new FailureSafeParser[String](
input => rawParser.parse(input, createParser, UTF8String.fromString),
parsedOptions.parseMode,
schema,
parsedOptions.columnNameOfCorruptRecord)
iter.flatMap(parser.parse)
}

Dataset.ofRows(
Expand Down Expand Up @@ -435,14 +443,21 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}

verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
val actualSchema =
StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))

val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions))
}.getOrElse(filteredLines.rdd)

val parsed = linesWithoutHeader.mapPartitions { iter =>
val parser = new UnivocityParser(schema, parsedOptions)
iter.flatMap(line => parser.parse(line))
val rawParser = new UnivocityParser(actualSchema, parsedOptions)
val parser = new FailureSafeParser[String](
input => Seq(rawParser.parse(input)),
parsedOptions.parseMode,
schema,
parsedOptions.columnNameOfCorruptRecord)
iter.flatMap(parser.parse)
}

Dataset.ofRows(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ abstract class CSVDataSource extends Serializable {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
parsedOptions: CSVOptions): Iterator[InternalRow]
schema: StructType): Iterator[InternalRow]

/**
* Infers the schema from `inputPaths` files.
Expand Down Expand Up @@ -115,17 +115,17 @@ object TextInputCSVDataSource extends CSVDataSource {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
parsedOptions: CSVOptions): Iterator[InternalRow] = {
schema: StructType): Iterator[InternalRow] = {
val lines = {
val linesReader = new HadoopFileLinesReader(file, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
linesReader.map { line =>
new String(line.getBytes, 0, line.getLength, parsedOptions.charset)
new String(line.getBytes, 0, line.getLength, parser.options.charset)
}
}

val shouldDropHeader = parsedOptions.headerFlag && file.start == 0
UnivocityParser.parseIterator(lines, shouldDropHeader, parser)
val shouldDropHeader = parser.options.headerFlag && file.start == 0
UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema)
}

override def infer(
Expand Down Expand Up @@ -192,11 +192,12 @@ object WholeFileCSVDataSource extends CSVDataSource {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
parsedOptions: CSVOptions): Iterator[InternalRow] = {
schema: StructType): Iterator[InternalRow] = {
UnivocityParser.parseStream(
CodecStreams.createInputStreamWithCloseResource(conf, file.filePath),
parsedOptions.headerFlag,
parser)
parser.options.headerFlag,
parser,
schema)
}

override def infer(
Expand Down
Loading