From 9d91da124e0723adee7744a64999ea1c07acfe66 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 11 Mar 2017 15:53:39 +0900 Subject: [PATCH 1/3] Defer input path validation into DataSource in CSV datasource --- .../datasources/csv/CSVDataSource.scala | 29 +++++++++++++------ .../datasources/csv/CSVFileFormat.scala | 2 -- .../execution/datasources/csv/CSVSuite.scala | 8 +++++ 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 35ff924f27ce..c58327e1dc7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -54,10 +54,21 @@ abstract class CSVDataSource extends Serializable { /** * Infers the schema from `inputPaths` files. */ - def infer( + final def infer( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - parsedOptions: CSVOptions): Option[StructType] + parsedOptions: CSVOptions): Option[StructType] = { + if (inputPaths.nonEmpty) { + Some(inferSchema(sparkSession, inputPaths, parsedOptions)) + } else { + None + } + } + + protected def inferSchema( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: CSVOptions): StructType /** * Generates a header from the given row which is null-safe and duplicate-safe. @@ -128,13 +139,13 @@ object TextInputCSVDataSource extends CSVDataSource { UnivocityParser.parseIterator(lines, shouldDropHeader, parser) } - override def infer( + override def inferSchema( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - parsedOptions: CSVOptions): Option[StructType] = { + parsedOptions: CSVOptions): StructType = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption - Some(inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)) + inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions) } /** @@ -199,10 +210,10 @@ object WholeFileCSVDataSource extends CSVDataSource { parser) } - override def infer( + override def inferSchema( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - parsedOptions: CSVOptions): Option[StructType] = { + parsedOptions: CSVOptions): StructType = { val csv = createBaseRdd(sparkSession, inputPaths, parsedOptions) csv.flatMap { lines => UnivocityParser.tokenizeStream( @@ -221,10 +232,10 @@ object WholeFileCSVDataSource extends CSVDataSource { parsedOptions.headerFlag, new CsvParser(parsedOptions.asParserSettings)) } - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + CSVInferSchema.infer(tokenRDD, header, parsedOptions) case None => // If the first row could not be read, just return the empty schema. - Some(StructType(Nil)) + StructType(Nil) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 29c41455279e..f397a150c6b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -51,8 +51,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - require(files.nonEmpty, "Cannot infer schema from an empty set of files") - val parsedOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 4435e4df38ef..a5fee4778aec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1115,4 +1115,12 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(df2.schema === schema) } + test("Consistent exception message in schema inference when the path is an empty directory") { + withTempDir { dir => + val message = intercept[AnalysisException] { + spark.read.csv(dir.getAbsolutePath) + }.getMessage + assert(message.contains("Unable to infer schema for CSV. It must be specified manually.")) + } + } } From 04e620cbc2ed13525918af9e19ef1271d194c386 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 12 Mar 2017 20:28:34 +0900 Subject: [PATCH 2/3] Fix test --- .../spark/sql/execution/datasources/csv/CSVSuite.scala | 8 -------- .../spark/sql/test/DataFrameReaderWriterSuite.scala | 6 ++++-- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index a5fee4778aec..4435e4df38ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1115,12 +1115,4 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(df2.schema === schema) } - test("Consistent exception message in schema inference when the path is an empty directory") { - withTempDir { dir => - val message = intercept[AnalysisException] { - spark.read.csv(dir.getAbsolutePath) - }.getMessage - assert(message.contains("Unable to infer schema for CSV. It must be specified manually.")) - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 8a8ba0553452..8287776f8f55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -370,9 +370,11 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be val schema = df.schema // Reader, without user specified schema - intercept[IllegalArgumentException] { + val message = intercept[AnalysisException] { testRead(spark.read.csv(), Seq.empty, schema) - } + }.getMessage + assert(message.contains("Unable to infer schema for CSV. It must be specified manually.")) + testRead(spark.read.csv(dir), data, schema) testRead(spark.read.csv(dir, dir), data ++ data, schema) testRead(spark.read.csv(Seq(dir, dir): _*), data ++ data, schema) From 87d3fc8ee6ca9012d056c46b7fdca4306fdfd76f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 13 Mar 2017 15:19:41 +0900 Subject: [PATCH 3/3] Swap names (infer and inferSchema) --- .../sql/execution/datasources/csv/CSVDataSource.scala | 10 +++++----- .../sql/execution/datasources/csv/CSVFileFormat.scala | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index c58327e1dc7c..b97bf3a2f0fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -54,18 +54,18 @@ abstract class CSVDataSource extends Serializable { /** * Infers the schema from `inputPaths` files. */ - final def infer( + final def inferSchema( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): Option[StructType] = { if (inputPaths.nonEmpty) { - Some(inferSchema(sparkSession, inputPaths, parsedOptions)) + Some(infer(sparkSession, inputPaths, parsedOptions)) } else { None } } - protected def inferSchema( + protected def infer( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): StructType @@ -139,7 +139,7 @@ object TextInputCSVDataSource extends CSVDataSource { UnivocityParser.parseIterator(lines, shouldDropHeader, parser) } - override def inferSchema( + override def infer( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): StructType = { @@ -210,7 +210,7 @@ object WholeFileCSVDataSource extends CSVDataSource { parser) } - override def inferSchema( + override def infer( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): StructType = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index f397a150c6b5..d3379766b2b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -54,7 +54,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { val parsedOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - CSVDataSource(parsedOptions).infer(sparkSession, files, parsedOptions) + CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions) } override def prepareWrite(