diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 6feea500b2aa0..984979ac5e9b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.json import java.io.{ByteArrayOutputStream, CharConversionException} +import java.nio.charset.MalformedInputException import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -402,7 +403,7 @@ class JacksonParser( } } } catch { - case e @ (_: RuntimeException | _: JsonProcessingException) => + case e @ (_: RuntimeException | _: JsonProcessingException | _: MalformedInputException) => // JSON parser currently doesn't support partial results for corrupted records. // For such records, all fields other than the field configured by // `columnNameOfCorruptRecord` are set to `null`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 9bd113419ae4c..1b3a9fc91d198 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -450,7 +450,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { input => rawParser.parse(input, createParser, UTF8String.fromString), parsedOptions.parseMode, schema, - parsedOptions.columnNameOfCorruptRecord) + parsedOptions.columnNameOfCorruptRecord, + parsedOptions.multiLine) iter.flatMap(parser.parse) } sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = jsonDataset.isStreaming) @@ -521,7 +522,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { input => Seq(rawParser.parse(input)), parsedOptions.parseMode, schema, - parsedOptions.columnNameOfCorruptRecord) + parsedOptions.columnNameOfCorruptRecord, + parsedOptions.multiLine) iter.flatMap(parser.parse) } sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = csvDataset.isStreaming) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala index 43591a9ff524a..90e81661bae7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String @@ -28,7 +29,8 @@ class FailureSafeParser[IN]( rawParser: IN => Seq[InternalRow], mode: ParseMode, schema: StructType, - columnNameOfCorruptRecord: String) { + columnNameOfCorruptRecord: String, + isMultiLine: Boolean) { private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) @@ -56,9 +58,15 @@ class FailureSafeParser[IN]( } } + private val skipParsing = !isMultiLine && mode == PermissiveMode && schema.isEmpty + def parse(input: IN): Iterator[InternalRow] = { try { - rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) + if (skipParsing) { + Iterator.single(InternalRow.empty) + } else { + rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) + } } catch { case e: BadRecordException => mode match { case PermissiveMode => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 79143cce4a380..e15af425b2649 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -203,19 +203,11 @@ class UnivocityParser( } } - private val doParse = if (requiredSchema.nonEmpty) { - (input: String) => convert(tokenizer.parseLine(input)) - } else { - // If `columnPruning` enabled and partition attributes scanned only, - // `schema` gets empty. - (_: String) => InternalRow.empty - } - /** * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): InternalRow = doParse(input) + def parse(input: String): InternalRow = convert(tokenizer.parseLine(input)) private val getToken = if (options.columnPruning) { (tokens: Array[String], index: Int) => tokens(index) @@ -293,7 +285,8 @@ private[csv] object UnivocityParser { input => Seq(parser.convert(input)), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens => safeParser.parse(tokens) }.flatten @@ -341,7 +334,8 @@ private[csv] object UnivocityParser { input => Seq(parser.parse(input)), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) filteredLines.flatMap(safeParser.parse) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index d6c588894d7f8..76f58371ae264 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -139,7 +139,8 @@ object TextInputJsonDataSource extends JsonDataSource { input => parser.parse(input, textParser, textToUTF8String), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) linesReader.flatMap(safeParser.parse) } @@ -223,7 +224,8 @@ object MultiLineJsonDataSource extends JsonDataSource { input => parser.parse[InputStream](input, streamParser, partitionedFileString), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) safeParser.parse( CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index 1a3dacb8398e6..24f5f55d55485 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -119,8 +119,47 @@ object CSVBenchmarks { } } + def countBenchmark(rowsNum: Int): Unit = { + val colsNum = 10 + val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write + .csv(path.getAbsolutePath) + + val ds = spark.read.schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ => + ds.select("*").filter((_: Row) => true).count() + } + benchmark.addCase(s"Select 1 column + count()", 3) { _ => + ds.select($"col1").filter((_: Row) => true).count() + } + benchmark.addCase(s"count()", 3) { _ => + ds.count() + } + + /* + Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz + + Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + --------------------------------------------------------------------------------------------- + Select 10 columns + count() 12598 / 12740 0.8 1259.8 1.0X + Select 1 column + count() 7960 / 8175 1.3 796.0 1.6X + count() 2332 / 2386 4.3 233.2 5.4X + */ + benchmark.run() + } + } + def main(args: Array[String]): Unit = { quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) multiColumnsBenchmark(rowsNum = 1000 * 1000) + countBenchmark(10 * 1000 * 1000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 456b4535a0dcc..14840e59a1052 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1641,4 +1641,30 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } } } + + test("count() for malformed input") { + def countForMalformedCSV(expected: Long, input: Seq[String]): Unit = { + val schema = new StructType().add("a", IntegerType) + val strings = spark.createDataset(input) + val df = spark.read.schema(schema).option("header", false).csv(strings) + + assert(df.count() == expected) + } + def checkCount(expected: Long): Unit = { + val validRec = "1" + val inputs = Seq( + Seq("{-}", validRec), + Seq(validRec, "?"), + Seq("0xAC", validRec), + Seq(validRec, "0.314"), + Seq("\\\\\\", validRec) + ) + inputs.foreach { input => + countForMalformedCSV(expected, input) + } + } + + checkCount(2) + countForMalformedCSV(0, Seq("")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala index 85cf054e51f6b..a2b747eaab411 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution.datasources.json import java.io.File import org.apache.spark.SparkConf -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.types.{LongType, StringType, StructType} +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ import org.apache.spark.util.{Benchmark, Utils} /** @@ -171,9 +172,49 @@ object JSONBenchmarks { } } + def countBenchmark(rowsNum: Int): Unit = { + val colsNum = 10 + val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + val columnNames = schema.fieldNames + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write + .json(path.getAbsolutePath) + + val ds = spark.read.schema(schema).json(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ => + ds.select("*").filter((_: Row) => true).count() + } + benchmark.addCase(s"Select 1 column + count()", 3) { _ => + ds.select($"col1").filter((_: Row) => true).count() + } + benchmark.addCase(s"count()", 3) { _ => + ds.count() + } + + /* + Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz + + Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + --------------------------------------------------------------------------------------------- + Select 10 columns + count() 9961 / 10006 1.0 996.1 1.0X + Select 1 column + count() 8355 / 8470 1.2 835.5 1.2X + count() 2104 / 2156 4.8 210.4 4.7X + */ + benchmark.run() + } + } + def main(args: Array[String]): Unit = { schemaInferring(100 * 1000 * 1000) perlineParsing(100 * 1000 * 1000) perlineParsingOfWideColumn(10 * 1000 * 1000) + countBenchmark(10 * 1000 * 1000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 655f40ad549e6..3e4cc8f166279 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2223,7 +2223,6 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) } - test("SPARK-23723: specified encoding is not matched to actual encoding") { val fileName = "test-data/utf16LE.json" val schema = new StructType().add("firstName", StringType).add("lastName", StringType) @@ -2490,4 +2489,30 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(exception.getMessage.contains("encoding must not be included in the blacklist")) } } + + test("count() for malformed input") { + def countForMalformedJSON(expected: Long, input: Seq[String]): Unit = { + val schema = new StructType().add("a", StringType) + val strings = spark.createDataset(input) + val df = spark.read.schema(schema).json(strings) + + assert(df.count() == expected) + } + def checkCount(expected: Long): Unit = { + val validRec = """{"a":"b"}""" + val inputs = Seq( + Seq("{-}", validRec), + Seq(validRec, "?"), + Seq("}", validRec), + Seq(validRec, """{"a": [1, 2, 3]}"""), + Seq("""{"a": {"a": "b"}}""", validRec) + ) + inputs.foreach { input => + countForMalformedJSON(expected, input) + } + } + + checkCount(2) + countForMalformedJSON(0, Seq("")) + } }