Skip to content

Commit 1d9338b

Browse files
MaxGekkgatorsmile
authored andcommitted
[SPARK-23786][SQL] Checking column names of csv headers
## What changes were proposed in this pull request? Currently column names of headers in CSV files are not checked against provided schema of CSV data. It could cause errors like showed in the [SPARK-23786](https://issues.apache.org/jira/browse/SPARK-23786) and #20894 (comment). I introduced new CSV option - `enforceSchema`. If it is enabled (by default `true`), Spark forcibly applies provided or inferred schema to CSV files. In that case, CSV headers are ignored and not checked against the schema. If `enforceSchema` is set to `false`, additional checks can be performed. For example, if column in CSV header and in the schema have different ordering, the following exception is thrown: ``` java.lang.IllegalArgumentException: CSV file header does not contain the expected fields Header: depth, temperature Schema: temperature, depth CSV file: marina.csv ``` ## How was this patch tested? The changes were tested by existing tests of CSVSuite and by 2 new tests. Author: Maxim Gekk <[email protected]> Author: Maxim Gekk <[email protected]> Closes #20894 from MaxGekk/check-column-names.
1 parent 416cd1f commit 1d9338b

File tree

10 files changed

+411
-37
lines changed

10 files changed

+411
-37
lines changed

python/pyspark/sql/readwriter.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
346346
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
347347
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
348348
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
349-
samplingRatio=None):
349+
samplingRatio=None, enforceSchema=None):
350350
"""Loads a CSV file and returns the result as a :class:`DataFrame`.
351351
352352
This function will go through the input once to determine the input schema if
@@ -373,6 +373,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
373373
default value, ``false``.
374374
:param inferSchema: infers the input schema automatically from data. It requires one extra
375375
pass over the data. If None is set, it uses the default value, ``false``.
376+
:param enforceSchema: If it is set to ``true``, the specified or inferred schema will be
377+
forcibly applied to datasource files, and headers in CSV files will be
378+
ignored. If the option is set to ``false``, the schema will be
379+
validated against all headers in CSV files or the first header in RDD
380+
if the ``header`` option is set to ``true``. Field names in the schema
381+
and column names in CSV headers are checked by their positions
382+
taking into account ``spark.sql.caseSensitive``. If None is set,
383+
``true`` is used by default. Though the default value is ``true``,
384+
it is recommended to disable the ``enforceSchema`` option
385+
to avoid incorrect results.
376386
:param ignoreLeadingWhiteSpace: A flag indicating whether or not leading whitespaces from
377387
values being read should be skipped. If None is set, it
378388
uses the default value, ``false``.
@@ -449,7 +459,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
449459
maxCharsPerColumn=maxCharsPerColumn,
450460
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
451461
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
452-
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio)
462+
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio,
463+
enforceSchema=enforceSchema)
453464
if isinstance(path, basestring):
454465
path = [path]
455466
if type(path) == list:

python/pyspark/sql/streaming.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
564564
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
565565
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
566566
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
567-
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None):
567+
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
568+
enforceSchema=None):
568569
"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.
569570
570571
This function will go through the input once to determine the input schema if
@@ -592,6 +593,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
592593
default value, ``false``.
593594
:param inferSchema: infers the input schema automatically from data. It requires one extra
594595
pass over the data. If None is set, it uses the default value, ``false``.
596+
:param enforceSchema: If it is set to ``true``, the specified or inferred schema will be
597+
forcibly applied to datasource files, and headers in CSV files will be
598+
ignored. If the option is set to ``false``, the schema will be
599+
validated against all headers in CSV files or the first header in RDD
600+
if the ``header`` option is set to ``true``. Field names in the schema
601+
and column names in CSV headers are checked by their positions
602+
taking into account ``spark.sql.caseSensitive``. If None is set,
603+
``true`` is used by default. Though the default value is ``true``,
604+
it is recommended to disable the ``enforceSchema`` option
605+
to avoid incorrect results.
595606
:param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from
596607
values being read should be skipped. If None is set, it
597608
uses the default value, ``false``.
@@ -664,7 +675,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
664675
maxCharsPerColumn=maxCharsPerColumn,
665676
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
666677
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
667-
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping)
678+
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema)
668679
if isinstance(path, basestring):
669680
return self._df(self._jreader.csv(path))
670681
else:

python/pyspark/sql/tests.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3056,6 +3056,24 @@ def test_csv_sampling_ratio(self):
30563056
.csv(rdd, samplingRatio=0.5).schema
30573057
self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)]))
30583058

3059+
def test_checking_csv_header(self):
3060+
path = tempfile.mkdtemp()
3061+
shutil.rmtree(path)
3062+
try:
3063+
self.spark.createDataFrame([[1, 1000], [2000, 2]])\
3064+
.toDF('f1', 'f2').write.option("header", "true").csv(path)
3065+
schema = StructType([
3066+
StructField('f2', IntegerType(), nullable=True),
3067+
StructField('f1', IntegerType(), nullable=True)])
3068+
df = self.spark.read.option('header', 'true').schema(schema)\
3069+
.csv(path, enforceSchema=False)
3070+
self.assertRaisesRegexp(
3071+
Exception,
3072+
"CSV header does not conform to the schema",
3073+
lambda: df.collect())
3074+
finally:
3075+
shutil.rmtree(path)
3076+
30593077

30603078
class HiveSparkSubmitTests(SparkSubmitTests):
30613079

sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.{Locale, Properties}
2222
import scala.collection.JavaConverters._
2323

2424
import com.fasterxml.jackson.databind.ObjectMapper
25+
import com.univocity.parsers.csv.CsvParser
2526

2627
import org.apache.spark.Partition
2728
import org.apache.spark.annotation.InterfaceStability
@@ -474,6 +475,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
474475
* it determines the columns as string types and it reads only the first line to determine the
475476
* names and the number of fields.
476477
*
478+
* If the enforceSchema is set to `false`, only the CSV header in the first line is checked
479+
* to conform specified or inferred schema.
480+
*
477481
* @param csvDataset input Dataset with one CSV row per record
478482
* @since 2.2.0
479483
*/
@@ -499,6 +503,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
499503
StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
500504

501505
val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
506+
CSVDataSource.checkHeader(
507+
firstLine,
508+
new CsvParser(parsedOptions.asParserSettings),
509+
actualSchema,
510+
csvDataset.getClass.getCanonicalName,
511+
parsedOptions.enforceSchema,
512+
sparkSession.sessionState.conf.caseSensitiveAnalysis)
502513
filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions))
503514
}.getOrElse(filteredLines.rdd)
504515

@@ -539,6 +550,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
539550
* <li>`comment` (default empty string): sets a single character used for skipping lines
540551
* beginning with this character. By default, it is disabled.</li>
541552
* <li>`header` (default `false`): uses the first line as names of columns.</li>
553+
* <li>`enforceSchema` (default `true`): If it is set to `true`, the specified or inferred schema
554+
* will be forcibly applied to datasource files, and headers in CSV files will be ignored.
555+
* If the option is set to `false`, the schema will be validated against all headers in CSV files
556+
* in the case when the `header` option is set to `true`. Field names in the schema
557+
* and column names in CSV headers are checked by their positions taking into account
558+
* `spark.sql.caseSensitive`. Though the default value is true, it is recommended to disable
559+
* the `enforceSchema` option to avoid incorrect results.</li>
542560
* <li>`inferSchema` (default `false`): infers the input schema automatically from data. It
543561
* requires one extra pass over the data.</li>
544562
* <li>`samplingRatio` (default is 1.0): defines fraction of rows used for schema inferring.</li>
@@ -583,6 +601,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
583601
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li>
584602
* <li>`multiLine` (default `false`): parse one record, which may span multiple lines.</li>
585603
* </ul>
604+
*
586605
* @since 2.0.0
587606
*/
588607
@scala.annotation.varargs

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala

Lines changed: 119 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
3030

3131
import org.apache.spark.TaskContext
3232
import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
33+
import org.apache.spark.internal.Logging
3334
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
3435
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
3536
import org.apache.spark.sql.catalyst.InternalRow
@@ -50,7 +51,10 @@ abstract class CSVDataSource extends Serializable {
5051
conf: Configuration,
5152
file: PartitionedFile,
5253
parser: UnivocityParser,
53-
schema: StructType): Iterator[InternalRow]
54+
requiredSchema: StructType,
55+
// Actual schema of data in the csv file
56+
dataSchema: StructType,
57+
caseSensitive: Boolean): Iterator[InternalRow]
5458

5559
/**
5660
* Infers the schema from `inputPaths` files.
@@ -110,14 +114,92 @@ abstract class CSVDataSource extends Serializable {
110114
}
111115
}
112116

113-
object CSVDataSource {
117+
object CSVDataSource extends Logging {
114118
def apply(options: CSVOptions): CSVDataSource = {
115119
if (options.multiLine) {
116120
MultiLineCSVDataSource
117121
} else {
118122
TextInputCSVDataSource
119123
}
120124
}
125+
126+
/**
127+
* Checks that column names in a CSV header and field names in the schema are the same
128+
* by taking into account case sensitivity.
129+
*
130+
* @param schema - provided (or inferred) schema to which CSV must conform.
131+
* @param columnNames - names of CSV columns that must be checked against to the schema.
132+
* @param fileName - name of CSV file that are currently checked. It is used in error messages.
133+
* @param enforceSchema - if it is `true`, column names are ignored otherwise the CSV column
134+
* names are checked for conformance to the schema. In the case if
135+
* the column name don't conform to the schema, an exception is thrown.
136+
* @param caseSensitive - if it is set to `false`, comparison of column names and schema field
137+
* names is not case sensitive.
138+
*/
139+
def checkHeaderColumnNames(
140+
schema: StructType,
141+
columnNames: Array[String],
142+
fileName: String,
143+
enforceSchema: Boolean,
144+
caseSensitive: Boolean): Unit = {
145+
if (columnNames != null) {
146+
val fieldNames = schema.map(_.name).toIndexedSeq
147+
val (headerLen, schemaSize) = (columnNames.size, fieldNames.length)
148+
var errorMessage: Option[String] = None
149+
150+
if (headerLen == schemaSize) {
151+
var i = 0
152+
while (errorMessage.isEmpty && i < headerLen) {
153+
var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i))
154+
if (!caseSensitive) {
155+
nameInSchema = nameInSchema.toLowerCase
156+
nameInHeader = nameInHeader.toLowerCase
157+
}
158+
if (nameInHeader != nameInSchema) {
159+
errorMessage = Some(
160+
s"""|CSV header does not conform to the schema.
161+
| Header: ${columnNames.mkString(", ")}
162+
| Schema: ${fieldNames.mkString(", ")}
163+
|Expected: ${fieldNames(i)} but found: ${columnNames(i)}
164+
|CSV file: $fileName""".stripMargin)
165+
}
166+
i += 1
167+
}
168+
} else {
169+
errorMessage = Some(
170+
s"""|Number of column in CSV header is not equal to number of fields in the schema:
171+
| Header length: $headerLen, schema size: $schemaSize
172+
|CSV file: $fileName""".stripMargin)
173+
}
174+
175+
errorMessage.foreach { msg =>
176+
if (enforceSchema) {
177+
logWarning(msg)
178+
} else {
179+
throw new IllegalArgumentException(msg)
180+
}
181+
}
182+
}
183+
}
184+
185+
/**
186+
* Checks that CSV header contains the same column names as fields names in the given schema
187+
* by taking into account case sensitivity.
188+
*/
189+
def checkHeader(
190+
header: String,
191+
parser: CsvParser,
192+
schema: StructType,
193+
fileName: String,
194+
enforceSchema: Boolean,
195+
caseSensitive: Boolean): Unit = {
196+
checkHeaderColumnNames(
197+
schema,
198+
parser.parseLine(header),
199+
fileName,
200+
enforceSchema,
201+
caseSensitive)
202+
}
121203
}
122204

123205
object TextInputCSVDataSource extends CSVDataSource {
@@ -127,7 +209,9 @@ object TextInputCSVDataSource extends CSVDataSource {
127209
conf: Configuration,
128210
file: PartitionedFile,
129211
parser: UnivocityParser,
130-
schema: StructType): Iterator[InternalRow] = {
212+
requiredSchema: StructType,
213+
dataSchema: StructType,
214+
caseSensitive: Boolean): Iterator[InternalRow] = {
131215
val lines = {
132216
val linesReader = new HadoopFileLinesReader(file, conf)
133217
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
@@ -136,8 +220,24 @@ object TextInputCSVDataSource extends CSVDataSource {
136220
}
137221
}
138222

139-
val shouldDropHeader = parser.options.headerFlag && file.start == 0
140-
UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema)
223+
val hasHeader = parser.options.headerFlag && file.start == 0
224+
if (hasHeader) {
225+
// Checking that column names in the header are matched to field names of the schema.
226+
// The header will be removed from lines.
227+
// Note: if there are only comments in the first block, the header would probably
228+
// be not extracted.
229+
CSVUtils.extractHeader(lines, parser.options).foreach { header =>
230+
CSVDataSource.checkHeader(
231+
header,
232+
parser.tokenizer,
233+
dataSchema,
234+
file.filePath,
235+
parser.options.enforceSchema,
236+
caseSensitive)
237+
}
238+
}
239+
240+
UnivocityParser.parseIterator(lines, parser, requiredSchema)
141241
}
142242

143243
override def infer(
@@ -206,12 +306,24 @@ object MultiLineCSVDataSource extends CSVDataSource {
206306
conf: Configuration,
207307
file: PartitionedFile,
208308
parser: UnivocityParser,
209-
schema: StructType): Iterator[InternalRow] = {
309+
requiredSchema: StructType,
310+
dataSchema: StructType,
311+
caseSensitive: Boolean): Iterator[InternalRow] = {
312+
def checkHeader(header: Array[String]): Unit = {
313+
CSVDataSource.checkHeaderColumnNames(
314+
dataSchema,
315+
header,
316+
file.filePath,
317+
parser.options.enforceSchema,
318+
caseSensitive)
319+
}
320+
210321
UnivocityParser.parseStream(
211322
CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))),
212323
parser.options.headerFlag,
213324
parser,
214-
schema)
325+
requiredSchema,
326+
checkHeader)
215327
}
216328

217329
override def infer(

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,21 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
130130
"df.filter($\"_corrupt_record\".isNotNull).count()."
131131
)
132132
}
133+
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
133134

134135
(file: PartitionedFile) => {
135136
val conf = broadcastedHadoopConf.value.value
136137
val parser = new UnivocityParser(
137138
StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
138139
StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
139140
parsedOptions)
140-
CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema)
141+
CSVDataSource(parsedOptions).readFile(
142+
conf,
143+
file,
144+
parser,
145+
requiredSchema,
146+
dataSchema,
147+
caseSensitive)
141148
}
142149
}
143150

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,12 @@ class CSVOptions(
156156
val samplingRatio =
157157
parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
158158

159+
/**
160+
* Forcibly apply the specified or inferred schema to datasource files.
161+
* If the option is enabled, headers of CSV files will be ignored.
162+
*/
163+
val enforceSchema = getBool("enforceSchema", default = true)
164+
159165
def asWriterSettings: CsvWriterSettings = {
160166
val writerSettings = new CsvWriterSettings()
161167
val format = writerSettings.getFormat

0 commit comments

Comments
 (0)