Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ package org.apache.spark.sql.execution.datasources.v2

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.PartitionedFileUtil
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

abstract class FileScan(
Expand All @@ -37,22 +37,6 @@ abstract class FileScan(
false
}

/**
* Returns whether this format supports the given [[DataType]] in write path.
* By default all data types are supported.
*/
def supportsDataType(dataType: DataType): Boolean = true

/**
* The string that represents the format that this data source provider uses. This is
* overridden by children to provide a nice alias for the data source. For example:
*
* {{{
* override def formatName(): String = "ORC"
* }}}
*/
def formatName: String

protected def partitions: Seq[FilePartition] = {
val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty)
val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions)
Expand All @@ -76,13 +60,5 @@ abstract class FileScan(
partitions.toArray
}

override def toBatch: Batch = {
readSchema.foreach { field =>
if (!supportsDataType(field.dataType)) {
throw new AnalysisException(
s"$formatName data source does not support ${field.dataType.catalogString} data type.")
}
}
this
}
override def toBatch: Batch = this
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.sources.v2.TableCapability._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.util.SchemaUtils

Expand All @@ -46,7 +46,11 @@ abstract class FileTable(
sparkSession, rootPathsSpecified, caseSensitiveMap, userSpecifiedSchema, fileStatusCache)
}

lazy val dataSchema: StructType = userSpecifiedSchema.orElse {
lazy val dataSchema: StructType = userSpecifiedSchema.map { schema =>
val partitionSchema = fileIndex.partitionSchema
val resolver = sparkSession.sessionState.conf.resolver
StructType(schema.filterNot(f => partitionSchema.exists(p => resolver(p.name, f.name))))
}.orElse {
inferSchema(fileIndex.allFiles())
}.getOrElse {
throw new AnalysisException(
Expand All @@ -57,6 +61,12 @@ abstract class FileTable(
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
SchemaUtils.checkColumnNameDuplication(dataSchema.fieldNames,
"in the data schema", caseSensitive)
dataSchema.foreach { field =>
if (!supportsDataType(field.dataType)) {
throw new AnalysisException(
s"$formatName data source does not support ${field.dataType.catalogString} data type.")
}
}
val partitionSchema = fileIndex.partitionSchema
SchemaUtils.checkColumnNameDuplication(partitionSchema.fieldNames,
"in the partition schema", caseSensitive)
Expand All @@ -72,6 +82,22 @@ abstract class FileTable(
* Spark will require that user specify the schema manually.
*/
def inferSchema(files: Seq[FileStatus]): Option[StructType]

/**
* Returns whether this format supports the given [[DataType]] in read/write path.
* By default all data types are supported.
*/
def supportsDataType(dataType: DataType): Boolean = true

/**
* The string that represents the format that this data source provider uses. This is
* overridden by children to provide a nice alias for the data source. For example:
*
* {{{
* override def formatName(): String = "ORC"
* }}}
*/
def formatName: String
}

object FileTable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.util.SerializableConfiguration

abstract class FileWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String])
abstract class FileWriteBuilder(
options: CaseInsensitiveStringMap,
paths: Seq[String],
formatName: String,
supportsDataType: DataType => Boolean)
extends WriteBuilder with SupportsSaveMode {
private var schema: StructType = _
private var queryId: String = _
Expand Down Expand Up @@ -108,22 +112,6 @@ abstract class FileWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[St
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory

/**
* Returns whether this format supports the given [[DataType]] in write path.
* By default all data types are supported.
*/
def supportsDataType(dataType: DataType): Boolean = true

/**
* The string that represents the format that this data source provider uses. This is
* overridden by children to provide a nice alias for the data source. For example:
*
* {{{
* override def formatName(): String = "ORC"
* }}}
*/
def formatName: String

private def validateInputs(caseSensitiveAnalysis: Boolean): Unit = {
assert(schema != null, "Missing input data schema")
assert(queryId != null, "Missing query ID")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.sources.v2.Table
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class CSVDataSourceV2 extends FileDataSourceV2 {
Expand All @@ -41,13 +41,3 @@ class CSVDataSourceV2 extends FileDataSourceV2 {
CSVTable(tableName, sparkSession, options, paths, Some(schema))
}
}

object CSVDataSourceV2 {
def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: AtomicType => true

case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)

case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,4 @@ case class CSVScan(
CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, fileIndex.partitionSchema, readSchema, parsedOptions)
}

override def supportsDataType(dataType: DataType): Boolean = {
CSVDataSourceV2.supportsDataType(dataType)
}

override def formatName: String = "CSV"
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.execution.datasources.csv.CSVDataSource
import org.apache.spark.sql.execution.datasources.v2.FileTable
import org.apache.spark.sql.sources.v2.writer.WriteBuilder
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{AtomicType, DataType, StructType, UserDefinedType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

case class CSVTable(
Expand All @@ -48,5 +48,15 @@ case class CSVTable(
}

override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder =
new CSVWriteBuilder(options, paths)
new CSVWriteBuilder(options, paths, formatName, supportsDataType)

override def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: AtomicType => true

case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)

case _ => false
}

override def formatName: String = "CSV"
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class CSVWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String])
extends FileWriteBuilder(options, paths) {
class CSVWriteBuilder(
options: CaseInsensitiveStringMap,
paths: Seq[String],
formatName: String,
supportsDataType: DataType => Boolean)
extends FileWriteBuilder(options, paths, formatName, supportsDataType) {
override def prepareWrite(
sqlConf: SQLConf,
job: Job,
Expand Down Expand Up @@ -56,10 +60,4 @@ class CSVWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String])
}
}
}

override def supportsDataType(dataType: DataType): Boolean = {
CSVDataSourceV2.supportsDataType(dataType)
}

override def formatName: String = "CSV"
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.sources.v2.Table
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class OrcDataSourceV2 extends FileDataSourceV2 {
Expand All @@ -42,19 +42,3 @@ class OrcDataSourceV2 extends FileDataSourceV2 {
}
}

object OrcDataSourceV2 {
def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: AtomicType => true

case st: StructType => st.forall { f => supportsDataType(f.dataType) }

case ArrayType(elementType, _) => supportsDataType(elementType)

case MapType(keyType, valueType, _) =>
supportsDataType(keyType) && supportsDataType(valueType)

case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)

case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScan
import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration

Expand All @@ -43,10 +43,4 @@ case class OrcScan(
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, fileIndex.partitionSchema, readSchema)
}

override def supportsDataType(dataType: DataType): Boolean = {
OrcDataSourceV2.supportsDataType(dataType)
}

override def formatName: String = "ORC"
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.orc.OrcUtils
import org.apache.spark.sql.execution.datasources.v2.FileTable
import org.apache.spark.sql.sources.v2.writer.WriteBuilder
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap

case class OrcTable(
Expand All @@ -40,5 +40,22 @@ case class OrcTable(
OrcUtils.readSchema(sparkSession, files)

override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder =
new OrcWriteBuilder(options, paths)
new OrcWriteBuilder(options, paths, formatName, supportsDataType)

override def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: AtomicType => true

case st: StructType => st.forall { f => supportsDataType(f.dataType) }

case ArrayType(elementType, _) => supportsDataType(elementType)

case MapType(keyType, valueType, _) =>
supportsDataType(keyType) && supportsDataType(valueType)

case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)

case _ => false
}

override def formatName: String = "ORC"
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class OrcWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String])
extends FileWriteBuilder(options, paths) {
class OrcWriteBuilder(
options: CaseInsensitiveStringMap,
paths: Seq[String],
formatName: String,
supportsDataType: DataType => Boolean)
extends FileWriteBuilder(options, paths, formatName, supportsDataType) {

override def prepareWrite(
sqlConf: SQLConf,
Expand Down Expand Up @@ -65,10 +69,4 @@ class OrcWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String])
}
}
}

override def supportsDataType(dataType: DataType): Boolean = {
OrcDataSourceV2.supportsDataType(dataType)
}

override def formatName: String = "ORC"
}
Loading