diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
index f26c134c2f6e..0a0ef76672a7 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
@@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
+import org.apache.spark.sql.sources.v2.DataFormat
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
import org.apache.spark.sql.types.StructType
@@ -53,7 +54,7 @@ class KafkaContinuousReader(
metadataPath: String,
initialOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
- extends ContinuousReader with SupportsScanUnsafeRow with Logging {
+ extends ContinuousReader with Logging {
private lazy val session = SparkSession.getActiveSession.get
private lazy val sc = session.sparkContext
@@ -86,7 +87,7 @@ class KafkaContinuousReader(
KafkaSourceOffset(JsonUtils.partitionOffsets(json))
}
- override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
+ override def createDataReaderFactories(): ju.List[DataReaderFactory] = {
import scala.collection.JavaConverters._
val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset)
@@ -108,7 +109,7 @@ class KafkaContinuousReader(
case (topicPartition, start) =>
KafkaContinuousDataReaderFactory(
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
- .asInstanceOf[DataReaderFactory[UnsafeRow]]
+ .asInstanceOf[DataReaderFactory]
}.asJava
}
@@ -161,17 +162,11 @@ case class KafkaContinuousDataReaderFactory(
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
- failOnDataLoss: Boolean) extends ContinuousDataReaderFactory[UnsafeRow] {
+ failOnDataLoss: Boolean) extends ContinuousDataReaderFactory {
- override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[UnsafeRow] = {
- val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset]
- require(kafkaOffset.topicPartition == topicPartition,
- s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}")
- new KafkaContinuousDataReader(
- topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss)
- }
+ override def dataFormat(): DataFormat = DataFormat.UNSAFE_ROW
- override def createDataReader(): KafkaContinuousDataReader = {
+ override def createUnsafeRowDataReader(): KafkaContinuousDataReader = {
new KafkaContinuousDataReader(
topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss)
}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala
index 2ed49ba3f549..c0b10a7a2994 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala
@@ -32,8 +32,8 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset}
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
-import org.apache.spark.sql.sources.v2.DataSourceOptions
-import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
+import org.apache.spark.sql.sources.v2.{DataFormat, DataSourceOptions}
+import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.UninterruptibleThread
@@ -61,7 +61,7 @@ private[kafka010] class KafkaMicroBatchReader(
metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
- extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
+ extends MicroBatchReader with Logging {
private var startPartitionOffsets: PartitionOffsetMap = _
private var endPartitionOffsets: PartitionOffsetMap = _
@@ -101,7 +101,7 @@ private[kafka010] class KafkaMicroBatchReader(
}
}
- override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
+ override def createDataReaderFactories(): ju.List[DataReaderFactory] = {
// Find the new partitions, and get their earliest offsets
val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet)
val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
@@ -146,7 +146,7 @@ private[kafka010] class KafkaMicroBatchReader(
new KafkaMicroBatchDataReaderFactory(
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
}
- factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava
+ factories.map(_.asInstanceOf[DataReaderFactory]).asJava
}
override def getStartOffset: Offset = {
@@ -305,11 +305,13 @@ private[kafka010] case class KafkaMicroBatchDataReaderFactory(
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean,
- reuseKafkaConsumer: Boolean) extends DataReaderFactory[UnsafeRow] {
+ reuseKafkaConsumer: Boolean) extends DataReaderFactory {
override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray
- override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader(
+ override def dataFormat(): DataFormat = DataFormat.UNSAFE_ROW
+
+ override def createUnsafeRowDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader(
offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
}
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
index e017fd9b84d2..730959d0873f 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
@@ -36,7 +36,6 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkContext
import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession}
-import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
@@ -673,7 +672,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase {
Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))),
Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L)))
)
- val factories = reader.createUnsafeRowReaderFactories().asScala
+ val factories = reader.createDataReaderFactories().asScala
.map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory])
withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") {
assert(factories.size == numPartitionsGenerated)
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataFormat.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataFormat.java
new file mode 100644
index 000000000000..b1bcac6a7a3b
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataFormat.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2;
+
+import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;
+
+/**
+ * An enum returned by {@link DataReaderFactory#dataFormat()}, representing the output data format of
+ * a data source scan.
+ *
+ *
+ * - {@link #ROW}
+ * - {@link #UNSAFE_ROW}
+ * - {@link #COLUMNAR_BATCH}
+ *
+ *
+ * TODO: add INTERNAL_ROW
+ */
+public enum DataFormat {
+ /**
+ * Refers to {@link org.apache.spark.sql.Row}, which is very stable and guaranteed to be backward
+ * compatible. Spark needs to convert data of row format to the internal format, data source developers
+ * should consider using other formats for better performance.
+ */
+ ROW,
+
+ /**
+ * Refers to {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow}, which is an unstable and
+ * internal API. It's already the internal format in Spark, so there is no extra conversion needed.
+ */
+ UNSAFE_ROW,
+
+ /**
+ * Refers to {@link org.apache.spark.sql.vectorized.ColumnarBatch}, which is a public but experimental
+ * API. It's already the internal format in Spark, so there is no extra conversion needed. This format
+ * is recommended over others as columnar format has other advantages like vectorization, to further
+ * speed up the data processing.
+ */
+ COLUMNAR_BATCH
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java
index a61697649c43..0d7d053b61d2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java
@@ -18,18 +18,48 @@
package org.apache.spark.sql.sources.v2.reader;
import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader;
+import org.apache.spark.sql.vectorized.ColumnarBatch;
/**
* A mix-in interface for {@link DataReaderFactory}. Continuous data reader factories can
* implement this interface to provide creating {@link DataReader} with particular offset.
*/
@InterfaceStability.Evolving
-public interface ContinuousDataReaderFactory extends DataReaderFactory {
+public interface ContinuousDataReaderFactory extends DataReaderFactory {
+
+ /**
+ * Returns a row-formatted continuous data reader to do the actual reading work.
+ *
+ * If this method fails (by throwing an exception), the corresponding Spark task would fail and
+ * get retried until hitting the maximum retry times.
+ */
+ default ContinuousDataReader createRowDataReader() {
+ throw new IllegalStateException(
+ "createRowDataReader must be implemented if the data format is ROW.");
+ }
+
+ /**
+ * Returns a unsafe-row-formatted continuous data reader to do the actual reading work.
+ *
+ * If this method fails (by throwing an exception), the corresponding Spark task would fail and
+ * get retried until hitting the maximum retry times.
+ */
+ default ContinuousDataReader createUnsafeRowDataReader() {
+ throw new IllegalStateException(
+ "createUnsafeRowDataReader must be implemented if the data format is UNSAFE_ROW.");
+ }
+
/**
- * Create a DataReader with particular offset as its startOffset.
+ * Returns a columnar-batch-formatted continuous data reader to do the actual reading work.
*
- * @param offset offset want to set as the DataReader's startOffset.
+ * If this method fails (by throwing an exception), the corresponding Spark task would fail and
+ * get retried until hitting the maximum retry times.
*/
- DataReader createDataReaderWithOffset(PartitionOffset offset);
+ default ContinuousDataReader createColumnarBatchDataReader() {
+ throw new IllegalStateException(
+ "createColumnarBatchDataReader must be implemented if the data format is COLUMNAR_BATCH.");
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java
index bb9790a1c819..04a1303c0bf5 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java
@@ -23,12 +23,13 @@
import org.apache.spark.annotation.InterfaceStability;
/**
- * A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for
- * outputting data for a RDD partition.
+ * A data reader returned by the create data reader method in {@link DataReaderFactory} and is
+ * responsible for outputting data for a RDD partition.
*
- * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data
- * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source
- * readers that mix in {@link SupportsScanUnsafeRow}.
+ * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row},
+ * or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow}, or
+ * {@link org.apache.spark.sql.vectorized.ColumnarBatch}, depending on the return type of
+ * {@link DataReaderFactory#dataFormat()}.
*/
@InterfaceStability.Evolving
public interface DataReader extends Closeable {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java
index 32e98e8f5d8b..86e0579318a8 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java
@@ -20,6 +20,10 @@
import java.io.Serializable;
import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.sources.v2.DataFormat;
+import org.apache.spark.sql.vectorized.ColumnarBatch;
/**
* A reader factory returned by {@link DataSourceReader#createDataReaderFactories()} and is
@@ -32,7 +36,7 @@
* serializable and {@link DataReader} doesn't need to be.
*/
@InterfaceStability.Evolving
-public interface DataReaderFactory extends Serializable {
+public interface DataReaderFactory extends Serializable {
/**
* The preferred locations where the data reader returned by this reader factory can run faster,
@@ -52,10 +56,46 @@ default String[] preferredLocations() {
}
/**
- * Returns a data reader to do the actual reading work.
+ * The output data format of this factory's data reader. Spark will invoke the corresponding
+ * create data reader method w.r.t. the return value of this method:
+ *
+ * - {@link DataFormat#ROW}: {@link #createRowDataReader()}
+ * - {@link DataFormat#UNSAFE_ROW}: {@link #createUnsafeRowDataReader()}
+ * - @{@link DataFormat#COLUMNAR_BATCH}: {@link #createColumnarBatchDataReader()}
+ *
+ */
+ DataFormat dataFormat();
+
+ /**
+ * Returns a row-formatted data reader to do the actual reading work.
+ *
+ * If this method fails (by throwing an exception), the corresponding Spark task would fail and
+ * get retried until hitting the maximum retry times.
+ */
+ default DataReader createRowDataReader() {
+ throw new IllegalStateException(
+ "createRowDataReader must be implemented if the data format is ROW.");
+ }
+
+ /**
+ * Returns a unsafe-row-formatted data reader to do the actual reading work.
*
* If this method fails (by throwing an exception), the corresponding Spark task would fail and
* get retried until hitting the maximum retry times.
*/
- DataReader createDataReader();
+ default DataReader createUnsafeRowDataReader() {
+ throw new IllegalStateException(
+ "createUnsafeRowDataReader must be implemented if the data format is UNSAFE_ROW.");
+ }
+
+ /**
+ * Returns a columnar-batch-formatted data reader to do the actual reading work.
+ *
+ * If this method fails (by throwing an exception), the corresponding Spark task would fail and
+ * get retried until hitting the maximum retry times.
+ */
+ default DataReader createColumnarBatchDataReader() {
+ throw new IllegalStateException(
+ "createColumnarBatchDataReader must be implemented if the data format is COLUMNAR_BATCH.");
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java
index a470bccc5aad..2049a96fe903 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java
@@ -20,7 +20,6 @@
import java.util.List;
import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.Row;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.ReadSupport;
import org.apache.spark.sql.sources.v2.ReadSupportWithSchema;
@@ -34,23 +33,18 @@
* logic is delegated to {@link DataReaderFactory}s that are returned by
* {@link #createDataReaderFactories()}.
*
- * There are mainly 3 kinds of query optimizations:
+ * There are mainly 2 kinds of query optimizations:
* 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column
* pruning), etc. Names of these interfaces start with `SupportsPushDown`.
* 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc.
* Names of these interfaces start with `SupportsReporting`.
- * 3. Special scans. E.g, columnar scan, unsafe row scan, etc.
- * Names of these interfaces start with `SupportsScan`. Note that a reader should only
- * implement at most one of the special scans, if more than one special scans are implemented,
- * only one of them would be respected, according to the priority list from high to low:
- * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}.
*
* If an exception was throw when applying any of these query optimizations, the action would fail
* and no Spark job was submitted.
*
* Spark first applies all operator push-down optimizations that this data source supports. Then
* Spark collects information this data source reported for further optimizations. Finally Spark
- * issues the scan request and does the actual data reading.
+ * issues the scan request, create the {@link DataReaderFactory} and does the actual data reading.
*/
@InterfaceStability.Evolving
public interface DataSourceReader {
@@ -76,5 +70,5 @@ public interface DataSourceReader {
* If this method fails (by throwing an exception), the action would fail and no Spark job was
* submitted.
*/
- List> createDataReaderFactories();
+ List createDataReaderFactories();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java
deleted file mode 100644
index 2e5cfa78511f..000000000000
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.sources.v2.reader;
-
-import java.util.List;
-
-import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.vectorized.ColumnarBatch;
-
-/**
- * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this
- * interface to output {@link ColumnarBatch} and make the scan faster.
- */
-@InterfaceStability.Evolving
-public interface SupportsScanColumnarBatch extends DataSourceReader {
- @Override
- default List> createDataReaderFactories() {
- throw new IllegalStateException(
- "createDataReaderFactories not supported by default within SupportsScanColumnarBatch.");
- }
-
- /**
- * Similar to {@link DataSourceReader#createDataReaderFactories()}, but returns columnar data
- * in batches.
- */
- List> createBatchDataReaderFactories();
-
- /**
- * Returns true if the concrete data source reader can read data in batch according to the scan
- * properties like required columns, pushes filters, etc. It's possible that the implementation
- * can only support some certain columns with certain types. Users can overwrite this method and
- * {@link #createDataReaderFactories()} to fallback to normal read path under some conditions.
- */
- default boolean enableBatchRead() {
- return true;
- }
-}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java
deleted file mode 100644
index 9cd749e8e4ce..000000000000
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.sources.v2.reader;
-
-import java.util.List;
-
-import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
-
-/**
- * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this
- * interface to output {@link UnsafeRow} directly and avoid the row copy at Spark side.
- * This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get
- * changed in the future Spark versions.
- */
-@InterfaceStability.Unstable
-public interface SupportsScanUnsafeRow extends DataSourceReader {
-
- @Override
- default List> createDataReaderFactories() {
- throw new IllegalStateException(
- "createDataReaderFactories not supported by default within SupportsScanUnsafeRow");
- }
-
- /**
- * Similar to {@link DataSourceReader#createDataReaderFactories()},
- * but returns data in unsafe row format.
- */
- List> createUnsafeRowReaderFactories();
-}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java
index 7fe7f00ac2fa..d908a5034ead 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java
@@ -17,12 +17,12 @@
package org.apache.spark.sql.sources.v2.reader.streaming;
+import java.util.Optional;
+
import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.execution.streaming.BaseStreamingSource;
import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
-import java.util.Optional;
-
/**
* A mix-in interface for {@link DataSourceReader}. Data source readers can implement this
* interface to allow reading in a continuous processing mode stream.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java
index 67ebde30d61a..0c8a40594408 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java
@@ -17,12 +17,12 @@
package org.apache.spark.sql.sources.v2.reader.streaming;
+import java.util.Optional;
+
import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
import org.apache.spark.sql.execution.streaming.BaseStreamingSource;
-import java.util.Optional;
-
/**
* A mix-in interface for {@link DataSourceReader}. Data source readers can implement this
* interface to indicate they allow micro-batch streaming reads.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index f85971be394b..a40eea1a9389 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -17,20 +17,24 @@
package org.apache.spark.sql.execution.datasources.v2
-import scala.collection.JavaConverters._
-import scala.reflect.ClassTag
-
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.sources.v2.DataFormat
+import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}
+import org.apache.spark.sql.types.StructType
-class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T])
+class DataSourceRDDPartition(val index: Int, val factory: DataReaderFactory)
extends Partition with Serializable
-class DataSourceRDD[T: ClassTag](
+class DataSourceRDD(
sc: SparkContext,
- @transient private val readerFactories: Seq[DataReaderFactory[T]])
- extends RDD[T](sc, Nil) {
+ @transient private val readerFactories: Seq[DataReaderFactory],
+ schema: StructType)
+ extends RDD[InternalRow](sc, Nil) {
override protected def getPartitions: Array[Partition] = {
readerFactories.zipWithIndex.map {
@@ -38,31 +42,60 @@ class DataSourceRDD[T: ClassTag](
}.toArray
}
- override def compute(split: Partition, context: TaskContext): Iterator[T] = {
- val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader()
- context.addTaskCompletionListener(_ => reader.close())
- val iter = new Iterator[T] {
- private[this] var valuePrepared = false
-
- override def hasNext: Boolean = {
- if (!valuePrepared) {
- valuePrepared = reader.next()
- }
- valuePrepared
- }
-
- override def next(): T = {
- if (!hasNext) {
- throw new java.util.NoSuchElementException("End of stream")
- }
- valuePrepared = false
- reader.get()
- }
+ override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
+ val factory = split.asInstanceOf[DataSourceRDDPartition].factory
+ val iter: DataReaderIterator[UnsafeRow] = factory.dataFormat() match {
+ case DataFormat.ROW =>
+ val reader = new RowToUnsafeDataReader(
+ factory.createRowDataReader(), RowEncoder.apply(schema).resolveAndBind())
+ new DataReaderIterator(reader)
+
+ case DataFormat.UNSAFE_ROW =>
+ new DataReaderIterator(factory.createUnsafeRowDataReader())
+
+ case DataFormat.COLUMNAR_BATCH =>
+ new DataReaderIterator(factory.createColumnarBatchDataReader())
+ // TODO: remove this type erase hack.
+ .asInstanceOf[DataReaderIterator[UnsafeRow]]
}
+ context.addTaskCompletionListener(_ => iter.close())
new InterruptibleIterator(context, iter)
}
override def getPreferredLocations(split: Partition): Seq[String] = {
- split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations()
+ split.asInstanceOf[DataSourceRDDPartition].factory.preferredLocations()
}
}
+
+class RowToUnsafeDataReader(rowReader: DataReader[Row], encoder: ExpressionEncoder[Row])
+ extends DataReader[UnsafeRow] {
+
+ override def next: Boolean = rowReader.next
+
+ override def get: UnsafeRow = encoder.toRow(rowReader.get).asInstanceOf[UnsafeRow]
+
+ override def close(): Unit = rowReader.close()
+}
+
+class DataReaderIterator[T](reader: DataReader[T]) extends Iterator[T] {
+ private[this] var valuePrepared = false
+
+ override def hasNext: Boolean = {
+ if (!valuePrepared) {
+ valuePrepared = reader.next()
+ // no more data, close the reader.
+ if (!valuePrepared) close()
+ }
+ valuePrepared
+ }
+
+ override def next(): T = {
+ if (!hasNext) {
+ throw new java.util.NoSuchElementException("End of stream")
+ }
+ valuePrepared = false
+ reader.get()
+ }
+
+ def close(): Unit = reader.close()
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
index 3a5e7bf89e14..33336cba2234 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
@@ -20,19 +20,15 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.JavaConverters._
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
import org.apache.spark.sql.execution.streaming.continuous._
-import org.apache.spark.sql.sources.v2.DataSourceV2
+import org.apache.spark.sql.sources.v2.{DataFormat, DataSourceV2}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.vectorized.ColumnarBatch
/**
* Physical plan node for scanning data from a data source.
@@ -57,36 +53,22 @@ case class DataSourceV2ScanExec(
Seq(output, source, options).hashCode()
}
- override def outputPartitioning: physical.Partitioning = reader match {
- case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 =>
+ override def outputPartitioning: physical.Partitioning = {
+ if (readerFactories.length == 1) {
SinglePartition
+ } else {
+ reader match {
+ case s: SupportsReportPartitioning =>
+ new DataSourcePartitioning(
+ s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name)))
- case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 =>
- SinglePartition
-
- case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 =>
- SinglePartition
-
- case s: SupportsReportPartitioning =>
- new DataSourcePartitioning(
- s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name)))
-
- case _ => super.outputPartitioning
- }
-
- private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match {
- case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala
- case _ =>
- reader.createDataReaderFactories().asScala.map {
- new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow]
+ case _ => super.outputPartitioning
}
+ }
}
- private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match {
- case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
- assert(!reader.isInstanceOf[ContinuousReader],
- "continuous stream reader does not support columnar read yet.")
- r.createBatchDataReaderFactories().asScala
+ private lazy val readerFactories: Seq[DataReaderFactory] = {
+ reader.createDataReaderFactories().asScala
}
private lazy val inputRDD: RDD[InternalRow] = reader match {
@@ -95,21 +77,29 @@ case class DataSourceV2ScanExec(
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
sparkContext.env)
.askSync[Unit](SetReaderPartitions(readerFactories.size))
- new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories)
- .asInstanceOf[RDD[InternalRow]]
-
- case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
- new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]]
-
+ if (supportsBatch) {
+ throw new IllegalArgumentException(
+ "continuous stream reader does not support columnar read yet.")
+ }
+ new ContinuousDataSourceRDD(
+ sparkContext,
+ sqlContext,
+ readerFactories.map(_.asInstanceOf[ContinuousDataReaderFactory]),
+ reader.readSchema())
case _ =>
- new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]]
+ new DataSourceRDD(sparkContext, readerFactories, reader.readSchema())
}
override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD)
- override val supportsBatch: Boolean = reader match {
- case r: SupportsScanColumnarBatch if r.enableBatchRead() => true
- case _ => false
+ override val supportsBatch: Boolean = {
+ val formats = readerFactories.map(_.dataFormat()).distinct
+ if (formats.length > 1) {
+ throw new IllegalArgumentException("Currently Spark requires all the data reader " +
+ "factories returned by the DataSourceReader output same data format.")
+ }
+
+ formats.nonEmpty && formats.head == DataFormat.COLUMNAR_BATCH
}
override protected def needsUnsafeRowConversion: Boolean = false
@@ -126,24 +116,3 @@ case class DataSourceV2ScanExec(
}
}
}
-
-class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType)
- extends DataReaderFactory[UnsafeRow] {
-
- override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations
-
- override def createDataReader: DataReader[UnsafeRow] = {
- new RowToUnsafeDataReader(
- rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind())
- }
-}
-
-class RowToUnsafeDataReader(val rowReader: DataReader[Row], encoder: ExpressionEncoder[Row])
- extends DataReader[UnsafeRow] {
-
- override def next: Boolean = rowReader.next
-
- override def get: UnsafeRow = encoder.toRow(rowReader.get).asInstanceOf[UnsafeRow]
-
- override def close(): Unit = rowReader.close()
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
index 06754f01657d..cfd1f888c36e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
@@ -20,41 +20,58 @@ package org.apache.spark.sql.execution.streaming.continuous
import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit}
import java.util.concurrent.atomic.AtomicBoolean
-import scala.collection.JavaConverters._
-
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader}
+import org.apache.spark.sql.execution.datasources.v2.RowToUnsafeDataReader
+import org.apache.spark.sql.sources.v2.DataFormat
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset}
+import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ThreadUtils
+class ContinuousDataSourceRDDPartition(
+ val index: Int,
+ val factory: ContinuousDataReaderFactory) extends Partition with Serializable
+
class ContinuousDataSourceRDD(
sc: SparkContext,
sqlContext: SQLContext,
- @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]])
- extends RDD[UnsafeRow](sc, Nil) {
+ @transient private val readerFactories: Seq[ContinuousDataReaderFactory],
+ schema: StructType)
+ extends RDD[InternalRow](sc, Nil) {
private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize
private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs
override protected def getPartitions: Array[Partition] = {
readerFactories.zipWithIndex.map {
- case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
+ case (readerFactory, index) => new ContinuousDataSourceRDDPartition(index, readerFactory)
}.toArray
}
- override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
// If attempt number isn't 0, this is a task retry, which we don't support.
if (context.attemptNumber() != 0) {
throw new ContinuousTaskRetryException()
}
- val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]]
- .readerFactory.createDataReader()
+ val factory = split.asInstanceOf[ContinuousDataSourceRDDPartition].factory
+ val reader = factory.dataFormat() match {
+ case DataFormat.ROW =>
+ new RowToUnsafeContinuousDataReader(
+ factory.createRowDataReader().asInstanceOf[ContinuousDataReader[Row]],
+ RowEncoder.apply(schema).resolveAndBind())
+
+ case DataFormat.UNSAFE_ROW => factory.createUnsafeRowDataReader()
+
+ case other =>
+ throw new IllegalArgumentException(s"Illegal data format specified: $other")
+ }
val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)
@@ -71,7 +88,7 @@ class ContinuousDataSourceRDD(
epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS)
// Important sequencing - we must get start offset before the data reader thread begins
- val startOffset = ContinuousDataSourceRDD.getBaseReader(reader).getOffset
+ val startOffset = reader.getOffset
val dataReaderFailed = new AtomicBoolean(false)
val dataReaderThread = new DataReaderThread(reader, queue, context, dataReaderFailed)
@@ -133,7 +150,7 @@ class ContinuousDataSourceRDD(
}
override def getPreferredLocations(split: Partition): Seq[String] = {
- split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readerFactory.preferredLocations()
+ split.asInstanceOf[ContinuousDataSourceRDDPartition].factory.preferredLocations()
}
}
@@ -168,7 +185,7 @@ class EpochPollRunnable(
}
class DataReaderThread(
- reader: DataReader[UnsafeRow],
+ reader: ContinuousDataReader[UnsafeRow],
queue: BlockingQueue[(UnsafeRow, PartitionOffset)],
context: TaskContext,
failedFlag: AtomicBoolean)
@@ -179,7 +196,6 @@ class DataReaderThread(
override def run(): Unit = {
TaskContext.setTaskContext(context)
- val baseReader = ContinuousDataSourceRDD.getBaseReader(reader)
try {
while (!context.isInterrupted && !context.isCompleted()) {
if (!reader.next()) {
@@ -192,7 +208,7 @@ class DataReaderThread(
}
}
- queue.put((reader.get().copy(), baseReader.getOffset))
+ queue.put((reader.get().copy(), reader.getOffset))
}
} catch {
case _: InterruptedException if context.isInterrupted() =>
@@ -209,14 +225,10 @@ class DataReaderThread(
}
}
-object ContinuousDataSourceRDD {
- private[continuous] def getBaseReader(reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = {
- reader match {
- case r: ContinuousDataReader[UnsafeRow] => r
- case wrapped: RowToUnsafeDataReader =>
- wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]]
- case _ =>
- throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}")
- }
- }
+class RowToUnsafeContinuousDataReader(
+ rowReader: ContinuousDataReader[Row],
+ encoder: ExpressionEncoder[Row])
+ extends RowToUnsafeDataReader(rowReader, encoder) with ContinuousDataReader[UnsafeRow] {
+
+ override def getOffset: PartitionOffset = rowReader.getOffset
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
index 2f0de2612c15..6b1a561b866b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
-import org.apache.spark.sql.sources.v2.DataSourceOptions
+import org.apache.spark.sql.sources.v2.{DataFormat, DataSourceOptions}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
import org.apache.spark.sql.types.StructType
@@ -67,7 +67,7 @@ class RateStreamContinuousReader(options: DataSourceOptions)
override def getStartOffset(): Offset = offset
- override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
+ override def createDataReaderFactories(): java.util.List[DataReaderFactory] = {
val partitionStartMap = offset match {
case off: RateStreamOffset => off.partitionToValueAndRunTimeMs
case off =>
@@ -90,8 +90,7 @@ class RateStreamContinuousReader(options: DataSourceOptions)
start.runTimeMs,
i,
numPartitions,
- perPartitionRate)
- .asInstanceOf[DataReaderFactory[Row]]
+ perPartitionRate).asInstanceOf[DataReaderFactory]
}.asJava
}
@@ -119,21 +118,11 @@ case class RateStreamContinuousDataReaderFactory(
partitionIndex: Int,
increment: Long,
rowsPerSecond: Double)
- extends ContinuousDataReaderFactory[Row] {
+ extends ContinuousDataReaderFactory {
- override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[Row] = {
- val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset]
- require(rateStreamOffset.partition == partitionIndex,
- s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}")
- new RateStreamContinuousDataReader(
- rateStreamOffset.currentValue,
- rateStreamOffset.currentTimeMs,
- partitionIndex,
- increment,
- rowsPerSecond)
- }
+ override def dataFormat(): DataFormat = DataFormat.ROW
- override def createDataReader(): DataReader[Row] =
+ override def createRowDataReader(): ContinuousDataReader[Row] =
new RateStreamContinuousDataReader(
startValue, startTimeMs, partitionIndex, increment, rowsPerSecond)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 628923d367ce..a8e9cff37e29 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -24,19 +24,18 @@ import javax.annotation.concurrent.GuardedBy
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
-import scala.reflect.ClassTag
import scala.util.control.NonFatal
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
-import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
-import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
+import org.apache.spark.sql.sources.v2.DataFormat
+import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
-import org.apache.spark.sql.streaming.{OutputMode, Trigger}
+import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
@@ -82,7 +81,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas
*/
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
extends MemoryStreamBase[A](sqlContext)
- with MicroBatchReader with SupportsScanUnsafeRow with Logging {
+ with MicroBatchReader with Logging {
protected val logicalPlan: LogicalPlan =
StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession)
@@ -141,7 +140,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
if (endOffset.offset == -1) null else endOffset
}
- override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
+ override def createDataReaderFactories(): ju.List[DataReaderFactory] = {
synchronized {
// Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
val startOrdinal = startOffset.offset.toInt + 1
@@ -158,7 +157,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal))
newBlocks.map { block =>
- new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]]
+ new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory]
}.asJava
}
}
@@ -204,8 +203,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
class MemoryStreamDataReaderFactory(records: Array[UnsafeRow])
- extends DataReaderFactory[UnsafeRow] {
- override def createDataReader(): DataReader[UnsafeRow] = {
+ extends DataReaderFactory {
+
+ override def dataFormat(): DataFormat = DataFormat.UNSAFE_ROW
+
+ override def createUnsafeRowDataReader(): DataReader[UnsafeRow] = {
new DataReader[UnsafeRow] {
private var currentIndex = -1
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
index c28919b8b729..6ed6af558500 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
@@ -33,8 +33,8 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeR
import org.apache.spark.sql.{Encoder, Row, SQLContext}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord
-import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions}
-import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
+import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataFormat, DataSourceOptions}
+import org.apache.spark.sql.sources.v2.reader.{ContinuousDataReaderFactory, DataReaderFactory}
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.RpcUtils
@@ -99,7 +99,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
)
}
- override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = {
+ override def createDataReaderFactories(): ju.List[DataReaderFactory] = {
synchronized {
val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id"
endpointRef =
@@ -108,7 +108,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
startOffset.partitionNums.map {
case (part, index) =>
new ContinuousMemoryStreamDataReaderFactory(
- endpointName, part, index): DataReaderFactory[Row]
+ endpointName, part, index): DataReaderFactory
}.toList.asJava
}
}
@@ -160,8 +160,11 @@ object ContinuousMemoryStream {
class ContinuousMemoryStreamDataReaderFactory(
driverEndpointName: String,
partition: Int,
- startOffset: Int) extends DataReaderFactory[Row] {
- override def createDataReader: ContinuousMemoryStreamDataReader =
+ startOffset: Int) extends ContinuousDataReaderFactory {
+
+ override def dataFormat(): DataFormat = DataFormat.ROW
+
+ override def createRowDataReader: ContinuousMemoryStreamDataReader =
new ContinuousMemoryStreamDataReader(driverEndpointName, partition, startOffset)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
index 6cf8520fc544..6a0ee6b89b6c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
@@ -31,7 +31,7 @@ import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.sources.v2.DataSourceOptions
+import org.apache.spark.sql.sources.v2.{DataFormat, DataSourceOptions}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
import org.apache.spark.sql.types.StructType
@@ -134,7 +134,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation:
LongOffset(json.toLong)
}
- override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
+ override def createDataReaderFactories(): java.util.List[DataReaderFactory] = {
val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L)
val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
@@ -169,7 +169,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation:
(0 until numPartitions).map { p =>
new RateStreamMicroBatchDataReaderFactory(
p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue)
- : DataReaderFactory[Row]
+ : DataReaderFactory
}.toList.asJava
}
@@ -188,9 +188,11 @@ class RateStreamMicroBatchDataReaderFactory(
rangeStart: Long,
rangeEnd: Long,
localStartTimeMs: Long,
- relativeMsPerValue: Double) extends DataReaderFactory[Row] {
+ relativeMsPerValue: Double) extends DataReaderFactory {
- override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader(
+ override def dataFormat(): DataFormat = DataFormat.ROW;
+
+ override def createRowDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader(
partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
index 5aae46b46339..848612abba6d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
@@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming.LongOffset
import org.apache.spark.sql.sources.DataSourceRegister
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2.{DataFormat, DataSourceOptions, DataSourceV2, MicroBatchReadSupport}
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
@@ -140,7 +140,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR
}
}
- override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
+ override def createDataReaderFactories(): JList[DataReaderFactory] = {
assert(startOffset != null && endOffset != null,
"start offset and end offset should already be set before create read tasks.")
@@ -165,8 +165,10 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR
(0 until numPartitions).map { i =>
val slice = slices(i)
- new DataReaderFactory[Row] {
- override def createDataReader(): DataReader[Row] = new DataReader[Row] {
+ new DataReaderFactory {
+ override def dataFormat(): DataFormat = DataFormat.ROW
+
+ override def createRowDataReader(): DataReader[Row] = new DataReader[Row] {
private var currentIdx = -1
override def next(): Boolean = {
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
index 172e5d5eebcb..69447ce31ab0 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
@@ -24,6 +24,7 @@
import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.sources.GreaterThan;
+import org.apache.spark.sql.sources.v2.DataFormat;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.ReadSupport;
@@ -79,8 +80,8 @@ public Filter[] pushedFilters() {
}
@Override
- public List> createDataReaderFactories() {
- List> res = new ArrayList<>();
+ public List createDataReaderFactories() {
+ List res = new ArrayList<>();
Integer lowerBound = null;
for (Filter filter : filters) {
@@ -107,7 +108,7 @@ public List> createDataReaderFactories() {
}
}
- static class JavaAdvancedDataReaderFactory implements DataReaderFactory, DataReader {
+ static class JavaAdvancedDataReaderFactory implements DataReaderFactory, DataReader {
private int start;
private int end;
private StructType requiredSchema;
@@ -119,7 +120,12 @@ static class JavaAdvancedDataReaderFactory implements DataReaderFactory, Da
}
@Override
- public DataReader createDataReader() {
+ public DataFormat dataFormat() {
+ return DataFormat.ROW;
+ }
+
+ @Override
+ public DataReader createRowDataReader() {
return new JavaAdvancedDataReaderFactory(start - 1, end, requiredSchema);
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java
index c55093768105..3d62128175da 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java
@@ -21,6 +21,7 @@
import java.util.List;
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
+import org.apache.spark.sql.sources.v2.DataFormat;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.ReadSupport;
@@ -33,7 +34,7 @@
public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport {
- class Reader implements DataSourceReader, SupportsScanColumnarBatch {
+ class Reader implements DataSourceReader {
private final StructType schema = new StructType().add("i", "int").add("j", "int");
@Override
@@ -42,14 +43,14 @@ public StructType readSchema() {
}
@Override
- public List> createBatchDataReaderFactories() {
+ public List createDataReaderFactories() {
return java.util.Arrays.asList(
new JavaBatchDataReaderFactory(0, 50), new JavaBatchDataReaderFactory(50, 90));
}
}
static class JavaBatchDataReaderFactory
- implements DataReaderFactory, DataReader {
+ implements DataReaderFactory, DataReader {
private int start;
private int end;
@@ -65,7 +66,12 @@ static class JavaBatchDataReaderFactory
}
@Override
- public DataReader createDataReader() {
+ public DataFormat dataFormat() {
+ return DataFormat.COLUMNAR_BATCH;
+ }
+
+ @Override
+ public DataReader createColumnarBatchDataReader() {
this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType);
this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType);
ColumnVector[] vectors = new ColumnVector[2];
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
index 32fad59b97ff..f5e9adfe24ae 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
@@ -23,6 +23,7 @@
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
+import org.apache.spark.sql.sources.v2.DataFormat;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.ReadSupport;
@@ -43,7 +44,7 @@ public StructType readSchema() {
}
@Override
- public List> createDataReaderFactories() {
+ public List createDataReaderFactories() {
return java.util.Arrays.asList(
new SpecificDataReaderFactory(new int[]{1, 1, 3}, new int[]{4, 4, 6}),
new SpecificDataReaderFactory(new int[]{2, 4, 4}, new int[]{6, 2, 2}));
@@ -73,7 +74,7 @@ public boolean satisfy(Distribution distribution) {
}
}
- static class SpecificDataReaderFactory implements DataReaderFactory, DataReader {
+ static class SpecificDataReaderFactory implements DataReaderFactory, DataReader {
private int[] i;
private int[] j;
private int current = -1;
@@ -101,7 +102,12 @@ public void close() throws IOException {
}
@Override
- public DataReader createDataReader() {
+ public DataFormat dataFormat() {
+ return DataFormat.ROW;
+ }
+
+ @Override
+ public DataReader createRowDataReader() {
return this;
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java
index 048d078dfaac..87a3df3b2396 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java
@@ -42,7 +42,7 @@ public StructType readSchema() {
}
@Override
- public List> createDataReaderFactories() {
+ public List createDataReaderFactories() {
return java.util.Collections.emptyList();
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java
index 96f55b8a7681..09edf8ad7941 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java
@@ -22,6 +22,7 @@
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
+import org.apache.spark.sql.sources.v2.DataFormat;
import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.ReadSupport;
@@ -41,14 +42,14 @@ public StructType readSchema() {
}
@Override
- public List> createDataReaderFactories() {
+ public List createDataReaderFactories() {
return java.util.Arrays.asList(
new JavaSimpleDataReaderFactory(0, 5),
new JavaSimpleDataReaderFactory(5, 10));
}
}
- static class JavaSimpleDataReaderFactory implements DataReaderFactory, DataReader {
+ static class JavaSimpleDataReaderFactory implements DataReaderFactory, DataReader {
private int start;
private int end;
@@ -58,7 +59,12 @@ static class JavaSimpleDataReaderFactory implements DataReaderFactory, Data
}
@Override
- public DataReader createDataReader() {
+ public DataFormat dataFormat() {
+ return DataFormat.ROW;
+ }
+
+ @Override
+ public DataReader createRowDataReader() {
return new JavaSimpleDataReaderFactory(start - 1, end);
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java
index c3916e0b370b..211e9e5fefd4 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java
@@ -21,6 +21,7 @@
import java.util.List;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.sources.v2.DataFormat;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.ReadSupport;
@@ -29,7 +30,7 @@
public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport {
- class Reader implements DataSourceReader, SupportsScanUnsafeRow {
+ class Reader implements DataSourceReader {
private final StructType schema = new StructType().add("i", "int").add("j", "int");
@Override
@@ -38,7 +39,7 @@ public StructType readSchema() {
}
@Override
- public List> createUnsafeRowReaderFactories() {
+ public List createDataReaderFactories() {
return java.util.Arrays.asList(
new JavaUnsafeRowDataReaderFactory(0, 5),
new JavaUnsafeRowDataReaderFactory(5, 10));
@@ -46,7 +47,7 @@ public List> createUnsafeRowReaderFactories() {
}
static class JavaUnsafeRowDataReaderFactory
- implements DataReaderFactory, DataReader {
+ implements DataReaderFactory, DataReader {
private int start;
private int end;
private UnsafeRow row;
@@ -59,7 +60,12 @@ static class JavaUnsafeRowDataReaderFactory
}
@Override
- public DataReader createDataReader() {
+ public DataFormat dataFormat() {
+ return DataFormat.UNSAFE_ROW;
+ }
+
+ @Override
+ public DataReader createUnsafeRowDataReader() {
return new JavaUnsafeRowDataReaderFactory(start - 1, end);
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
index ff14ec38e66a..9d1a0b2f42ef 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
@@ -24,8 +24,7 @@ import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous._
@@ -144,7 +143,7 @@ class RateSourceSuite extends StreamTest {
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
val tasks = reader.createDataReaderFactories()
assert(tasks.size == 1)
- val dataReader = tasks.get(0).createDataReader()
+ val dataReader = tasks.get(0).createRowDataReader()
val data = ArrayBuffer[Row]()
while (dataReader.next()) {
data.append(dataReader.get())
@@ -163,7 +162,7 @@ class RateSourceSuite extends StreamTest {
assert(tasks.size == 11)
val readData = tasks.asScala
- .map(_.createDataReader())
+ .map(_.createRowDataReader())
.flatMap { reader =>
val buf = scala.collection.mutable.ListBuffer[Row]()
while (reader.next()) buf.append(reader.get())
@@ -314,7 +313,7 @@ class RateSourceSuite extends StreamTest {
.asInstanceOf[RateStreamOffset]
.partitionToValueAndRunTimeMs(t.partitionIndex)
.runTimeMs
- val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader]
+ val r = t.createRowDataReader().asInstanceOf[RateStreamContinuousDataReader]
for (rowIndex <- 0 to 9) {
r.next()
data.append(r.get())
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
index e0a53272cd22..518a31cc1cd8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
@@ -346,7 +346,7 @@ class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport {
class Reader extends DataSourceReader {
override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
- override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
+ override def createDataReaderFactories(): JList[DataReaderFactory] = {
java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5))
}
}
@@ -359,7 +359,7 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {
class Reader extends DataSourceReader {
override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
- override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
+ override def createDataReaderFactories(): JList[DataReaderFactory] = {
java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5), new SimpleDataReaderFactory(5, 10))
}
}
@@ -368,11 +368,13 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {
}
class SimpleDataReaderFactory(start: Int, end: Int)
- extends DataReaderFactory[Row]
+ extends DataReaderFactory
with DataReader[Row] {
private var current = start - 1
- override def createDataReader(): DataReader[Row] = new SimpleDataReaderFactory(start, end)
+ override def dataFormat(): DataFormat = DataFormat.ROW
+
+ override def createRowDataReader(): DataReader[Row] = new SimpleDataReaderFactory(start, end)
override def next(): Boolean = {
current += 1
@@ -413,12 +415,12 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport {
requiredSchema
}
- override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
+ override def createDataReaderFactories(): JList[DataReaderFactory] = {
val lowerBound = filters.collect {
case GreaterThan("i", v: Int) => v
}.headOption
- val res = new ArrayList[DataReaderFactory[Row]]
+ val res = new ArrayList[DataReaderFactory]
if (lowerBound.isEmpty) {
res.add(new AdvancedDataReaderFactory(0, 5, requiredSchema))
@@ -438,11 +440,13 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport {
}
class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType)
- extends DataReaderFactory[Row] with DataReader[Row] {
+ extends DataReaderFactory with DataReader[Row] {
private var current = start - 1
- override def createDataReader(): DataReader[Row] = {
+ override def dataFormat(): DataFormat = DataFormat.ROW
+
+ override def createRowDataReader(): DataReader[Row] = {
new AdvancedDataReaderFactory(start, end, requiredSchema)
}
@@ -465,10 +469,10 @@ class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType
class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport {
- class Reader extends DataSourceReader with SupportsScanUnsafeRow {
+ class Reader extends DataSourceReader {
override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
- override def createUnsafeRowReaderFactories(): JList[DataReaderFactory[UnsafeRow]] = {
+ override def createDataReaderFactories(): JList[DataReaderFactory] = {
java.util.Arrays.asList(new UnsafeRowDataReaderFactory(0, 5),
new UnsafeRowDataReaderFactory(5, 10))
}
@@ -478,14 +482,16 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport {
}
class UnsafeRowDataReaderFactory(start: Int, end: Int)
- extends DataReaderFactory[UnsafeRow] with DataReader[UnsafeRow] {
+ extends DataReaderFactory with DataReader[UnsafeRow] {
private val row = new UnsafeRow(2)
row.pointTo(new Array[Byte](8 * 3), 8 * 3)
private var current = start - 1
- override def createDataReader(): DataReader[UnsafeRow] = this
+ override def dataFormat(): DataFormat = DataFormat.UNSAFE_ROW
+
+ override def createUnsafeRowDataReader(): DataReader[UnsafeRow] = this
override def next(): Boolean = {
current += 1
@@ -503,7 +509,7 @@ class UnsafeRowDataReaderFactory(start: Int, end: Int)
class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema {
class Reader(val readSchema: StructType) extends DataSourceReader {
- override def createDataReaderFactories(): JList[DataReaderFactory[Row]] =
+ override def createDataReaderFactories(): JList[DataReaderFactory] =
java.util.Collections.emptyList()
}
@@ -513,10 +519,10 @@ class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema {
class BatchDataSourceV2 extends DataSourceV2 with ReadSupport {
- class Reader extends DataSourceReader with SupportsScanColumnarBatch {
+ class Reader extends DataSourceReader {
override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
- override def createBatchDataReaderFactories(): JList[DataReaderFactory[ColumnarBatch]] = {
+ override def createDataReaderFactories(): JList[DataReaderFactory] = {
java.util.Arrays.asList(new BatchDataReaderFactory(0, 50), new BatchDataReaderFactory(50, 90))
}
}
@@ -525,7 +531,7 @@ class BatchDataSourceV2 extends DataSourceV2 with ReadSupport {
}
class BatchDataReaderFactory(start: Int, end: Int)
- extends DataReaderFactory[ColumnarBatch] with DataReader[ColumnarBatch] {
+ extends DataReaderFactory with DataReader[ColumnarBatch] {
private final val BATCH_SIZE = 20
private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType)
@@ -534,7 +540,9 @@ class BatchDataReaderFactory(start: Int, end: Int)
private var current = start
- override def createDataReader(): DataReader[ColumnarBatch] = this
+ override def dataFormat(): DataFormat = DataFormat.COLUMNAR_BATCH
+
+ override def createColumnarBatchDataReader(): DataReader[ColumnarBatch] = this
override def next(): Boolean = {
i.reset()
@@ -568,7 +576,7 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport {
class Reader extends DataSourceReader with SupportsReportPartitioning {
override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int")
- override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
+ override def createDataReaderFactories(): JList[DataReaderFactory] = {
// Note that we don't have same value of column `a` across partitions.
java.util.Arrays.asList(
new SpecificDataReaderFactory(Array(1, 1, 3), Array(4, 4, 6)),
@@ -591,13 +599,15 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport {
}
class SpecificDataReaderFactory(i: Array[Int], j: Array[Int])
- extends DataReaderFactory[Row]
+ extends DataReaderFactory
with DataReader[Row] {
assert(i.length == j.length)
private var current = -1
- override def createDataReader(): DataReader[Row] = this
+ override def dataFormat(): DataFormat = DataFormat.ROW
+
+ override def createRowDataReader(): DataReader[Row] = this
override def next(): Boolean = {
current += 1
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
index a5007fa32135..9515d6370903 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
@@ -45,7 +45,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS
class Reader(path: String, conf: Configuration) extends DataSourceReader {
override def readSchema(): StructType = schema
- override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
+ override def createDataReaderFactories(): JList[DataReaderFactory] = {
val dataPath = new Path(path)
val fs = dataPath.getFileSystem(conf)
if (fs.exists(dataPath)) {
@@ -56,7 +56,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS
val serializableConf = new SerializableConfiguration(conf)
new SimpleCSVDataReaderFactory(
f.getPath.toUri.toString,
- serializableConf): DataReaderFactory[Row]
+ serializableConf): DataReaderFactory
}.toList.asJava
} else {
Collections.emptyList()
@@ -157,13 +157,15 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS
}
class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration)
- extends DataReaderFactory[Row] with DataReader[Row] {
+ extends DataReaderFactory with DataReader[Row] {
@transient private var lines: Iterator[String] = _
@transient private var currentLine: String = _
@transient private var inputStream: FSDataInputStream = _
- override def createDataReader(): DataReader[Row] = {
+ override def dataFormat(): DataFormat = DataFormat.ROW
+
+ override def createRowDataReader(): DataReader[Row] = {
val filePath = new Path(path)
val fs = filePath.getFileSystem(conf.value)
inputStream = fs.open(filePath)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 20942ed93897..50887521de94 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -30,7 +30,6 @@ import org.scalatest.mockito.MockitoSugar
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset}
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter
import org.apache.spark.sql.functions._
@@ -227,10 +226,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
}
// getBatch should take 100 ms the first time it is called
- override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
+ override def createDataReaderFactories(): ju.List[DataReaderFactory] = {
synchronized {
clock.waitTillTime(1350)
- super.createUnsafeRowReaderFactories()
+ super.createDataReaderFactories()
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
index af4618bed545..47f416c4a3b9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
@@ -44,7 +44,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader {
def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map())
def setStartOffset(start: Optional[Offset]): Unit = {}
- def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory[Row]] = {
+ def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory] = {
throw new IllegalStateException("fake source - cannot actually read")
}
}