Skip to content

Commit a8a1ac0

Browse files
MaxGekkgatorsmile
authored andcommitted
[SPARK-24959][SQL] Speed up count() for JSON and CSV
## What changes were proposed in this pull request? In the PR, I propose to skip invoking of the CSV/JSON parser per each line in the case if the required schema is empty. Added benchmarks for `count()` shows performance improvement up to **3.5 times**. Before: ``` Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) -------------------------------------------------------------------------------------- JSON count() 7676 / 7715 1.3 767.6 CSV count() 3309 / 3363 3.0 330.9 ``` After: ``` Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) -------------------------------------------------------------------------------------- JSON count() 2104 / 2156 4.8 210.4 CSV count() 2332 / 2386 4.3 233.2 ``` ## How was this patch tested? It was tested by `CSVSuite` and `JSONSuite` as well as on added benchmarks. Author: Maxim Gekk <[email protected]> Author: Maxim Gekk <[email protected]> Closes #21909 from MaxGekk/empty-schema-optimization.
1 parent 14d7c1c commit a8a1ac0

File tree

9 files changed

+159
-21
lines changed

9 files changed

+159
-21
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.json
1919

2020
import java.io.{ByteArrayOutputStream, CharConversionException}
21+
import java.nio.charset.MalformedInputException
2122

2223
import scala.collection.mutable.ArrayBuffer
2324
import scala.util.Try
@@ -402,7 +403,7 @@ class JacksonParser(
402403
}
403404
}
404405
} catch {
405-
case e @ (_: RuntimeException | _: JsonProcessingException) =>
406+
case e @ (_: RuntimeException | _: JsonProcessingException | _: MalformedInputException) =>
406407
// JSON parser currently doesn't support partial results for corrupted records.
407408
// For such records, all fields other than the field configured by
408409
// `columnNameOfCorruptRecord` are set to `null`.

sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
450450
input => rawParser.parse(input, createParser, UTF8String.fromString),
451451
parsedOptions.parseMode,
452452
schema,
453-
parsedOptions.columnNameOfCorruptRecord)
453+
parsedOptions.columnNameOfCorruptRecord,
454+
parsedOptions.multiLine)
454455
iter.flatMap(parser.parse)
455456
}
456457
sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = jsonDataset.isStreaming)
@@ -521,7 +522,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
521522
input => Seq(rawParser.parse(input)),
522523
parsedOptions.parseMode,
523524
schema,
524-
parsedOptions.columnNameOfCorruptRecord)
525+
parsedOptions.columnNameOfCorruptRecord,
526+
parsedOptions.multiLine)
525527
iter.flatMap(parser.parse)
526528
}
527529
sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = csvDataset.isStreaming)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@ import org.apache.spark.SparkException
2121
import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
2323
import org.apache.spark.sql.catalyst.util._
24+
import org.apache.spark.sql.internal.SQLConf
2425
import org.apache.spark.sql.types.StructType
2526
import org.apache.spark.unsafe.types.UTF8String
2627

2728
class FailureSafeParser[IN](
2829
rawParser: IN => Seq[InternalRow],
2930
mode: ParseMode,
3031
schema: StructType,
31-
columnNameOfCorruptRecord: String) {
32+
columnNameOfCorruptRecord: String,
33+
isMultiLine: Boolean) {
3234

3335
private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord)
3436
private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord))
@@ -56,9 +58,15 @@ class FailureSafeParser[IN](
5658
}
5759
}
5860

61+
private val skipParsing = !isMultiLine && mode == PermissiveMode && schema.isEmpty
62+
5963
def parse(input: IN): Iterator[InternalRow] = {
6064
try {
61-
rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null))
65+
if (skipParsing) {
66+
Iterator.single(InternalRow.empty)
67+
} else {
68+
rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null))
69+
}
6270
} catch {
6371
case e: BadRecordException => mode match {
6472
case PermissiveMode =>

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -203,19 +203,11 @@ class UnivocityParser(
203203
}
204204
}
205205

206-
private val doParse = if (requiredSchema.nonEmpty) {
207-
(input: String) => convert(tokenizer.parseLine(input))
208-
} else {
209-
// If `columnPruning` enabled and partition attributes scanned only,
210-
// `schema` gets empty.
211-
(_: String) => InternalRow.empty
212-
}
213-
214206
/**
215207
* Parses a single CSV string and turns it into either one resulting row or no row (if the
216208
* the record is malformed).
217209
*/
218-
def parse(input: String): InternalRow = doParse(input)
210+
def parse(input: String): InternalRow = convert(tokenizer.parseLine(input))
219211

220212
private val getToken = if (options.columnPruning) {
221213
(tokens: Array[String], index: Int) => tokens(index)
@@ -293,7 +285,8 @@ private[csv] object UnivocityParser {
293285
input => Seq(parser.convert(input)),
294286
parser.options.parseMode,
295287
schema,
296-
parser.options.columnNameOfCorruptRecord)
288+
parser.options.columnNameOfCorruptRecord,
289+
parser.options.multiLine)
297290
convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens =>
298291
safeParser.parse(tokens)
299292
}.flatten
@@ -341,7 +334,8 @@ private[csv] object UnivocityParser {
341334
input => Seq(parser.parse(input)),
342335
parser.options.parseMode,
343336
schema,
344-
parser.options.columnNameOfCorruptRecord)
337+
parser.options.columnNameOfCorruptRecord,
338+
parser.options.multiLine)
345339
filteredLines.flatMap(safeParser.parse)
346340
}
347341
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ object TextInputJsonDataSource extends JsonDataSource {
139139
input => parser.parse(input, textParser, textToUTF8String),
140140
parser.options.parseMode,
141141
schema,
142-
parser.options.columnNameOfCorruptRecord)
142+
parser.options.columnNameOfCorruptRecord,
143+
parser.options.multiLine)
143144
linesReader.flatMap(safeParser.parse)
144145
}
145146

@@ -223,7 +224,8 @@ object MultiLineJsonDataSource extends JsonDataSource {
223224
input => parser.parse[InputStream](input, streamParser, partitionedFileString),
224225
parser.options.parseMode,
225226
schema,
226-
parser.options.columnNameOfCorruptRecord)
227+
parser.options.columnNameOfCorruptRecord,
228+
parser.options.multiLine)
227229

228230
safeParser.parse(
229231
CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))))

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,47 @@ object CSVBenchmarks {
119119
}
120120
}
121121

122+
def countBenchmark(rowsNum: Int): Unit = {
123+
val colsNum = 10
124+
val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", rowsNum)
125+
126+
withTempPath { path =>
127+
val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType))
128+
val schema = StructType(fields)
129+
130+
spark.range(rowsNum)
131+
.select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*)
132+
.write
133+
.csv(path.getAbsolutePath)
134+
135+
val ds = spark.read.schema(schema).csv(path.getAbsolutePath)
136+
137+
benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ =>
138+
ds.select("*").filter((_: Row) => true).count()
139+
}
140+
benchmark.addCase(s"Select 1 column + count()", 3) { _ =>
141+
ds.select($"col1").filter((_: Row) => true).count()
142+
}
143+
benchmark.addCase(s"count()", 3) { _ =>
144+
ds.count()
145+
}
146+
147+
/*
148+
Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz
149+
150+
Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
151+
---------------------------------------------------------------------------------------------
152+
Select 10 columns + count() 12598 / 12740 0.8 1259.8 1.0X
153+
Select 1 column + count() 7960 / 8175 1.3 796.0 1.6X
154+
count() 2332 / 2386 4.3 233.2 5.4X
155+
*/
156+
benchmark.run()
157+
}
158+
}
159+
122160
def main(args: Array[String]): Unit = {
123161
quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3)
124162
multiColumnsBenchmark(rowsNum = 1000 * 1000)
163+
countBenchmark(10 * 1000 * 1000)
125164
}
126165
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1641,4 +1641,30 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te
16411641
}
16421642
}
16431643
}
1644+
1645+
test("count() for malformed input") {
1646+
def countForMalformedCSV(expected: Long, input: Seq[String]): Unit = {
1647+
val schema = new StructType().add("a", IntegerType)
1648+
val strings = spark.createDataset(input)
1649+
val df = spark.read.schema(schema).option("header", false).csv(strings)
1650+
1651+
assert(df.count() == expected)
1652+
}
1653+
def checkCount(expected: Long): Unit = {
1654+
val validRec = "1"
1655+
val inputs = Seq(
1656+
Seq("{-}", validRec),
1657+
Seq(validRec, "?"),
1658+
Seq("0xAC", validRec),
1659+
Seq(validRec, "0.314"),
1660+
Seq("\\\\\\", validRec)
1661+
)
1662+
inputs.foreach { input =>
1663+
countForMalformedCSV(expected, input)
1664+
}
1665+
}
1666+
1667+
checkCount(2)
1668+
countForMalformedCSV(0, Seq(""))
1669+
}
16441670
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ package org.apache.spark.sql.execution.datasources.json
1919
import java.io.File
2020

2121
import org.apache.spark.SparkConf
22-
import org.apache.spark.sql.SparkSession
23-
import org.apache.spark.sql.types.{LongType, StringType, StructType}
22+
import org.apache.spark.sql.{Row, SparkSession}
23+
import org.apache.spark.sql.functions.lit
24+
import org.apache.spark.sql.types._
2425
import org.apache.spark.util.{Benchmark, Utils}
2526

2627
/**
@@ -171,9 +172,49 @@ object JSONBenchmarks {
171172
}
172173
}
173174

175+
def countBenchmark(rowsNum: Int): Unit = {
176+
val colsNum = 10
177+
val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", rowsNum)
178+
179+
withTempPath { path =>
180+
val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType))
181+
val schema = StructType(fields)
182+
val columnNames = schema.fieldNames
183+
184+
spark.range(rowsNum)
185+
.select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*)
186+
.write
187+
.json(path.getAbsolutePath)
188+
189+
val ds = spark.read.schema(schema).json(path.getAbsolutePath)
190+
191+
benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ =>
192+
ds.select("*").filter((_: Row) => true).count()
193+
}
194+
benchmark.addCase(s"Select 1 column + count()", 3) { _ =>
195+
ds.select($"col1").filter((_: Row) => true).count()
196+
}
197+
benchmark.addCase(s"count()", 3) { _ =>
198+
ds.count()
199+
}
200+
201+
/*
202+
Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz
203+
204+
Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
205+
---------------------------------------------------------------------------------------------
206+
Select 10 columns + count() 9961 / 10006 1.0 996.1 1.0X
207+
Select 1 column + count() 8355 / 8470 1.2 835.5 1.2X
208+
count() 2104 / 2156 4.8 210.4 4.7X
209+
*/
210+
benchmark.run()
211+
}
212+
}
213+
174214
def main(args: Array[String]): Unit = {
175215
schemaInferring(100 * 1000 * 1000)
176216
perlineParsing(100 * 1000 * 1000)
177217
perlineParsingOfWideColumn(10 * 1000 * 1000)
218+
countBenchmark(10 * 1000 * 1000)
178219
}
179220
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2223,7 +2223,6 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
22232223
checkAnswer(jsonDF, Seq(Row("Chris", "Baird")))
22242224
}
22252225

2226-
22272226
test("SPARK-23723: specified encoding is not matched to actual encoding") {
22282227
val fileName = "test-data/utf16LE.json"
22292228
val schema = new StructType().add("firstName", StringType).add("lastName", StringType)
@@ -2490,4 +2489,30 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
24902489
assert(exception.getMessage.contains("encoding must not be included in the blacklist"))
24912490
}
24922491
}
2492+
2493+
test("count() for malformed input") {
2494+
def countForMalformedJSON(expected: Long, input: Seq[String]): Unit = {
2495+
val schema = new StructType().add("a", StringType)
2496+
val strings = spark.createDataset(input)
2497+
val df = spark.read.schema(schema).json(strings)
2498+
2499+
assert(df.count() == expected)
2500+
}
2501+
def checkCount(expected: Long): Unit = {
2502+
val validRec = """{"a":"b"}"""
2503+
val inputs = Seq(
2504+
Seq("{-}", validRec),
2505+
Seq(validRec, "?"),
2506+
Seq("}", validRec),
2507+
Seq(validRec, """{"a": [1, 2, 3]}"""),
2508+
Seq("""{"a": {"a": "b"}}""", validRec)
2509+
)
2510+
inputs.foreach { input =>
2511+
countForMalformedJSON(expected, input)
2512+
}
2513+
}
2514+
2515+
checkCount(2)
2516+
countForMalformedJSON(0, Seq(""))
2517+
}
24932518
}

0 commit comments

Comments
 (0)