diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index f26c134c2f6e..0a0ef76672a7 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.sources.v2.DataFormat import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType @@ -53,7 +54,7 @@ class KafkaContinuousReader( metadataPath: String, initialOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends ContinuousReader with SupportsScanUnsafeRow with Logging { + extends ContinuousReader with Logging { private lazy val session = SparkSession.getActiveSession.get private lazy val sc = session.sparkContext @@ -86,7 +87,7 @@ class KafkaContinuousReader( KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def createDataReaderFactories(): ju.List[DataReaderFactory] = { import scala.collection.JavaConverters._ val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) @@ -108,7 +109,7 @@ class KafkaContinuousReader( case (topicPartition, start) => KafkaContinuousDataReaderFactory( topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) - .asInstanceOf[DataReaderFactory[UnsafeRow]] + .asInstanceOf[DataReaderFactory] }.asJava } @@ -161,17 +162,11 @@ case class KafkaContinuousDataReaderFactory( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousDataReaderFactory[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousDataReaderFactory { - override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[UnsafeRow] = { - val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] - require(kafkaOffset.topicPartition == topicPartition, - s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") - new KafkaContinuousDataReader( - topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) - } + override def dataFormat(): DataFormat = DataFormat.UNSAFE_ROW - override def createDataReader(): KafkaContinuousDataReader = { + override def createUnsafeRowDataReader(): KafkaContinuousDataReader = { new KafkaContinuousDataReader( topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 2ed49ba3f549..c0b10a7a2994 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -32,8 +32,8 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.{DataFormat, DataSourceOptions} +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.UninterruptibleThread @@ -61,7 +61,7 @@ private[kafka010] class KafkaMicroBatchReader( metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends MicroBatchReader with SupportsScanUnsafeRow with Logging { + extends MicroBatchReader with Logging { private var startPartitionOffsets: PartitionOffsetMap = _ private var endPartitionOffsets: PartitionOffsetMap = _ @@ -101,7 +101,7 @@ private[kafka010] class KafkaMicroBatchReader( } } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def createDataReaderFactories(): ju.List[DataReaderFactory] = { // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) @@ -146,7 +146,7 @@ private[kafka010] class KafkaMicroBatchReader( new KafkaMicroBatchDataReaderFactory( range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } - factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava + factories.map(_.asInstanceOf[DataReaderFactory]).asJava } override def getStartOffset: Offset = { @@ -305,11 +305,13 @@ private[kafka010] case class KafkaMicroBatchDataReaderFactory( executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends DataReaderFactory[UnsafeRow] { + reuseKafkaConsumer: Boolean) extends DataReaderFactory { override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray - override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader( + override def dataFormat(): DataFormat = DataFormat.UNSAFE_ROW + + override def createUnsafeRowDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader( offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index e017fd9b84d2..730959d0873f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -36,7 +36,6 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession} -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution @@ -673,7 +672,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) ) - val factories = reader.createUnsafeRowReaderFactories().asScala + val factories = reader.createDataReaderFactories().asScala .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory]) withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { assert(factories.size == numPartitionsGenerated) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataFormat.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataFormat.java new file mode 100644 index 000000000000..b1bcac6a7a3b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataFormat.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; + +/** + * An enum returned by {@link DataReaderFactory#dataFormat()}, representing the output data format of + * a data source scan. + * + * + * + * TODO: add INTERNAL_ROW + */ +public enum DataFormat { + /** + * Refers to {@link org.apache.spark.sql.Row}, which is very stable and guaranteed to be backward + * compatible. Spark needs to convert data of row format to the internal format, data source developers + * should consider using other formats for better performance. + */ + ROW, + + /** + * Refers to {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow}, which is an unstable and + * internal API. It's already the internal format in Spark, so there is no extra conversion needed. + */ + UNSAFE_ROW, + + /** + * Refers to {@link org.apache.spark.sql.vectorized.ColumnarBatch}, which is a public but experimental + * API. It's already the internal format in Spark, so there is no extra conversion needed. This format + * is recommended over others as columnar format has other advantages like vectorization, to further + * speed up the data processing. + */ + COLUMNAR_BATCH +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java index a61697649c43..0d7d053b61d2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java @@ -18,18 +18,48 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader; +import org.apache.spark.sql.vectorized.ColumnarBatch; /** * A mix-in interface for {@link DataReaderFactory}. Continuous data reader factories can * implement this interface to provide creating {@link DataReader} with particular offset. */ @InterfaceStability.Evolving -public interface ContinuousDataReaderFactory extends DataReaderFactory { +public interface ContinuousDataReaderFactory extends DataReaderFactory { + + /** + * Returns a row-formatted continuous data reader to do the actual reading work. + * + * If this method fails (by throwing an exception), the corresponding Spark task would fail and + * get retried until hitting the maximum retry times. + */ + default ContinuousDataReader createRowDataReader() { + throw new IllegalStateException( + "createRowDataReader must be implemented if the data format is ROW."); + } + + /** + * Returns a unsafe-row-formatted continuous data reader to do the actual reading work. + * + * If this method fails (by throwing an exception), the corresponding Spark task would fail and + * get retried until hitting the maximum retry times. + */ + default ContinuousDataReader createUnsafeRowDataReader() { + throw new IllegalStateException( + "createUnsafeRowDataReader must be implemented if the data format is UNSAFE_ROW."); + } + /** - * Create a DataReader with particular offset as its startOffset. + * Returns a columnar-batch-formatted continuous data reader to do the actual reading work. * - * @param offset offset want to set as the DataReader's startOffset. + * If this method fails (by throwing an exception), the corresponding Spark task would fail and + * get retried until hitting the maximum retry times. */ - DataReader createDataReaderWithOffset(PartitionOffset offset); + default ContinuousDataReader createColumnarBatchDataReader() { + throw new IllegalStateException( + "createColumnarBatchDataReader must be implemented if the data format is COLUMNAR_BATCH."); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java index bb9790a1c819..04a1303c0bf5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java @@ -23,12 +23,13 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for - * outputting data for a RDD partition. + * A data reader returned by the create data reader method in {@link DataReaderFactory} and is + * responsible for outputting data for a RDD partition. * - * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data - * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source - * readers that mix in {@link SupportsScanUnsafeRow}. + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row}, + * or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow}, or + * {@link org.apache.spark.sql.vectorized.ColumnarBatch}, depending on the return type of + * {@link DataReaderFactory#dataFormat()}. */ @InterfaceStability.Evolving public interface DataReader extends Closeable { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java index 32e98e8f5d8b..86e0579318a8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java @@ -20,6 +20,10 @@ import java.io.Serializable; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.sources.v2.DataFormat; +import org.apache.spark.sql.vectorized.ColumnarBatch; /** * A reader factory returned by {@link DataSourceReader#createDataReaderFactories()} and is @@ -32,7 +36,7 @@ * serializable and {@link DataReader} doesn't need to be. */ @InterfaceStability.Evolving -public interface DataReaderFactory extends Serializable { +public interface DataReaderFactory extends Serializable { /** * The preferred locations where the data reader returned by this reader factory can run faster, @@ -52,10 +56,46 @@ default String[] preferredLocations() { } /** - * Returns a data reader to do the actual reading work. + * The output data format of this factory's data reader. Spark will invoke the corresponding + * create data reader method w.r.t. the return value of this method: + *
    + *
  • {@link DataFormat#ROW}: {@link #createRowDataReader()}
  • + *
  • {@link DataFormat#UNSAFE_ROW}: {@link #createUnsafeRowDataReader()}
  • + *
  • @{@link DataFormat#COLUMNAR_BATCH}: {@link #createColumnarBatchDataReader()}
  • + *
+ */ + DataFormat dataFormat(); + + /** + * Returns a row-formatted data reader to do the actual reading work. + * + * If this method fails (by throwing an exception), the corresponding Spark task would fail and + * get retried until hitting the maximum retry times. + */ + default DataReader createRowDataReader() { + throw new IllegalStateException( + "createRowDataReader must be implemented if the data format is ROW."); + } + + /** + * Returns a unsafe-row-formatted data reader to do the actual reading work. * * If this method fails (by throwing an exception), the corresponding Spark task would fail and * get retried until hitting the maximum retry times. */ - DataReader createDataReader(); + default DataReader createUnsafeRowDataReader() { + throw new IllegalStateException( + "createUnsafeRowDataReader must be implemented if the data format is UNSAFE_ROW."); + } + + /** + * Returns a columnar-batch-formatted data reader to do the actual reading work. + * + * If this method fails (by throwing an exception), the corresponding Spark task would fail and + * get retried until hitting the maximum retry times. + */ + default DataReader createColumnarBatchDataReader() { + throw new IllegalStateException( + "createColumnarBatchDataReader must be implemented if the data format is COLUMNAR_BATCH."); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index a470bccc5aad..2049a96fe903 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -20,7 +20,6 @@ import java.util.List; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; @@ -34,23 +33,18 @@ * logic is delegated to {@link DataReaderFactory}s that are returned by * {@link #createDataReaderFactories()}. * - * There are mainly 3 kinds of query optimizations: + * There are mainly 2 kinds of query optimizations: * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column * pruning), etc. Names of these interfaces start with `SupportsPushDown`. * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. * Names of these interfaces start with `SupportsReporting`. - * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. - * Names of these interfaces start with `SupportsScan`. Note that a reader should only - * implement at most one of the special scans, if more than one special scans are implemented, - * only one of them would be respected, according to the priority list from high to low: - * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. * * If an exception was throw when applying any of these query optimizations, the action would fail * and no Spark job was submitted. * * Spark first applies all operator push-down optimizations that this data source supports. Then * Spark collects information this data source reported for further optimizations. Finally Spark - * issues the scan request and does the actual data reading. + * issues the scan request, create the {@link DataReaderFactory} and does the actual data reading. */ @InterfaceStability.Evolving public interface DataSourceReader { @@ -76,5 +70,5 @@ public interface DataSourceReader { * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. */ - List> createDataReaderFactories(); + List createDataReaderFactories(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java deleted file mode 100644 index 2e5cfa78511f..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import java.util.List; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.vectorized.ColumnarBatch; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to output {@link ColumnarBatch} and make the scan faster. - */ -@InterfaceStability.Evolving -public interface SupportsScanColumnarBatch extends DataSourceReader { - @Override - default List> createDataReaderFactories() { - throw new IllegalStateException( - "createDataReaderFactories not supported by default within SupportsScanColumnarBatch."); - } - - /** - * Similar to {@link DataSourceReader#createDataReaderFactories()}, but returns columnar data - * in batches. - */ - List> createBatchDataReaderFactories(); - - /** - * Returns true if the concrete data source reader can read data in batch according to the scan - * properties like required columns, pushes filters, etc. It's possible that the implementation - * can only support some certain columns with certain types. Users can overwrite this method and - * {@link #createDataReaderFactories()} to fallback to normal read path under some conditions. - */ - default boolean enableBatchRead() { - return true; - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java deleted file mode 100644 index 9cd749e8e4ce..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import java.util.List; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to output {@link UnsafeRow} directly and avoid the row copy at Spark side. - * This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get - * changed in the future Spark versions. - */ -@InterfaceStability.Unstable -public interface SupportsScanUnsafeRow extends DataSourceReader { - - @Override - default List> createDataReaderFactories() { - throw new IllegalStateException( - "createDataReaderFactories not supported by default within SupportsScanUnsafeRow"); - } - - /** - * Similar to {@link DataSourceReader#createDataReaderFactories()}, - * but returns data in unsafe row format. - */ - List> createUnsafeRowReaderFactories(); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index 7fe7f00ac2fa..d908a5034ead 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -17,12 +17,12 @@ package org.apache.spark.sql.sources.v2.reader.streaming; +import java.util.Optional; + import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import java.util.Optional; - /** * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to allow reading in a continuous processing mode stream. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java index 67ebde30d61a..0c8a40594408 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java @@ -17,12 +17,12 @@ package org.apache.spark.sql.sources.v2.reader.streaming; +import java.util.Optional; + import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; -import java.util.Optional; - /** * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to indicate they allow micro-batch streaming reads. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index f85971be394b..a40eea1a9389 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -17,20 +17,24 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ -import scala.reflect.ClassTag - import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.sources.v2.DataFormat +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.types.StructType -class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T]) +class DataSourceRDDPartition(val index: Int, val factory: DataReaderFactory) extends Partition with Serializable -class DataSourceRDD[T: ClassTag]( +class DataSourceRDD( sc: SparkContext, - @transient private val readerFactories: Seq[DataReaderFactory[T]]) - extends RDD[T](sc, Nil) { + @transient private val readerFactories: Seq[DataReaderFactory], + schema: StructType) + extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { readerFactories.zipWithIndex.map { @@ -38,31 +42,60 @@ class DataSourceRDD[T: ClassTag]( }.toArray } - override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader() - context.addTaskCompletionListener(_ => reader.close()) - val iter = new Iterator[T] { - private[this] var valuePrepared = false - - override def hasNext: Boolean = { - if (!valuePrepared) { - valuePrepared = reader.next() - } - valuePrepared - } - - override def next(): T = { - if (!hasNext) { - throw new java.util.NoSuchElementException("End of stream") - } - valuePrepared = false - reader.get() - } + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val factory = split.asInstanceOf[DataSourceRDDPartition].factory + val iter: DataReaderIterator[UnsafeRow] = factory.dataFormat() match { + case DataFormat.ROW => + val reader = new RowToUnsafeDataReader( + factory.createRowDataReader(), RowEncoder.apply(schema).resolveAndBind()) + new DataReaderIterator(reader) + + case DataFormat.UNSAFE_ROW => + new DataReaderIterator(factory.createUnsafeRowDataReader()) + + case DataFormat.COLUMNAR_BATCH => + new DataReaderIterator(factory.createColumnarBatchDataReader()) + // TODO: remove this type erase hack. + .asInstanceOf[DataReaderIterator[UnsafeRow]] } + context.addTaskCompletionListener(_ => iter.close()) new InterruptibleIterator(context, iter) } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition].factory.preferredLocations() } } + +class RowToUnsafeDataReader(rowReader: DataReader[Row], encoder: ExpressionEncoder[Row]) + extends DataReader[UnsafeRow] { + + override def next: Boolean = rowReader.next + + override def get: UnsafeRow = encoder.toRow(rowReader.get).asInstanceOf[UnsafeRow] + + override def close(): Unit = rowReader.close() +} + +class DataReaderIterator[T](reader: DataReader[T]) extends Iterator[T] { + private[this] var valuePrepared = false + + override def hasNext: Boolean = { + if (!valuePrepared) { + valuePrepared = reader.next() + // no more data, close the reader. + if (!valuePrepared) close() + } + valuePrepared + } + + override def next(): T = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + valuePrepared = false + reader.get() + } + + def close(): Unit = reader.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 3a5e7bf89e14..33336cba2234 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -20,19 +20,15 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.sql.sources.v2.{DataFormat, DataSourceV2} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.ColumnarBatch /** * Physical plan node for scanning data from a data source. @@ -57,36 +53,22 @@ case class DataSourceV2ScanExec( Seq(output, source, options).hashCode() } - override def outputPartitioning: physical.Partitioning = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 => + override def outputPartitioning: physical.Partitioning = { + if (readerFactories.length == 1) { SinglePartition + } else { + reader match { + case s: SupportsReportPartitioning => + new DataSourcePartitioning( + s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) - case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 => - SinglePartition - - case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 => - SinglePartition - - case s: SupportsReportPartitioning => - new DataSourcePartitioning( - s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) - - case _ => super.outputPartitioning - } - - private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala - case _ => - reader.createDataReaderFactories().asScala.map { - new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow] + case _ => super.outputPartitioning } + } } - private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() => - assert(!reader.isInstanceOf[ContinuousReader], - "continuous stream reader does not support columnar read yet.") - r.createBatchDataReaderFactories().asScala + private lazy val readerFactories: Seq[DataReaderFactory] = { + reader.createDataReaderFactories().asScala } private lazy val inputRDD: RDD[InternalRow] = reader match { @@ -95,21 +77,29 @@ case class DataSourceV2ScanExec( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) .askSync[Unit](SetReaderPartitions(readerFactories.size)) - new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories) - .asInstanceOf[RDD[InternalRow]] - - case r: SupportsScanColumnarBatch if r.enableBatchRead() => - new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]] - + if (supportsBatch) { + throw new IllegalArgumentException( + "continuous stream reader does not support columnar read yet.") + } + new ContinuousDataSourceRDD( + sparkContext, + sqlContext, + readerFactories.map(_.asInstanceOf[ContinuousDataReaderFactory]), + reader.readSchema()) case _ => - new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD(sparkContext, readerFactories, reader.readSchema()) } override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) - override val supportsBatch: Boolean = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() => true - case _ => false + override val supportsBatch: Boolean = { + val formats = readerFactories.map(_.dataFormat()).distinct + if (formats.length > 1) { + throw new IllegalArgumentException("Currently Spark requires all the data reader " + + "factories returned by the DataSourceReader output same data format.") + } + + formats.nonEmpty && formats.head == DataFormat.COLUMNAR_BATCH } override protected def needsUnsafeRowConversion: Boolean = false @@ -126,24 +116,3 @@ case class DataSourceV2ScanExec( } } } - -class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType) - extends DataReaderFactory[UnsafeRow] { - - override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations - - override def createDataReader: DataReader[UnsafeRow] = { - new RowToUnsafeDataReader( - rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind()) - } -} - -class RowToUnsafeDataReader(val rowReader: DataReader[Row], encoder: ExpressionEncoder[Row]) - extends DataReader[UnsafeRow] { - - override def next: Boolean = rowReader.next - - override def get: UnsafeRow = encoder.toRow(rowReader.get).asInstanceOf[UnsafeRow] - - override def close(): Unit = rowReader.close() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index 06754f01657d..cfd1f888c36e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -20,41 +20,58 @@ package org.apache.spark.sql.execution.streaming.continuous import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} import java.util.concurrent.atomic.AtomicBoolean -import scala.collection.JavaConverters._ - import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} +import org.apache.spark.sql.execution.datasources.v2.RowToUnsafeDataReader +import org.apache.spark.sql.sources.v2.DataFormat import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.ThreadUtils +class ContinuousDataSourceRDDPartition( + val index: Int, + val factory: ContinuousDataReaderFactory) extends Partition with Serializable + class ContinuousDataSourceRDD( sc: SparkContext, sqlContext: SQLContext, - @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { + @transient private val readerFactories: Seq[ContinuousDataReaderFactory], + schema: StructType) + extends RDD[InternalRow](sc, Nil) { private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs override protected def getPartitions: Array[Partition] = { readerFactories.zipWithIndex.map { - case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) + case (readerFactory, index) => new ContinuousDataSourceRDDPartition(index, readerFactory) }.toArray } - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { // If attempt number isn't 0, this is a task retry, which we don't support. if (context.attemptNumber() != 0) { throw new ContinuousTaskRetryException() } - val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]] - .readerFactory.createDataReader() + val factory = split.asInstanceOf[ContinuousDataSourceRDDPartition].factory + val reader = factory.dataFormat() match { + case DataFormat.ROW => + new RowToUnsafeContinuousDataReader( + factory.createRowDataReader().asInstanceOf[ContinuousDataReader[Row]], + RowEncoder.apply(schema).resolveAndBind()) + + case DataFormat.UNSAFE_ROW => factory.createUnsafeRowDataReader() + + case other => + throw new IllegalArgumentException(s"Illegal data format specified: $other") + } val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) @@ -71,7 +88,7 @@ class ContinuousDataSourceRDD( epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) // Important sequencing - we must get start offset before the data reader thread begins - val startOffset = ContinuousDataSourceRDD.getBaseReader(reader).getOffset + val startOffset = reader.getOffset val dataReaderFailed = new AtomicBoolean(false) val dataReaderThread = new DataReaderThread(reader, queue, context, dataReaderFailed) @@ -133,7 +150,7 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readerFactory.preferredLocations() + split.asInstanceOf[ContinuousDataSourceRDDPartition].factory.preferredLocations() } } @@ -168,7 +185,7 @@ class EpochPollRunnable( } class DataReaderThread( - reader: DataReader[UnsafeRow], + reader: ContinuousDataReader[UnsafeRow], queue: BlockingQueue[(UnsafeRow, PartitionOffset)], context: TaskContext, failedFlag: AtomicBoolean) @@ -179,7 +196,6 @@ class DataReaderThread( override def run(): Unit = { TaskContext.setTaskContext(context) - val baseReader = ContinuousDataSourceRDD.getBaseReader(reader) try { while (!context.isInterrupted && !context.isCompleted()) { if (!reader.next()) { @@ -192,7 +208,7 @@ class DataReaderThread( } } - queue.put((reader.get().copy(), baseReader.getOffset)) + queue.put((reader.get().copy(), reader.getOffset)) } } catch { case _: InterruptedException if context.isInterrupted() => @@ -209,14 +225,10 @@ class DataReaderThread( } } -object ContinuousDataSourceRDD { - private[continuous] def getBaseReader(reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { - reader match { - case r: ContinuousDataReader[UnsafeRow] => r - case wrapped: RowToUnsafeDataReader => - wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] - case _ => - throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") - } - } +class RowToUnsafeContinuousDataReader( + rowReader: ContinuousDataReader[Row], + encoder: ExpressionEncoder[Row]) + extends RowToUnsafeDataReader(rowReader, encoder) with ContinuousDataReader[UnsafeRow] { + + override def getOffset: PartitionOffset = rowReader.getOffset } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 2f0de2612c15..6b1a561b866b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider -import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.{DataFormat, DataSourceOptions} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType @@ -67,7 +67,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def getStartOffset(): Offset = offset - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + override def createDataReaderFactories(): java.util.List[DataReaderFactory] = { val partitionStartMap = offset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => @@ -90,8 +90,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) start.runTimeMs, i, numPartitions, - perPartitionRate) - .asInstanceOf[DataReaderFactory[Row]] + perPartitionRate).asInstanceOf[DataReaderFactory] }.asJava } @@ -119,21 +118,11 @@ case class RateStreamContinuousDataReaderFactory( partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousDataReaderFactory[Row] { + extends ContinuousDataReaderFactory { - override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[Row] = { - val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] - require(rateStreamOffset.partition == partitionIndex, - s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") - new RateStreamContinuousDataReader( - rateStreamOffset.currentValue, - rateStreamOffset.currentTimeMs, - partitionIndex, - increment, - rowsPerSecond) - } + override def dataFormat(): DataFormat = DataFormat.ROW - override def createDataReader(): DataReader[Row] = + override def createRowDataReader(): ContinuousDataReader[Row] = new RateStreamContinuousDataReader( startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 628923d367ce..a8e9cff37e29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -24,19 +24,18 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.DataFormat +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.streaming.{OutputMode, Trigger} +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -82,7 +81,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) extends MemoryStreamBase[A](sqlContext) - with MicroBatchReader with SupportsScanUnsafeRow with Logging { + with MicroBatchReader with Logging { protected val logicalPlan: LogicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) @@ -141,7 +140,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) if (endOffset.offset == -1) null else endOffset } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def createDataReaderFactories(): ju.List[DataReaderFactory] = { synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 @@ -158,7 +157,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]] + new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory] }.asJava } } @@ -204,8 +203,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) - extends DataReaderFactory[UnsafeRow] { - override def createDataReader(): DataReader[UnsafeRow] = { + extends DataReaderFactory { + + override def dataFormat(): DataFormat = DataFormat.UNSAFE_ROW + + override def createUnsafeRowDataReader(): DataReader[UnsafeRow] = { new DataReader[UnsafeRow] { private var currentIndex = -1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index c28919b8b729..6ed6af558500 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -33,8 +33,8 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeR import org.apache.spark.sql.{Encoder, Row, SQLContext} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataFormat, DataSourceOptions} +import org.apache.spark.sql.sources.v2.reader.{ContinuousDataReaderFactory, DataReaderFactory} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils @@ -99,7 +99,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) ) } - override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = { + override def createDataReaderFactories(): ju.List[DataReaderFactory] = { synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = @@ -108,7 +108,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) startOffset.partitionNums.map { case (part, index) => new ContinuousMemoryStreamDataReaderFactory( - endpointName, part, index): DataReaderFactory[Row] + endpointName, part, index): DataReaderFactory }.toList.asJava } } @@ -160,8 +160,11 @@ object ContinuousMemoryStream { class ContinuousMemoryStreamDataReaderFactory( driverEndpointName: String, partition: Int, - startOffset: Int) extends DataReaderFactory[Row] { - override def createDataReader: ContinuousMemoryStreamDataReader = + startOffset: Int) extends ContinuousDataReaderFactory { + + override def dataFormat(): DataFormat = DataFormat.ROW + + override def createRowDataReader: ContinuousMemoryStreamDataReader = new ContinuousMemoryStreamDataReader(driverEndpointName, partition, startOffset) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index 6cf8520fc544..6a0ee6b89b6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -31,7 +31,7 @@ import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.{DataFormat, DataSourceOptions} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.StructType @@ -134,7 +134,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: LongOffset(json.toLong) } - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + override def createDataReaderFactories(): java.util.List[DataReaderFactory] = { val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") @@ -169,7 +169,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: (0 until numPartitions).map { p => new RateStreamMicroBatchDataReaderFactory( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - : DataReaderFactory[Row] + : DataReaderFactory }.toList.asJava } @@ -188,9 +188,11 @@ class RateStreamMicroBatchDataReaderFactory( rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReaderFactory[Row] { + relativeMsPerValue: Double) extends DataReaderFactory { - override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( + override def dataFormat(): DataFormat = DataFormat.ROW; + + override def createRowDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 5aae46b46339..848612abba6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.LongOffset import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.{DataFormat, DataSourceOptions, DataSourceV2, MicroBatchReadSupport} import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} @@ -140,7 +140,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory] = { assert(startOffset != null && endOffset != null, "start offset and end offset should already be set before create read tasks.") @@ -165,8 +165,10 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR (0 until numPartitions).map { i => val slice = slices(i) - new DataReaderFactory[Row] { - override def createDataReader(): DataReader[Row] = new DataReader[Row] { + new DataReaderFactory { + override def dataFormat(): DataFormat = DataFormat.ROW + + override def createRowDataReader(): DataReader[Row] = new DataReader[Row] { private var currentIdx = -1 override def next(): Boolean = { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 172e5d5eebcb..69447ce31ab0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.v2.DataFormat; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupport; @@ -79,8 +80,8 @@ public Filter[] pushedFilters() { } @Override - public List> createDataReaderFactories() { - List> res = new ArrayList<>(); + public List createDataReaderFactories() { + List res = new ArrayList<>(); Integer lowerBound = null; for (Filter filter : filters) { @@ -107,7 +108,7 @@ public List> createDataReaderFactories() { } } - static class JavaAdvancedDataReaderFactory implements DataReaderFactory, DataReader { + static class JavaAdvancedDataReaderFactory implements DataReaderFactory, DataReader { private int start; private int end; private StructType requiredSchema; @@ -119,7 +120,12 @@ static class JavaAdvancedDataReaderFactory implements DataReaderFactory, Da } @Override - public DataReader createDataReader() { + public DataFormat dataFormat() { + return DataFormat.ROW; + } + + @Override + public DataReader createRowDataReader() { return new JavaAdvancedDataReaderFactory(start - 1, end, requiredSchema); } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index c55093768105..3d62128175da 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -21,6 +21,7 @@ import java.util.List; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.DataFormat; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupport; @@ -33,7 +34,7 @@ public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader, SupportsScanColumnarBatch { + class Reader implements DataSourceReader { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -42,14 +43,14 @@ public StructType readSchema() { } @Override - public List> createBatchDataReaderFactories() { + public List createDataReaderFactories() { return java.util.Arrays.asList( new JavaBatchDataReaderFactory(0, 50), new JavaBatchDataReaderFactory(50, 90)); } } static class JavaBatchDataReaderFactory - implements DataReaderFactory, DataReader { + implements DataReaderFactory, DataReader { private int start; private int end; @@ -65,7 +66,12 @@ static class JavaBatchDataReaderFactory } @Override - public DataReader createDataReader() { + public DataFormat dataFormat() { + return DataFormat.COLUMNAR_BATCH; + } + + @Override + public DataReader createColumnarBatchDataReader() { this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); ColumnVector[] vectors = new ColumnVector[2]; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 32fad59b97ff..f5e9adfe24ae 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.v2.DataFormat; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupport; @@ -43,7 +44,7 @@ public StructType readSchema() { } @Override - public List> createDataReaderFactories() { + public List createDataReaderFactories() { return java.util.Arrays.asList( new SpecificDataReaderFactory(new int[]{1, 1, 3}, new int[]{4, 4, 6}), new SpecificDataReaderFactory(new int[]{2, 4, 4}, new int[]{6, 2, 2})); @@ -73,7 +74,7 @@ public boolean satisfy(Distribution distribution) { } } - static class SpecificDataReaderFactory implements DataReaderFactory, DataReader { + static class SpecificDataReaderFactory implements DataReaderFactory, DataReader { private int[] i; private int[] j; private int current = -1; @@ -101,7 +102,12 @@ public void close() throws IOException { } @Override - public DataReader createDataReader() { + public DataFormat dataFormat() { + return DataFormat.ROW; + } + + @Override + public DataReader createRowDataReader() { return this; } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index 048d078dfaac..87a3df3b2396 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -42,7 +42,7 @@ public StructType readSchema() { } @Override - public List> createDataReaderFactories() { + public List createDataReaderFactories() { return java.util.Collections.emptyList(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 96f55b8a7681..09edf8ad7941 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -22,6 +22,7 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.v2.DataFormat; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.ReadSupport; @@ -41,14 +42,14 @@ public StructType readSchema() { } @Override - public List> createDataReaderFactories() { + public List createDataReaderFactories() { return java.util.Arrays.asList( new JavaSimpleDataReaderFactory(0, 5), new JavaSimpleDataReaderFactory(5, 10)); } } - static class JavaSimpleDataReaderFactory implements DataReaderFactory, DataReader { + static class JavaSimpleDataReaderFactory implements DataReaderFactory, DataReader { private int start; private int end; @@ -58,7 +59,12 @@ static class JavaSimpleDataReaderFactory implements DataReaderFactory, Data } @Override - public DataReader createDataReader() { + public DataFormat dataFormat() { + return DataFormat.ROW; + } + + @Override + public DataReader createRowDataReader() { return new JavaSimpleDataReaderFactory(start - 1, end); } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index c3916e0b370b..211e9e5fefd4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -21,6 +21,7 @@ import java.util.List; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.sources.v2.DataFormat; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupport; @@ -29,7 +30,7 @@ public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader, SupportsScanUnsafeRow { + class Reader implements DataSourceReader { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -38,7 +39,7 @@ public StructType readSchema() { } @Override - public List> createUnsafeRowReaderFactories() { + public List createDataReaderFactories() { return java.util.Arrays.asList( new JavaUnsafeRowDataReaderFactory(0, 5), new JavaUnsafeRowDataReaderFactory(5, 10)); @@ -46,7 +47,7 @@ public List> createUnsafeRowReaderFactories() { } static class JavaUnsafeRowDataReaderFactory - implements DataReaderFactory, DataReader { + implements DataReaderFactory, DataReader { private int start; private int end; private UnsafeRow row; @@ -59,7 +60,12 @@ static class JavaUnsafeRowDataReaderFactory } @Override - public DataReader createDataReader() { + public DataFormat dataFormat() { + return DataFormat.UNSAFE_ROW; + } + + @Override + public DataReader createUnsafeRowDataReader() { return new JavaUnsafeRowDataReaderFactory(start - 1, end); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index ff14ec38e66a..9d1a0b2f42ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -24,8 +24,7 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ @@ -144,7 +143,7 @@ class RateSourceSuite extends StreamTest { reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) val tasks = reader.createDataReaderFactories() assert(tasks.size == 1) - val dataReader = tasks.get(0).createDataReader() + val dataReader = tasks.get(0).createRowDataReader() val data = ArrayBuffer[Row]() while (dataReader.next()) { data.append(dataReader.get()) @@ -163,7 +162,7 @@ class RateSourceSuite extends StreamTest { assert(tasks.size == 11) val readData = tasks.asScala - .map(_.createDataReader()) + .map(_.createRowDataReader()) .flatMap { reader => val buf = scala.collection.mutable.ListBuffer[Row]() while (reader.next()) buf.append(reader.get()) @@ -314,7 +313,7 @@ class RateSourceSuite extends StreamTest { .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + val r = t.createRowDataReader().asInstanceOf[RateStreamContinuousDataReader] for (rowIndex <- 0 to 9) { r.next() data.append(r.get()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index e0a53272cd22..518a31cc1cd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -346,7 +346,7 @@ class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory] = { java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5)) } } @@ -359,7 +359,7 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory] = { java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5), new SimpleDataReaderFactory(5, 10)) } } @@ -368,11 +368,13 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { } class SimpleDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[Row] + extends DataReaderFactory with DataReader[Row] { private var current = start - 1 - override def createDataReader(): DataReader[Row] = new SimpleDataReaderFactory(start, end) + override def dataFormat(): DataFormat = DataFormat.ROW + + override def createRowDataReader(): DataReader[Row] = new SimpleDataReaderFactory(start, end) override def next(): Boolean = { current += 1 @@ -413,12 +415,12 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { requiredSchema } - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory] = { val lowerBound = filters.collect { case GreaterThan("i", v: Int) => v }.headOption - val res = new ArrayList[DataReaderFactory[Row]] + val res = new ArrayList[DataReaderFactory] if (lowerBound.isEmpty) { res.add(new AdvancedDataReaderFactory(0, 5, requiredSchema)) @@ -438,11 +440,13 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { } class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType) - extends DataReaderFactory[Row] with DataReader[Row] { + extends DataReaderFactory with DataReader[Row] { private var current = start - 1 - override def createDataReader(): DataReader[Row] = { + override def dataFormat(): DataFormat = DataFormat.ROW + + override def createRowDataReader(): DataReader[Row] = { new AdvancedDataReaderFactory(start, end, requiredSchema) } @@ -465,10 +469,10 @@ class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsScanUnsafeRow { + class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createUnsafeRowReaderFactories(): JList[DataReaderFactory[UnsafeRow]] = { + override def createDataReaderFactories(): JList[DataReaderFactory] = { java.util.Arrays.asList(new UnsafeRowDataReaderFactory(0, 5), new UnsafeRowDataReaderFactory(5, 10)) } @@ -478,14 +482,16 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { } class UnsafeRowDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[UnsafeRow] with DataReader[UnsafeRow] { + extends DataReaderFactory with DataReader[UnsafeRow] { private val row = new UnsafeRow(2) row.pointTo(new Array[Byte](8 * 3), 8 * 3) private var current = start - 1 - override def createDataReader(): DataReader[UnsafeRow] = this + override def dataFormat(): DataFormat = DataFormat.UNSAFE_ROW + + override def createUnsafeRowDataReader(): DataReader[UnsafeRow] = this override def next(): Boolean = { current += 1 @@ -503,7 +509,7 @@ class UnsafeRowDataReaderFactory(start: Int, end: Int) class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { class Reader(val readSchema: StructType) extends DataSourceReader { - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = + override def createDataReaderFactories(): JList[DataReaderFactory] = java.util.Collections.emptyList() } @@ -513,10 +519,10 @@ class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsScanColumnarBatch { + class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createBatchDataReaderFactories(): JList[DataReaderFactory[ColumnarBatch]] = { + override def createDataReaderFactories(): JList[DataReaderFactory] = { java.util.Arrays.asList(new BatchDataReaderFactory(0, 50), new BatchDataReaderFactory(50, 90)) } } @@ -525,7 +531,7 @@ class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { } class BatchDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[ColumnarBatch] with DataReader[ColumnarBatch] { + extends DataReaderFactory with DataReader[ColumnarBatch] { private final val BATCH_SIZE = 20 private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) @@ -534,7 +540,9 @@ class BatchDataReaderFactory(start: Int, end: Int) private var current = start - override def createDataReader(): DataReader[ColumnarBatch] = this + override def dataFormat(): DataFormat = DataFormat.COLUMNAR_BATCH + + override def createColumnarBatchDataReader(): DataReader[ColumnarBatch] = this override def next(): Boolean = { i.reset() @@ -568,7 +576,7 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader with SupportsReportPartitioning { override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory] = { // Note that we don't have same value of column `a` across partitions. java.util.Arrays.asList( new SpecificDataReaderFactory(Array(1, 1, 3), Array(4, 4, 6)), @@ -591,13 +599,15 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { } class SpecificDataReaderFactory(i: Array[Int], j: Array[Int]) - extends DataReaderFactory[Row] + extends DataReaderFactory with DataReader[Row] { assert(i.length == j.length) private var current = -1 - override def createDataReader(): DataReader[Row] = this + override def dataFormat(): DataFormat = DataFormat.ROW + + override def createRowDataReader(): DataReader[Row] = this override def next(): Boolean = { current += 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index a5007fa32135..9515d6370903 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -45,7 +45,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS class Reader(path: String, conf: Configuration) extends DataSourceReader { override def readSchema(): StructType = schema - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -56,7 +56,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS val serializableConf = new SerializableConfiguration(conf) new SimpleCSVDataReaderFactory( f.getPath.toUri.toString, - serializableConf): DataReaderFactory[Row] + serializableConf): DataReaderFactory }.toList.asJava } else { Collections.emptyList() @@ -157,13 +157,15 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration) - extends DataReaderFactory[Row] with DataReader[Row] { + extends DataReaderFactory with DataReader[Row] { @transient private var lines: Iterator[String] = _ @transient private var currentLine: String = _ @transient private var inputStream: FSDataInputStream = _ - override def createDataReader(): DataReader[Row] = { + override def dataFormat(): DataFormat = DataFormat.ROW + + override def createRowDataReader(): DataReader[Row] = { val filePath = new Path(path) val fs = filePath.getFileSystem(conf.value) inputStream = fs.open(filePath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 20942ed93897..50887521de94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -30,7 +30,6 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ @@ -227,10 +226,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } // getBatch should take 100 ms the first time it is called - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def createDataReaderFactories(): ju.List[DataReaderFactory] = { synchronized { clock.waitTillTime(1350) - super.createUnsafeRowReaderFactories() + super.createDataReaderFactories() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index af4618bed545..47f416c4a3b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -44,7 +44,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader { def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) def setStartOffset(start: Optional[Offset]): Unit = {} - def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory[Row]] = { + def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory] = { throw new IllegalStateException("fake source - cannot actually read") } }