Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,20 @@ object SQLConf {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefault(100)

val DISABLED_V2_FILE_DATA_SOURCE_READERS = buildConf("spark.sql.disabledV2FileDataSourceReaders")
.internal()
.doc("A comma-separated list of file data source short names for which DataSourceReader" +
" is disabled. Reads from these sources will fall back to the V1 sources")
.stringConf
.createWithDefault("")

val DISABLED_V2_FILE_DATA_SOURCE_WRITERS = buildConf("spark.sql.disabledV2FileDataSourceWriters")
.internal()
.doc("A comma-separated list of file data source short names for which DataSourceWriter" +
" is disabled. Writes to these sources will fall back to the V1 FileFormat")
.stringConf
.createWithDefault("")

val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers")
.internal()
.doc("A comma-separated list of fully qualified data source register class names for which" +
Expand Down Expand Up @@ -1606,6 +1620,10 @@ class SQLConf extends Serializable with Logging {
def continuousStreamingExecutorPollIntervalMs: Long =
getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS)

def disabledV2FileDataSourceReader: String = getConf(DISABLED_V2_FILE_DATA_SOURCE_READERS)

def disabledV2FileDataSourceWriter: String = getConf(DISABLED_V2_FILE_DATA_SOURCE_WRITERS)

def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS)

def disabledV2StreamingMicroBatchReaders: String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,16 @@ public class OrcColumnarBatchReader extends RecordReader<Void, ColumnarBatch> {

/**
* The column IDs of the physical ORC file schema which are required by this reader.
* -1 means this required column doesn't exist in the ORC file.
* -1 means this required column is partition column, or it doesn't exist in the ORC file.
*/
private int[] requestedColIds;

/**
* The column IDs of the ORC file partition schema which are required by this reader.
* -1 means this required column doesn't exist in the ORC partition columns.
*/
private int[] requestedPartitionColIds;

// Record reader from ORC row batch.
private org.apache.orc.RecordReader recordReader;

Expand Down Expand Up @@ -143,75 +149,76 @@ public void initialize(
/**
* Initialize columnar batch by setting required schema and partition information.
* With this information, this creates ColumnarBatch with the full schema.
*
* @param orcSchema Schema from ORC file reader.
* @param requiredFields All the fields that are required to return, including partition fields.
* @param requestedColIds Requested column ids from orcSchema. -1 if not existed.
* @param requestedPartitionColIds Requested column ids from partition schema. -1 if not existed.
* @param partitionValues Values of partition columns.
*/
public void initBatch(
TypeDescription orcSchema,
int[] requestedColIds,
StructField[] requiredFields,
StructType partitionSchema,
int[] requestedColIds,
int[] requestedPartitionColIds,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

param doc for them.

InternalRow partitionValues) {
batch = orcSchema.createRowBatch(capacity);
assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`.

assert(requiredFields.length == requestedColIds.length);
assert(requiredFields.length == requestedPartitionColIds.length);
// If a required column is also partition column, use partition value and don't read from file.
for (int i = 0; i < requiredFields.length; i++) {
if (requestedPartitionColIds[i] != -1) {
requestedColIds[i] = -1;
}
}
this.requestedPartitionColIds = requestedPartitionColIds;
this.requiredFields = requiredFields;
this.requestedColIds = requestedColIds;
assert(requiredFields.length == requestedColIds.length);

StructType resultSchema = new StructType(requiredFields);
for (StructField f : partitionSchema.fields()) {
resultSchema = resultSchema.add(f);
}

if (copyToSpark) {
if (MEMORY_MODE == MemoryMode.OFF_HEAP) {
columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema);
} else {
columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema);
}

// Initialize the missing columns once.
// Initialize the missing columns and partition columns once.
for (int i = 0; i < requiredFields.length; i++) {
if (requestedColIds[i] == -1) {
if (requestedPartitionColIds[i] != -1) {
ColumnVectorUtils.populate(columnVectors[i],
partitionValues, requestedPartitionColIds[i]);
columnVectors[i].setIsConstant();
} else if (requestedColIds[i] == -1) {
columnVectors[i].putNulls(0, capacity);
columnVectors[i].setIsConstant();
}
}

if (partitionValues.numFields() > 0) {
int partitionIdx = requiredFields.length;
for (int i = 0; i < partitionValues.numFields(); i++) {
ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i);
columnVectors[i + partitionIdx].setIsConstant();
}
}

columnarBatch = new ColumnarBatch(columnVectors);
} else {
// Just wrap the ORC column vector instead of copying it to Spark column vector.
orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()];

for (int i = 0; i < requiredFields.length; i++) {
DataType dt = requiredFields[i].dataType();
int colId = requestedColIds[i];
// Initialize the missing columns once.
if (colId == -1) {
OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt);
missingCol.putNulls(0, capacity);
missingCol.setIsConstant();
orcVectorWrappers[i] = missingCol;
} else {
orcVectorWrappers[i] = new OrcColumnVector(dt, batch.cols[colId]);
}
}

if (partitionValues.numFields() > 0) {
int partitionIdx = requiredFields.length;
for (int i = 0; i < partitionValues.numFields(); i++) {
DataType dt = partitionSchema.fields()[i].dataType();
if (requestedPartitionColIds[i] != -1) {
OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt);
ColumnVectorUtils.populate(partitionCol, partitionValues, i);
ColumnVectorUtils.populate(partitionCol, partitionValues, requestedPartitionColIds[i]);
partitionCol.setIsConstant();
orcVectorWrappers[partitionIdx + i] = partitionCol;
orcVectorWrappers[i] = partitionCol;
} else {
int colId = requestedColIds[i];
// Initialize the missing columns once.
if (colId == -1) {
OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt);
missingCol.putNulls(0, capacity);
missingCol.setIsConstant();
orcVectorWrappers[i] = missingCol;
} else {
orcVectorWrappers[i] = new OrcColumnVector(dt, batch.cols[colId]);
}
}
}

Expand All @@ -233,6 +240,7 @@ private boolean nextBatch() throws IOException {

if (!copyToSpark) {
for (int i = 0; i < requiredFields.length; i++) {
// It is possible that..
if (requestedColIds[i] != -1) {
((OrcColumnVector) orcVectorWrappers[i]).setBatchSize(batchSize);
}
Expand All @@ -248,7 +256,7 @@ private boolean nextBatch() throws IOException {
StructField field = requiredFields[i];
WritableColumnVector toColumn = columnVectors[i];

if (requestedColIds[i] >= 0) {
if (requestedColIds[i] != -1) {
ColumnVector fromColumn = batch.cols[requestedColIds[i]];

if (fromColumn.isRepeating) {
Expand Down
30 changes: 22 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser}
import org.apache.spark.sql.execution.datasources.csv._
import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -190,35 +190,49 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
"read files of Hive data source directly.")
}

val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
val allPaths = (CaseInsensitiveMap(extraOptions.toMap).get("path") ++ paths).toSeq
val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf, allPaths)
if (classOf[DataSourceV2].isAssignableFrom(cls)) {
val ds = cls.newInstance().asInstanceOf[DataSourceV2]
if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) {

val (needToFallBackFileDataSourceV2, fallBackFileFormat) = ds match {
case f: FileDataSourceV2 =>
val disabledV2Readers =
sparkSession.sessionState.conf.disabledV2FileDataSourceReader.split(",")
(disabledV2Readers.contains(f.shortName), f.fallBackFileFormat.getCanonicalName)
case _ => (false, source)
}
val supportsRead = ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]
if (supportsRead && !needToFallBackFileDataSourceV2) {
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
ds = ds, conf = sparkSession.sessionState.conf)
val pathsOption = {
val objectMapper = new ObjectMapper()
DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray)
}

Dataset.ofRows(sparkSession, DataSourceV2Relation.create(
ds, extraOptions.toMap ++ sessionOptions + pathsOption,
userSpecifiedSchema = userSpecifiedSchema))
} else {
loadV1Source(paths: _*)
// In the following cases, we fall back to loading with V1:
// 1. The data source implements v2, but has no v2 implementation for read path.
// 2. The v2 reader of the data source is configured as disabled.
loadV1Source(fallBackFileFormat, paths: _*)
}
} else {
loadV1Source(paths: _*)
loadV1Source(source, paths: _*)
}
}

private def loadV1Source(paths: String*) = {
private def loadV1Source(className: String, paths: String*) = {
// Code path for data source v1.
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = paths,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
className = className,
options = extraOptions.toMap).resolveRelation())
}

Expand Down
57 changes: 32 additions & 25 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, InsertIntoT
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, FileDataSourceV2, WriteToDataSourceV2}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -241,39 +240,47 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
if (classOf[DataSourceV2].isAssignableFrom(cls)) {
val ds = cls.newInstance()
ds match {
case ws: WriteSupport =>
val options = new DataSourceOptions((extraOptions ++
DataSourceV2Utils.extractSessionConfigs(
ds = ds.asInstanceOf[DataSourceV2],
conf = df.sparkSession.sessionState.conf)).asJava)
// Using a timestamp and a random UUID to distinguish different writing jobs. This is good
// enough as there won't be tons of writing jobs created at the same second.
val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US)
.format(new Date()) + "-" + UUID.randomUUID()
val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options)
if (writer.isPresent) {
runCommand(df.sparkSession, "save") {
WriteToDataSourceV2(writer.get(), df.logicalPlan)
}
}
val (needToFallBackFileDataSourceV2, fallBackFileFormat) = ds match {
case f: FileDataSourceV2 =>
val disabledV2Readers =
df.sparkSession.sessionState.conf.disabledV2FileDataSourceWriter.split(",")
(disabledV2Readers.contains(f.shortName), f.fallBackFileFormat.getCanonicalName)
case _ => (false, source)
}

// Streaming also uses the data source V2 API. So it may be that the data source implements
// v2, but has no v2 implementation for batch writes. In that case, we fall back to saving
// as though it's a V1 source.
case _ => saveToV1Source()
if (ds.isInstanceOf[WriteSupport] && !needToFallBackFileDataSourceV2) {
val options = new DataSourceOptions((extraOptions ++
DataSourceV2Utils.extractSessionConfigs(
ds = ds.asInstanceOf[DataSourceV2],
conf = df.sparkSession.sessionState.conf)).asJava)
// Using a timestamp and a random UUID to distinguish different writing jobs. This is good
// enough as there won't be tons of writing jobs created at the same second.
val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US)
.format(new Date()) + "-" + UUID.randomUUID()
val writer = ds.asInstanceOf[WriteSupport]
.createWriter(jobId, df.logicalPlan.schema, mode, options)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand this: why do use .createWriter here, but we do not use .createReader in DataFrameReader. It seems "unsymmetrical" to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is. We're still evolving the v2 API and integration with Spark. This problem is addressed in PR #21305, which is the first of a series of changes to standardize the logical plans and fix problems like this one.

There's also an open proposal for those changes.

if (writer.isPresent) {
runCommand(df.sparkSession, "save") {
WriteToDataSourceV2(writer.get(), df.logicalPlan)
}
}
} else {
// In the following cases, we fall back to saving with V1:
// 1. The data source implements v2, but has no v2 implementation for write path.
// 2. The v2 writer of the data source is configured as disabled.
saveToV1Source(fallBackFileFormat)
}
} else {
saveToV1Source()
saveToV1Source(source)
}
}

private def saveToV1Source(): Unit = {
private def saveToV1Source(className: String): Unit = {
// Code path for data source v1.
runCommand(df.sparkSession, "save") {
DataSource(
sparkSession = df.sparkSession,
className = source,
className = className,
partitionColumns = partitioningColumns.getOrElse(Nil),
options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan))
}
Expand Down
Loading