diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 6ab5c4b269b9..e971fd762efe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.hadoop.fs.Path -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.PartitionedFileUtil import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap abstract class FileScan( @@ -37,22 +37,6 @@ abstract class FileScan( false } - /** - * Returns whether this format supports the given [[DataType]] in write path. - * By default all data types are supported. - */ - def supportsDataType(dataType: DataType): Boolean = true - - /** - * The string that represents the format that this data source provider uses. This is - * overridden by children to provide a nice alias for the data source. For example: - * - * {{{ - * override def formatName(): String = "ORC" - * }}} - */ - def formatName: String - protected def partitions: Seq[FilePartition] = { val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty) val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions) @@ -76,13 +60,5 @@ abstract class FileScan( partitions.toArray } - override def toBatch: Batch = { - readSchema.foreach { field => - if (!supportsDataType(field.dataType)) { - throw new AnalysisException( - s"$formatName data source does not support ${field.dataType.catalogString} data type.") - } - } - this - } + override def toBatch: Batch = this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index 4b35df355b6e..188016c161a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, Table, TableCapability} import org.apache.spark.sql.sources.v2.TableCapability._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.util.SchemaUtils @@ -46,7 +46,11 @@ abstract class FileTable( sparkSession, rootPathsSpecified, caseSensitiveMap, userSpecifiedSchema, fileStatusCache) } - lazy val dataSchema: StructType = userSpecifiedSchema.orElse { + lazy val dataSchema: StructType = userSpecifiedSchema.map { schema => + val partitionSchema = fileIndex.partitionSchema + val resolver = sparkSession.sessionState.conf.resolver + StructType(schema.filterNot(f => partitionSchema.exists(p => resolver(p.name, f.name)))) + }.orElse { inferSchema(fileIndex.allFiles()) }.getOrElse { throw new AnalysisException( @@ -57,6 +61,12 @@ abstract class FileTable( val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis SchemaUtils.checkColumnNameDuplication(dataSchema.fieldNames, "in the data schema", caseSensitive) + dataSchema.foreach { field => + if (!supportsDataType(field.dataType)) { + throw new AnalysisException( + s"$formatName data source does not support ${field.dataType.catalogString} data type.") + } + } val partitionSchema = fileIndex.partitionSchema SchemaUtils.checkColumnNameDuplication(partitionSchema.fieldNames, "in the partition schema", caseSensitive) @@ -72,6 +82,22 @@ abstract class FileTable( * Spark will require that user specify the schema manually. */ def inferSchema(files: Seq[FileStatus]): Option[StructType] + + /** + * Returns whether this format supports the given [[DataType]] in read/write path. + * By default all data types are supported. + */ + def supportsDataType(dataType: DataType): Boolean = true + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def formatName(): String = "ORC" + * }}} + */ + def formatName: String } object FileTable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala index bb4a428e4066..7ff5c4182d98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala @@ -39,7 +39,11 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.SerializableConfiguration -abstract class FileWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String]) +abstract class FileWriteBuilder( + options: CaseInsensitiveStringMap, + paths: Seq[String], + formatName: String, + supportsDataType: DataType => Boolean) extends WriteBuilder with SupportsSaveMode { private var schema: StructType = _ private var queryId: String = _ @@ -108,22 +112,6 @@ abstract class FileWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[St options: Map[String, String], dataSchema: StructType): OutputWriterFactory - /** - * Returns whether this format supports the given [[DataType]] in write path. - * By default all data types are supported. - */ - def supportsDataType(dataType: DataType): Boolean = true - - /** - * The string that represents the format that this data source provider uses. This is - * overridden by children to provide a nice alias for the data source. For example: - * - * {{{ - * override def formatName(): String = "ORC" - * }}} - */ - def formatName: String - private def validateInputs(caseSensitiveAnalysis: Boolean): Unit = { assert(schema != null, "Missing input data schema") assert(queryId != null, "Missing query ID") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala index 4ecd9cdc32ac..55222c624d91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala @@ -20,7 +20,7 @@ import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.sources.v2.Table -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap class CSVDataSourceV2 extends FileDataSourceV2 { @@ -41,13 +41,3 @@ class CSVDataSourceV2 extends FileDataSourceV2 { CSVTable(tableName, sparkSession, options, paths, Some(schema)) } } - -object CSVDataSourceV2 { - def supportsDataType(dataType: DataType): Boolean = dataType match { - case _: AtomicType => true - - case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) - - case _ => false - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index 35c6a668f22a..8f2f8f256731 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -75,10 +75,4 @@ case class CSVScan( CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, fileIndex.partitionSchema, readSchema, parsedOptions) } - - override def supportsDataType(dataType: DataType): Boolean = { - CSVDataSourceV2.supportsDataType(dataType) - } - - override def formatName: String = "CSV" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala index bf4b8ba868f2..852cbf07c350 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.execution.datasources.csv.CSVDataSource import org.apache.spark.sql.execution.datasources.v2.FileTable import org.apache.spark.sql.sources.v2.writer.WriteBuilder -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{AtomicType, DataType, StructType, UserDefinedType} import org.apache.spark.sql.util.CaseInsensitiveStringMap case class CSVTable( @@ -48,5 +48,15 @@ case class CSVTable( } override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = - new CSVWriteBuilder(options, paths) + new CSVWriteBuilder(options, paths, formatName, supportsDataType) + + override def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: AtomicType => true + + case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) + + case _ => false + } + + override def formatName: String = "CSV" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala index bb26d2f92d74..92b47e435480 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala @@ -27,8 +27,12 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap -class CSVWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String]) - extends FileWriteBuilder(options, paths) { +class CSVWriteBuilder( + options: CaseInsensitiveStringMap, + paths: Seq[String], + formatName: String, + supportsDataType: DataType => Boolean) + extends FileWriteBuilder(options, paths, formatName, supportsDataType) { override def prepareWrite( sqlConf: SQLConf, job: Job, @@ -56,10 +60,4 @@ class CSVWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String]) } } } - - override def supportsDataType(dataType: DataType): Boolean = { - CSVDataSourceV2.supportsDataType(dataType) - } - - override def formatName: String = "CSV" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala index 36e7e12e41ce..e8b9e6c6f498 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala @@ -20,7 +20,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.sources.v2.Table -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap class OrcDataSourceV2 extends FileDataSourceV2 { @@ -42,19 +42,3 @@ class OrcDataSourceV2 extends FileDataSourceV2 { } } -object OrcDataSourceV2 { - def supportsDataType(dataType: DataType): Boolean = dataType match { - case _: AtomicType => true - - case st: StructType => st.forall { f => supportsDataType(f.dataType) } - - case ArrayType(elementType, _) => supportsDataType(elementType) - - case MapType(keyType, valueType, _) => - supportsDataType(keyType) && supportsDataType(valueType) - - case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) - - case _ => false - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 237eadb698b4..fc8a682b226c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -43,10 +43,4 @@ case class OrcScan( OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, fileIndex.partitionSchema, readSchema) } - - override def supportsDataType(dataType: DataType): Boolean = { - OrcDataSourceV2.supportsDataType(dataType) - } - - override def formatName: String = "ORC" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala index aac38fb3fa1f..ace77b7c4d9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources.orc.OrcUtils import org.apache.spark.sql.execution.datasources.v2.FileTable import org.apache.spark.sql.sources.v2.writer.WriteBuilder -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class OrcTable( @@ -40,5 +40,22 @@ case class OrcTable( OrcUtils.readSchema(sparkSession, files) override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = - new OrcWriteBuilder(options, paths) + new OrcWriteBuilder(options, paths, formatName, supportsDataType) + + override def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportsDataType(f.dataType) } + + case ArrayType(elementType, _) => supportsDataType(elementType) + + case MapType(keyType, valueType, _) => + supportsDataType(keyType) && supportsDataType(valueType) + + case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) + + case _ => false + } + + override def formatName: String = "ORC" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala index 829ab5fbe176..f5b06e11c8bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala @@ -28,8 +28,12 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap -class OrcWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String]) - extends FileWriteBuilder(options, paths) { +class OrcWriteBuilder( + options: CaseInsensitiveStringMap, + paths: Seq[String], + formatName: String, + supportsDataType: DataType => Boolean) + extends FileWriteBuilder(options, paths, formatName, supportsDataType) { override def prepareWrite( sqlConf: SQLConf, @@ -65,10 +69,4 @@ class OrcWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String]) } } } - - override def supportsDataType(dataType: DataType): Boolean = { - OrcDataSourceV2.supportsDataType(dataType) - } - - override def formatName: String = "ORC" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/FileTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/FileTableSuite.scala new file mode 100644 index 000000000000..3d4f5640723e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/FileTableSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.FileStatus + +import org.apache.spark.sql.{QueryTest, SparkSession} +import org.apache.spark.sql.sources.v2.reader.ScanBuilder +import org.apache.spark.sql.sources.v2.writer.WriteBuilder +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class DummyFileTable( + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + expectedDataSchema: StructType, + userSpecifiedSchema: Option[StructType]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = Some(expectedDataSchema) + + override def name(): String = "Dummy" + + override def formatName: String = "Dummy" + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = null + + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = null + + override def supportsDataType(dataType: DataType): Boolean = dataType == StringType +} + +class FileTableSuite extends QueryTest with SharedSQLContext with SQLTestUtils { + + test("Data type validation should check data schema only") { + withTempPath { dir => + val df = spark.createDataFrame(Seq(("a", 1), ("b", 2))).toDF("v", "p") + val pathName = dir.getCanonicalPath + df.write.partitionBy("p").text(pathName) + val options = new CaseInsensitiveStringMap(Map("path" -> pathName).asJava) + val expectedDataSchema = StructType(Seq(StructField("v", StringType, true))) + // DummyFileTable doesn't support Integer data type. + // However, the partition schema is handled by Spark, so it is allowed to contain + // Integer data type here. + val table = new DummyFileTable(spark, options, Seq(pathName), expectedDataSchema, None) + assert(table.dataSchema == expectedDataSchema) + val expectedPartitionSchema = StructType(Seq(StructField("p", IntegerType, true))) + assert(table.fileIndex.partitionSchema == expectedPartitionSchema) + } + } + + test("Returns correct data schema when user specified schema contains partition schema") { + withTempPath { dir => + val df = spark.createDataFrame(Seq(("a", 1), ("b", 2))).toDF("v", "p") + val pathName = dir.getCanonicalPath + df.write.partitionBy("p").text(pathName) + val options = new CaseInsensitiveStringMap(Map("path" -> pathName).asJava) + val userSpecifiedSchema = Some(StructType(Seq( + StructField("v", StringType, true), + StructField("p", IntegerType, true)))) + val expectedDataSchema = StructType(Seq(StructField("v", StringType, true))) + val table = + new DummyFileTable(spark, options, Seq(pathName), expectedDataSchema, userSpecifiedSchema) + assert(table.dataSchema == expectedDataSchema) + } + } +}