diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala similarity index 74% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala index be7ce3b3ed75..4a18839e6a77 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala @@ -25,16 +25,15 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.sql.types.StructType /** - * A [[ContinuousReader]] for data from kafka. + * A [[ContinuousReadSupport]] for data from kafka. * * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be * read by per-task consumers generated later. @@ -47,70 +46,49 @@ import org.apache.spark.sql.types.StructType * scenarios, where some offsets after the specified initial ones can't be * properly read. */ -class KafkaContinuousReader( +class KafkaContinuousReadSupport( offsetReader: KafkaOffsetReader, kafkaParams: ju.Map[String, Object], sourceOptions: Map[String, String], metadataPath: String, initialOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends ContinuousReader with Logging { - - private lazy val session = SparkSession.getActiveSession.get - private lazy val sc = session.sparkContext + extends ContinuousReadSupport with Logging { private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong - // Initialized when creating reader factories. If this diverges from the partitions at the latest - // offsets, we need to reconfigure. - // Exposed outside this object only for unit tests. - @volatile private[sql] var knownPartitions: Set[TopicPartition] = _ - - override def readSchema: StructType = KafkaOffsetReader.kafkaSchema - - private var offset: Offset = _ - override def setStartOffset(start: ju.Optional[Offset]): Unit = { - offset = start.orElse { - val offsets = initialOffsets match { - case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) - case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) - } - logInfo(s"Initial offsets: $offsets") - offsets + override def initialOffset(): Offset = { + val offsets = initialOffsets match { + case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) } + logInfo(s"Initial offsets: $offsets") + offsets } - override def getStartOffset(): Offset = offset + override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema + + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new KafkaContinuousScanConfigBuilder(fullSchema(), start, offsetReader, reportDataLoss) + } override def deserializeOffset(json: String): Offset = { KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { - import scala.collection.JavaConverters._ - - val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) - - val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet - val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) - val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) - - val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) - if (deletedPartitions.nonEmpty) { - reportDataLoss(s"Some partitions were deleted: $deletedPartitions") - } - - val startOffsets = newPartitionOffsets ++ - oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) - knownPartitions = startOffsets.keySet - + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffsets = config.asInstanceOf[KafkaContinuousScanConfig].startOffsets startOffsets.toSeq.map { case (topicPartition, start) => KafkaContinuousInputPartition( - topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss - ): InputPartition[InternalRow] - }.asJava + topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) + }.toArray + } + + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + KafkaContinuousReaderFactory } /** Stop this source and free any resources it has allocated. */ @@ -127,8 +105,9 @@ class KafkaContinuousReader( KafkaSourceOffset(mergedMap) } - override def needsReconfiguration(): Boolean = { - knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions + override def needsReconfiguration(config: ScanConfig): Boolean = { + val knownPartitions = config.asInstanceOf[KafkaContinuousScanConfig].knownPartitions + offsetReader.fetchLatestOffsets().keySet != knownPartitions } override def toString(): String = s"KafkaSource[$offsetReader]" @@ -162,23 +141,51 @@ case class KafkaContinuousInputPartition( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousInputPartition[InternalRow] { - - override def createContinuousReader( - offset: PartitionOffset): InputPartitionReader[InternalRow] = { - val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] - require(kafkaOffset.topicPartition == topicPartition, - s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") - new KafkaContinuousInputPartitionReader( - topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + failOnDataLoss: Boolean) extends InputPartition + +object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory { + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[KafkaContinuousInputPartition] + new KafkaContinuousPartitionReader( + p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs, p.failOnDataLoss) } +} + +class KafkaContinuousScanConfigBuilder( + schema: StructType, + startOffset: Offset, + offsetReader: KafkaOffsetReader, + reportDataLoss: String => Unit) + extends ScanConfigBuilder { + + override def build(): ScanConfig = { + val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(startOffset) + + val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet + val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) + val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) - override def createPartitionReader(): KafkaContinuousInputPartitionReader = { - new KafkaContinuousInputPartitionReader( - topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"Some partitions were deleted: $deletedPartitions") + } + + val startOffsets = newPartitionOffsets ++ + oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) + KafkaContinuousScanConfig(schema, startOffsets) } } +case class KafkaContinuousScanConfig( + readSchema: StructType, + startOffsets: Map[TopicPartition, Long]) + extends ScanConfig { + + // Created when building the scan config builder. If this diverges from the partitions at the + // latest offsets, we need to reconfigure the kafka read support. + def knownPartitions: Set[TopicPartition] = startOffsets.keySet +} + /** * A per-task data reader for continuous Kafka processing. * @@ -189,12 +196,12 @@ case class KafkaContinuousInputPartition( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -class KafkaContinuousInputPartitionReader( +class KafkaContinuousPartitionReader( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[InternalRow] { + failOnDataLoss: Boolean) extends ContinuousPartitionReader[InternalRow] { private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) private val converter = new KafkaRecordToUnsafeRowConverter diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala similarity index 83% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala index 900c9f4e7fbf..c31af60b8a1c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala @@ -21,8 +21,6 @@ import java.{util => ju} import java.io._ import java.nio.charset.StandardCharsets -import scala.collection.JavaConverters._ - import org.apache.commons.io.IOUtils import org.apache.kafka.common.TopicPartition @@ -32,16 +30,17 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} +import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} +import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchReadSupport import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset, SupportsCustomReaderMetrics} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset, SupportsCustomReaderMetrics} import org.apache.spark.sql.types.StructType import org.apache.spark.util.UninterruptibleThread /** - * A [[MicroBatchReader]] that reads data from Kafka. + * A [[MicroBatchReadSupport]] that reads data from Kafka. * * The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains * a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For @@ -56,17 +55,14 @@ import org.apache.spark.util.UninterruptibleThread * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers * and not use wrong broker addresses. */ -private[kafka010] class KafkaMicroBatchReader( +private[kafka010] class KafkaMicroBatchReadSupport( kafkaOffsetReader: KafkaOffsetReader, executorKafkaParams: ju.Map[String, Object], options: DataSourceOptions, metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends MicroBatchReader with SupportsCustomReaderMetrics with Logging { - - private var startPartitionOffsets: PartitionOffsetMap = _ - private var endPartitionOffsets: PartitionOffsetMap = _ + extends RateControlMicroBatchReadSupport with SupportsCustomReaderMetrics with Logging { private val pollTimeoutMs = options.getLong( "kafkaConsumer.pollTimeoutMs", @@ -76,34 +72,40 @@ private[kafka010] class KafkaMicroBatchReader( Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) private val rangeCalculator = KafkaOffsetRangeCalculator(options) + + private var endPartitionOffsets: KafkaSourceOffset = _ + /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only * called in StreamExecutionThread. Otherwise, interrupting a thread while running * `KafkaConsumer.poll` may hang forever (KAFKA-1894). */ - private lazy val initialPartitionOffsets = getOrCreateInitialPartitionOffsets() - - override def setOffsetRange(start: ju.Optional[Offset], end: ju.Optional[Offset]): Unit = { - // Make sure initialPartitionOffsets is initialized - initialPartitionOffsets - - startPartitionOffsets = Option(start.orElse(null)) - .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) - .getOrElse(initialPartitionOffsets) - - endPartitionOffsets = Option(end.orElse(null)) - .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) - .getOrElse { - val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() - maxOffsetsPerTrigger.map { maxOffsets => - rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) - }.getOrElse { - latestPartitionOffsets - } - } + override def initialOffset(): Offset = { + KafkaSourceOffset(getOrCreateInitialPartitionOffsets()) + } + + override def latestOffset(start: Offset): Offset = { + val startPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets + val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() + endPartitionOffsets = KafkaSourceOffset(maxOffsetsPerTrigger.map { maxOffsets => + rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) + }.getOrElse { + latestPartitionOffsets + }) + endPartitionOffsets + } + + override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema + + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) } - override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startPartitionOffsets = sc.start.asInstanceOf[KafkaSourceOffset].partitionToOffsets + val endPartitionOffsets = sc.end.get.asInstanceOf[KafkaSourceOffset].partitionToOffsets + // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) @@ -145,30 +147,26 @@ private[kafka010] class KafkaMicroBatchReader( // Generate factories based on the offset ranges offsetRanges.map { range => - new KafkaMicroBatchInputPartition( - range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer - ): InputPartition[InternalRow] - }.asJava - } - - override def getStartOffset: Offset = { - KafkaSourceOffset(startPartitionOffsets) + KafkaMicroBatchInputPartition( + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) + }.toArray } - override def getEndOffset: Offset = { - KafkaSourceOffset(endPartitionOffsets) + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + KafkaMicroBatchReaderFactory } - override def getCustomMetrics: CustomMetrics = { - KafkaCustomMetrics(kafkaOffsetReader.fetchLatestOffsets(), endPartitionOffsets) + // TODO: figure out the life cycle of custom metrics, and make this method take `ScanConfig` as + // a parameter. + override def getCustomMetrics(): CustomMetrics = { + KafkaCustomMetrics( + kafkaOffsetReader.fetchLatestOffsets(), endPartitionOffsets.partitionToOffsets) } override def deserializeOffset(json: String): Offset = { KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema - override def commit(end: Offset): Unit = {} override def stop(): Unit = { @@ -311,22 +309,23 @@ private[kafka010] case class KafkaMicroBatchInputPartition( executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends InputPartition[InternalRow] { + reuseKafkaConsumer: Boolean) extends InputPartition - override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray - - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs, - failOnDataLoss, reuseKafkaConsumer) +private[kafka010] object KafkaMicroBatchReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val p = partition.asInstanceOf[KafkaMicroBatchInputPartition] + KafkaMicroBatchPartitionReader(p.offsetRange, p.executorKafkaParams, p.pollTimeoutMs, + p.failOnDataLoss, p.reuseKafkaConsumer) + } } -/** A [[InputPartitionReader]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] case class KafkaMicroBatchInputPartitionReader( +/** A [[PartitionReader]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] case class KafkaMicroBatchPartitionReader( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends InputPartitionReader[InternalRow] with Logging { + reuseKafkaConsumer: Boolean) extends PartitionReader[InternalRow] with Logging { private val consumer = KafkaDataConsumer.acquire( offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index d225c1ea6b7f..28c9853bfea9 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,9 +30,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -46,9 +45,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with StreamWriteSupport - with ContinuousReadSupport - with MicroBatchReadSupport + with StreamingWriteSupportProvider + with ContinuousReadSupportProvider + with MicroBatchReadSupportProvider with Logging { import KafkaSourceProvider._ @@ -108,13 +107,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } /** - * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader]] to read batches - * of Kafka data in a micro-batch streaming query. + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport]] to read + * batches of Kafka data in a micro-batch streaming query. */ - override def createMicroBatchReader( - schema: Optional[StructType], + override def createMicroBatchReadSupport( metadataPath: String, - options: DataSourceOptions): KafkaMicroBatchReader = { + options: DataSourceOptions): KafkaMicroBatchReadSupport = { val parameters = options.asMap().asScala.toMap validateStreamOptions(parameters) @@ -140,7 +138,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister parameters, driverGroupIdPrefix = s"$uniqueGroupId-driver") - new KafkaMicroBatchReader( + new KafkaMicroBatchReadSupport( kafkaOffsetReader, kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), options, @@ -150,13 +148,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } /** - * Creates a [[ContinuousInputPartitionReader]] to read + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport]] to read * Kafka data in a continuous streaming query. */ - override def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( metadataPath: String, - options: DataSourceOptions): KafkaContinuousReader = { + options: DataSourceOptions): KafkaContinuousReadSupport = { val parameters = options.asMap().asScala.toMap validateStreamOptions(parameters) // Each running query should use its own group id. Otherwise, the query may be only assigned @@ -181,7 +178,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister parameters, driverGroupIdPrefix = s"$uniqueGroupId-driver") - new KafkaContinuousReader( + new KafkaContinuousReadSupport( kafkaOffsetReader, kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), parameters, @@ -270,11 +267,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - override def createStreamWriter( + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { + options: DataSourceOptions): StreamingWriteSupport = { import scala.collection.JavaConverters._ val spark = SparkSession.getActiveSession.get @@ -285,7 +282,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister KafkaWriter.validateQuery( schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) - new KafkaStreamWriter(topic, producerParams, schema) + new KafkaStreamingWriteSupport(topic, producerParams, schema) } private def strategy(caseInsensitiveParams: Map[String, String]) = diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala similarity index 91% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala index 5f0802b46603..dc19312f79a2 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} import org.apache.spark.sql.types.StructType /** @@ -33,20 +33,20 @@ import org.apache.spark.sql.types.StructType case object KafkaWriterCommitMessage extends WriterCommitMessage /** - * A [[StreamWriter]] for Kafka writing. Responsible for generating the writer factory. + * A [[StreamingWriteSupport]] for Kafka writing. Responsible for generating the writer factory. * * @param topic The topic this writer is responsible for. If None, topic will be inferred from * a `topic` field in the incoming data. * @param producerParams Parameters for Kafka producers in each task. * @param schema The schema of the input data. */ -class KafkaStreamWriter( +class KafkaStreamingWriteSupport( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends StreamWriter { + extends StreamingWriteSupport { validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) - override def createWriterFactory(): KafkaStreamWriterFactory = + override def createStreamingWriterFactory(): KafkaStreamWriterFactory = KafkaStreamWriterFactory(topic, producerParams, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} @@ -63,9 +63,9 @@ class KafkaStreamWriter( */ case class KafkaStreamWriterFactory( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends DataWriterFactory[InternalRow] { + extends StreamingDataWriterFactory { - override def createDataWriter( + override def createWriter( partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index ea2a2a84d22c..321665042b8e 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -61,10 +61,12 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r - }.exists { r => + case r: StreamingDataSourceV2Relation + if r.readSupport.isInstanceOf[KafkaContinuousReadSupport] => + r.scanConfigBuilder.build().asInstanceOf[KafkaContinuousScanConfig] + }.exists { config => // Ensure the new topic is present and the old topic is gone. - r.knownPartitions.exists(_.topic == topic2) + config.knownPartitions.exists(_.topic == topic2) }, s"query never reconfigured to new topic $topic2") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index fa1468a3943c..fa6bdc20bd4f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.streaming.Trigger @@ -46,8 +46,10 @@ trait KafkaContinuousTest extends KafkaSourceTest { testUtils.addPartitions(topic, newCount) eventually(timeout(streamingTimeout)) { assert( - query.lastExecution.logical.collectFirst { - case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r + query.lastExecution.executedPlan.collectFirst { + case scan: DataSourceV2ScanExec + if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] => + scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig] }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index c7b74f305eed..946b636710f0 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import java.io._ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Paths} -import java.util.{Locale, Optional, Properties} +import java.util.{Locale, Properties} import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger @@ -44,11 +44,9 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} -import org.apache.spark.sql.types.StructType abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with KafkaTest { @@ -118,14 +116,16 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with Kaf query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") - val sources = { + val sources: Seq[BaseStreamingSource] = { query.get.logicalPlan.collect { case StreamingExecutionRelation(source: KafkaSource, _) => source - case StreamingExecutionRelation(source: KafkaMicroBatchReader, _) => source + case StreamingExecutionRelation(source: KafkaMicroBatchReadSupport, _) => source } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case StreamingDataSourceV2Relation(_, _, _, reader: KafkaContinuousReader) => reader + case r: StreamingDataSourceV2Relation + if r.readSupport.isInstanceOf[KafkaContinuousReadSupport] => + r.readSupport.asInstanceOf[KafkaContinuousReadSupport] } }) }.distinct @@ -650,7 +650,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { makeSureGetOffsetCalled, AssertOnQuery { query => query.logicalPlan.collect { - case StreamingExecutionRelation(_: KafkaMicroBatchReader, _) => true + case StreamingExecutionRelation(_: KafkaMicroBatchReadSupport, _) => true }.nonEmpty } ) @@ -675,17 +675,16 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { "kafka.bootstrap.servers" -> testUtils.brokerAddress, "subscribe" -> topic ) ++ Option(minPartitions).map { p => "minPartitions" -> p} - val reader = provider.createMicroBatchReader( - Optional.empty[StructType], dir.getAbsolutePath, new DataSourceOptions(options.asJava)) - reader.setOffsetRange( - Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), - Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) - ) - val factories = reader.planInputPartitions().asScala + val readSupport = provider.createMicroBatchReadSupport( + dir.getAbsolutePath, new DataSourceOptions(options.asJava)) + val config = readSupport.newScanConfigBuilder( + KafkaSourceOffset(Map(tp -> 0L)), + KafkaSourceOffset(Map(tp -> 100L))).build() + val inputPartitions = readSupport.planInputPartitions(config) .map(_.asInstanceOf[KafkaMicroBatchInputPartition]) - withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { - assert(factories.size == numPartitionsGenerated) - factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } + withClue(s"minPartitions = $minPartitions generated factories $inputPartitions\n\t") { + assert(inputPartitions.size == numPartitionsGenerated) + inputPartitions.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java similarity index 59% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java index 80ac08ee5ff5..f403dc619e86 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java @@ -18,48 +18,44 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.DataSourceRegister; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; +import org.apache.spark.sql.sources.v2.reader.BatchReadSupport; import org.apache.spark.sql.types.StructType; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability and scan the data from the data source. + * provide data reading ability for batch processing. + * + * This interface is used to create {@link BatchReadSupport} instances when end users run + * {@code SparkSession.read.format(...).option(...).load()}. */ @InterfaceStability.Evolving -public interface ReadSupport extends DataSourceV2 { +public interface BatchReadSupportProvider extends DataSourceV2 { /** - * Creates a {@link DataSourceReader} to scan the data from this data source. + * Creates a {@link BatchReadSupport} instance to load the data from this data source with a user + * specified schema, which is called by Spark at the beginning of each batch query. + * + * Spark will call this method at the beginning of each batch query to create a + * {@link BatchReadSupport} instance. * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. + * By default this method throws {@link UnsupportedOperationException}, implementations should + * override this method to handle user specified schema. * * @param schema the user specified schema. * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. - * - * By default this method throws {@link UnsupportedOperationException}, implementations should - * override this method to handle user specified schema. */ - default DataSourceReader createReader(StructType schema, DataSourceOptions options) { - String name; - if (this instanceof DataSourceRegister) { - name = ((DataSourceRegister) this).shortName(); - } else { - name = this.getClass().getName(); - } - throw new UnsupportedOperationException(name + " does not support user specified schema"); + default BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) { + return DataSourceV2Utils.failForUserSpecifiedSchema(this); } /** - * Creates a {@link DataSourceReader} to scan the data from this data source. - * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. + * Creates a {@link BatchReadSupport} instance to scan the data from this data source, which is + * called by Spark at the beginning of each batch query. * * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. */ - DataSourceReader createReader(DataSourceOptions options); + BatchReadSupport createBatchReadSupport(DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java similarity index 58% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java index 048787a7a0a0..bd10c3353bf1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java @@ -21,33 +21,39 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; +import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport; import org.apache.spark.sql.types.StructType; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability and save the data to the data source. + * provide data writing ability for batch processing. + * + * This interface is used to create {@link BatchWriteSupport} instances when end users run + * {@code Dataset.write.format(...).option(...).save()}. */ @InterfaceStability.Evolving -public interface WriteSupport extends DataSourceV2 { +public interface BatchWriteSupportProvider extends DataSourceV2 { /** - * Creates an optional {@link DataSourceWriter} to save the data to this data source. Data - * sources can return None if there is no writing needed to be done according to the save mode. + * Creates an optional {@link BatchWriteSupport} instance to save the data to this data source, + * which is called by Spark at the beginning of each batch query. * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. + * Data sources can return None if there is no writing needed to be done according to the save + * mode. * - * @param writeUUID A unique string for the writing job. It's possible that there are many writing - * jobs running at the same time, and the returned {@link DataSourceWriter} can - * use this job id to distinguish itself from other jobs. + * @param queryId A unique string for the writing query. It's possible that there are many + * writing queries running at the same time, and the returned + * {@link BatchWriteSupport} can use this id to distinguish itself from others. * @param schema the schema of the data to be written. * @param mode the save mode which determines what to do when the data are already in this data * source, please refer to {@link SaveMode} for more details. * @param options the options for the returned data source writer, which is an immutable * case-insensitive string-to-string map. - * @return a writer to append data to this data source + * @return a write support to write data to this data source. */ - Optional createWriter( - String writeUUID, StructType schema, SaveMode mode, DataSourceOptions options); + Optional createBatchWriteSupport( + String queryId, + StructType schema, + SaveMode mode, + DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java deleted file mode 100644 index 7df5a451ae5f..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import java.util.Optional; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability for continuous stream processing. - */ -@InterfaceStability.Evolving -public interface ContinuousReadSupport extends DataSourceV2 { - /** - * Creates a {@link ContinuousReader} to scan the data from this data source. - * - * @param schema the user provided schema, or empty() if none was provided - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - ContinuousReader createContinuousReader( - Optional schema, - String checkpointLocation, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java new file mode 100644 index 000000000000..824c290518ac --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data reading ability for continuous stream processing. + * + * This interface is used to create {@link ContinuousReadSupport} instances when end users run + * {@code SparkSession.readStream.format(...).option(...).load()} with a continuous trigger. + */ +@InterfaceStability.Evolving +public interface ContinuousReadSupportProvider extends DataSourceV2 { + + /** + * Creates a {@link ContinuousReadSupport} instance to scan the data from this streaming data + * source with a user specified schema, which is called by Spark at the beginning of each + * continuous streaming query. + * + * By default this method throws {@link UnsupportedOperationException}, implementations should + * override this method to handle user specified schema. + * + * @param schema the user provided schema. + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + default ContinuousReadSupport createContinuousReadSupport( + StructType schema, + String checkpointLocation, + DataSourceOptions options) { + return DataSourceV2Utils.failForUserSpecifiedSchema(this); + } + + /** + * Creates a {@link ContinuousReadSupport} instance to scan the data from this streaming data + * source, which is called by Spark at the beginning of each continuous streaming query. + * + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + ContinuousReadSupport createContinuousReadSupport( + String checkpointLocation, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java index 6234071320dc..6e31e84bf6c7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java @@ -22,9 +22,13 @@ /** * The base interface for data source v2. Implementations must have a public, 0-arg constructor. * - * Note that this is an empty interface. Data source implementations should mix-in at least one of - * the plug-in interfaces like {@link ReadSupport} and {@link WriteSupport}. Otherwise it's just - * a dummy data source which is un-readable/writable. + * Note that this is an empty interface. Data source implementations must mix in interfaces such as + * {@link BatchReadSupportProvider} or {@link BatchWriteSupportProvider}, which can provide + * batch or streaming read/write support instances. Otherwise it's just a dummy data source which + * is un-readable/writable. + * + * If Spark fails to execute any methods in the implementations of this interface (by throwing an + * exception), the read action will fail and no Spark job will be submitted. */ @InterfaceStability.Evolving public interface DataSourceV2 {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java deleted file mode 100644 index 7f4a2c9593c7..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import java.util.Optional; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide streaming micro-batch data reading ability. - */ -@InterfaceStability.Evolving -public interface MicroBatchReadSupport extends DataSourceV2 { - /** - * Creates a {@link MicroBatchReader} to read batches of data from this data source in a - * streaming query. - * - * The execution engine will create a micro-batch reader at the start of a streaming query, - * alternate calls to setOffsetRange and planInputPartitions for each batch to process, and - * then call stop() when the execution is complete. Note that a single query may have multiple - * executions due to restart or failure recovery. - * - * @param schema the user provided schema, or empty() if none was provided - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - MicroBatchReader createMicroBatchReader( - Optional schema, - String checkpointLocation, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java new file mode 100644 index 000000000000..61c08e7fa89d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data reading ability for micro-batch stream processing. + * + * This interface is used to create {@link MicroBatchReadSupport} instances when end users run + * {@code SparkSession.readStream.format(...).option(...).load()} with a micro-batch trigger. + */ +@InterfaceStability.Evolving +public interface MicroBatchReadSupportProvider extends DataSourceV2 { + + /** + * Creates a {@link MicroBatchReadSupport} instance to scan the data from this streaming data + * source with a user specified schema, which is called by Spark at the beginning of each + * micro-batch streaming query. + * + * By default this method throws {@link UnsupportedOperationException}, implementations should + * override this method to handle user specified schema. + * + * @param schema the user provided schema. + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + default MicroBatchReadSupport createMicroBatchReadSupport( + StructType schema, + String checkpointLocation, + DataSourceOptions options) { + return DataSourceV2Utils.failForUserSpecifiedSchema(this); + } + + /** + * Creates a {@link MicroBatchReadSupport} instance to scan the data from this streaming data + * source, which is called by Spark at the beginning of each micro-batch streaming query. + * + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + MicroBatchReadSupport createMicroBatchReadSupport( + String checkpointLocation, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index 9d66805d79b9..bbe430e29926 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -27,10 +27,11 @@ @InterfaceStability.Evolving public interface SessionConfigSupport extends DataSourceV2 { - /** - * Key prefix of the session configs to propagate. Spark will extract all session configs that - * starts with `spark.datasource.$keyPrefix`, turn `spark.datasource.$keyPrefix.xxx -> yyy` - * into `xxx -> yyy`, and propagate them to all data source operations in this session. - */ - String keyPrefix(); + /** + * Key prefix of the session configs to propagate, which is usually the data source name. Spark + * will extract all session configs that starts with `spark.datasource.$keyPrefix`, turn + * `spark.datasource.$keyPrefix.xxx -> yyy` into `xxx -> yyy`, and propagate them to all + * data source operations in this session. + */ + String keyPrefix(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java deleted file mode 100644 index a77b01497269..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability for structured streaming. - */ -@InterfaceStability.Evolving -public interface StreamWriteSupport extends DataSourceV2, BaseStreamingSink { - - /** - * Creates an optional {@link StreamWriter} to save the data to this data source. Data - * sources can return None if there is no writing needed to be done. - * - * @param queryId A unique string for the writing query. It's possible that there are many - * writing queries running at the same time, and the returned - * {@link DataSourceWriter} can use this id to distinguish itself from others. - * @param schema the schema of the data to be written. - * @param mode the output mode which determines what successive epoch output means to this - * sink, please refer to {@link OutputMode} for more details. - * @param options the options for the returned data source writer, which is an immutable - * case-insensitive string-to-string map. - */ - StreamWriter createStreamWriter( - String queryId, - StructType schema, - OutputMode mode, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java new file mode 100644 index 000000000000..f9ca85d8089b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSink; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; +import org.apache.spark.sql.streaming.OutputMode; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data writing ability for structured streaming. + * + * This interface is used to create {@link StreamingWriteSupport} instances when end users run + * {@code Dataset.writeStream.format(...).option(...).start()}. + */ +@InterfaceStability.Evolving +public interface StreamingWriteSupportProvider extends DataSourceV2, BaseStreamingSink { + + /** + * Creates a {@link StreamingWriteSupport} instance to save the data to this data source, which is + * called by Spark at the beginning of each streaming query. + * + * @param queryId A unique string for the writing query. It's possible that there are many + * writing queries running at the same time, and the returned + * {@link StreamingWriteSupport} can use this id to distinguish itself from others. + * @param schema the schema of the data to be written. + * @param mode the output mode which determines what successive epoch output means to this + * sink, please refer to {@link OutputMode} for more details. + * @param options the options for the returned data source writer, which is an immutable + * case-insensitive string-to-string map. + */ + StreamingWriteSupport createStreamingWriteSupport( + String queryId, + StructType schema, + OutputMode mode, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java new file mode 100644 index 000000000000..452ee86675b4 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An interface that defines how to load the data from data source for batch processing. + * + * The execution engine will get an instance of this interface from a data source provider + * (e.g. {@link org.apache.spark.sql.sources.v2.BatchReadSupportProvider}) at the start of a batch + * query, then call {@link #newScanConfigBuilder()} and create an instance of {@link ScanConfig}. + * The {@link ScanConfigBuilder} can apply operator pushdown and keep the pushdown result in + * {@link ScanConfig}. The {@link ScanConfig} will be used to create input partitions and reader + * factory to scan data from the data source with a Spark job. + */ +@InterfaceStability.Evolving +public interface BatchReadSupport extends ReadSupport { + + /** + * Returns a builder of {@link ScanConfig}. Spark will call this method and create a + * {@link ScanConfig} for each data scanning job. + * + * The builder can take some query specific information to do operators pushdown, and keep these + * information in the created {@link ScanConfig}. + * + * This is the first step of the data scan. All other methods in {@link BatchReadSupport} needs + * to take {@link ScanConfig} as an input. + */ + ScanConfigBuilder newScanConfigBuilder(); + + /** + * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. + */ + PartitionReaderFactory createReaderFactory(ScanConfig config); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java deleted file mode 100644 index da98fab1284e..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import java.util.List; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.types.StructType; - -/** - * A data source reader that is returned by - * {@link ReadSupport#createReader(DataSourceOptions)} or - * {@link ReadSupport#createReader(StructType, DataSourceOptions)}. - * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic is delegated to {@link InputPartition}s, which are returned by - * {@link #planInputPartitions()}. - * - * There are mainly 3 kinds of query optimizations: - * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column - * pruning), etc. Names of these interfaces start with `SupportsPushDown`. - * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. - * Names of these interfaces start with `SupportsReporting`. - * 3. Columnar scan if implements {@link SupportsScanColumnarBatch}. - * - * If an exception was throw when applying any of these query optimizations, the action will fail - * and no Spark job will be submitted. - * - * Spark first applies all operator push-down optimizations that this data source supports. Then - * Spark collects information this data source reported for further optimizations. Finally Spark - * issues the scan request and does the actual data reading. - */ -@InterfaceStability.Evolving -public interface DataSourceReader { - - /** - * Returns the actual schema of this data source reader, which may be different from the physical - * schema of the underlying storage, as column pruning or other optimizations may happen. - * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. - */ - StructType readSchema(); - - /** - * Returns a list of {@link InputPartition}s. Each {@link InputPartition} is responsible for - * creating a data reader to output data of one RDD partition. The number of input partitions - * returned here is the same as the number of RDD partitions this scan outputs. - * - * Note that, this may not be a full scan if the data source reader mixes in other optimization - * interfaces like column pruning, filter push-down, etc. These optimizations are applied before - * Spark issues the scan request. - * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. - */ - List> planInputPartitions(); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index f2038d0de3ff..95c30de907e4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -22,18 +22,18 @@ import org.apache.spark.annotation.InterfaceStability; /** - * An input partition returned by {@link DataSourceReader#planInputPartitions()} and is - * responsible for creating the actual data reader of one RDD partition. - * The relationship between {@link InputPartition} and {@link InputPartitionReader} - * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. + * A serializable representation of an input partition returned by + * {@link ReadSupport#planInputPartitions(ScanConfig)}. * - * Note that {@link InputPartition}s will be serialized and sent to executors, then - * {@link InputPartitionReader}s will be created on executors to do the actual reading. So - * {@link InputPartition} must be serializable while {@link InputPartitionReader} doesn't need to - * be. + * Note that {@link InputPartition} will be serialized and sent to executors, then + * {@link PartitionReader} will be created by + * {@link PartitionReaderFactory#createReader(InputPartition)} or + * {@link PartitionReaderFactory#createColumnarReader(InputPartition)} on executors to do + * the actual reading. So {@link InputPartition} must be serializable while {@link PartitionReader} + * doesn't need to be. */ @InterfaceStability.Evolving -public interface InputPartition extends Serializable { +public interface InputPartition extends Serializable { /** * The preferred locations where the input partition reader returned by this partition can run @@ -51,12 +51,4 @@ public interface InputPartition extends Serializable { default String[] preferredLocations() { return new String[0]; } - - /** - * Returns an input partition reader to do the actual reading work. - * - * If this method fails (by throwing an exception), the corresponding Spark task would fail and - * get retried until hitting the maximum retry times. - */ - InputPartitionReader createPartitionReader(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java similarity index 67% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java index f3ff7f5cc0f2..04ff8d0a19fc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java @@ -23,31 +23,27 @@ import org.apache.spark.annotation.InterfaceStability; /** - * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is - * responsible for outputting data for a RDD partition. + * A partition reader returned by {@link PartitionReaderFactory#createReader(InputPartition)} or + * {@link PartitionReaderFactory#createColumnarReader(InputPartition)}. It's responsible for + * outputting data for a RDD partition. * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow} - * for normal data source readers, {@link org.apache.spark.sql.vectorized.ColumnarBatch} for data - * source readers that mix in {@link SupportsScanColumnarBatch}. + * for normal data sources, or {@link org.apache.spark.sql.vectorized.ColumnarBatch} for columnar + * data sources(whose {@link PartitionReaderFactory#supportColumnarReads(InputPartition)} + * returns true). */ @InterfaceStability.Evolving -public interface InputPartitionReader extends Closeable { +public interface PartitionReader extends Closeable { /** * Proceed to next record, returns false if there is no more records. * - * If this method fails (by throwing an exception), the corresponding Spark task would fail and - * get retried until hitting the maximum retry times. - * * @throws IOException if failure happens during disk/network IO like reading files. */ boolean next() throws IOException; /** * Return the current record. This method should return same value until `next` is called. - * - * If this method fails (by throwing an exception), the corresponding Spark task would fail and - * get retried until hitting the maximum retry times. */ T get(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java new file mode 100644 index 000000000000..f35de9310eee --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * A factory used to create {@link PartitionReader} instances. + * + * If Spark fails to execute any methods in the implementations of this interface or in the returned + * {@link PartitionReader} (by throwing an exception), corresponding Spark task would fail and + * get retried until hitting the maximum retry times. + */ +@InterfaceStability.Evolving +public interface PartitionReaderFactory extends Serializable { + + /** + * Returns a row-based partition reader to read data from the given {@link InputPartition}. + * + * Implementations probably need to cast the input partition to the concrete + * {@link InputPartition} class defined for the data source. + */ + PartitionReader createReader(InputPartition partition); + + /** + * Returns a columnar partition reader to read data from the given {@link InputPartition}. + * + * Implementations probably need to cast the input partition to the concrete + * {@link InputPartition} class defined for the data source. + */ + default PartitionReader createColumnarReader(InputPartition partition) { + throw new UnsupportedOperationException("Cannot create columnar reader."); + } + + /** + * Returns true if the given {@link InputPartition} should be read by Spark in a columnar way. + * This means, implementations must also implement {@link #createColumnarReader(InputPartition)} + * for the input partitions that this method returns true. + * + * As of Spark 2.4, Spark can only read all input partition in a columnar way, or none of them. + * Data source can't mix columnar and row-based partitions. This may be relaxed in future + * versions. + */ + default boolean supportColumnarReads(InputPartition partition) { + return false; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java new file mode 100644 index 000000000000..a58ddb288f1e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.types.StructType; + +/** + * The base interface for all the batch and streaming read supports. Data sources should implement + * concrete read support interfaces like {@link BatchReadSupport}. + * + * If Spark fails to execute any methods in the implementations of this interface (by throwing an + * exception), the read action will fail and no Spark job will be submitted. + */ +@InterfaceStability.Evolving +public interface ReadSupport { + + /** + * Returns the full schema of this data source, which is usually the physical schema of the + * underlying storage. This full schema should not be affected by column pruning or other + * optimizations. + */ + StructType fullSchema(); + + /** + * Returns a list of {@link InputPartition input partitions}. Each {@link InputPartition} + * represents a data split that can be processed by one Spark task. The number of input + * partitions returned here is the same as the number of RDD partitions this scan outputs. + * + * Note that, this may not be a full scan if the data source supports optimization like filter + * push-down. Implementations should check the input {@link ScanConfig} and adjust the resulting + * {@link InputPartition input partitions}. + */ + InputPartition[] planInputPartitions(ScanConfig config); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java new file mode 100644 index 000000000000..7462ce282058 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.types.StructType; + +/** + * An interface that carries query specific information for the data scanning job, like operator + * pushdown information and streaming query offsets. This is defined as an empty interface, and data + * sources should define their own {@link ScanConfig} classes. + * + * For APIs that take a {@link ScanConfig} as input, like + * {@link ReadSupport#planInputPartitions(ScanConfig)}, + * {@link BatchReadSupport#createReaderFactory(ScanConfig)} and + * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}, implementations mostly need to + * cast the input {@link ScanConfig} to the concrete {@link ScanConfig} class of the data source. + */ +@InterfaceStability.Evolving +public interface ScanConfig { + + /** + * Returns the actual schema of this data source reader, which may be different from the physical + * schema of the underlying storage, as column pruning or other optimizations may happen. + * + * If this method fails (by throwing an exception), the action will fail and no Spark job will be + * submitted. + */ + StructType readSchema(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java similarity index 61% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java index dcb87715d0b6..4c0eedfddfe2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java @@ -18,18 +18,13 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset; /** - * A mix-in interface for {@link InputPartition}. Continuous input partitions can - * implement this interface to provide creating {@link InputPartitionReader} with particular offset. + * An interface for building the {@link ScanConfig}. Implementations can mixin those + * SupportsPushDownXYZ interfaces to do operator pushdown, and keep the operator pushdown result in + * the returned {@link ScanConfig}. */ @InterfaceStability.Evolving -public interface ContinuousInputPartition extends InputPartition { - /** - * Create an input partition reader with particular offset as its startOffset. - * - * @param offset offset want to set as the input partition reader's startOffset. - */ - InputPartitionReader createContinuousReader(PartitionOffset offset); +public interface ScanConfigBuilder { + ScanConfig build(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java index e8cd7adbca07..44799c7d4913 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java @@ -23,7 +23,7 @@ /** * An interface to represent statistics for a data source, which is returned by - * {@link SupportsReportStatistics#getStatistics()}. + * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}. */ @InterfaceStability.Evolving public interface Statistics { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 4543c143a9ac..9d79a18d14bc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression; /** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this + * A mix-in interface for {@link ScanConfigBuilder}. Data source readers can implement this * interface to push down arbitrary expressions as predicates to the data source. * This is an experimental and unstable interface as {@link Expression} is not public and may get * changed in the future Spark versions. @@ -31,7 +31,7 @@ * process this interface. */ @InterfaceStability.Unstable -public interface SupportsPushDownCatalystFilters extends DataSourceReader { +public interface SupportsPushDownCatalystFilters extends ScanConfigBuilder { /** * Pushes down filters, and returns filters that need to be evaluated after scanning. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index b6a90a3d0b68..5d32a8ac60f7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -21,15 +21,15 @@ import org.apache.spark.sql.sources.Filter; /** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to push down filters to the data source and reduce the size of the data to be read. + * A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this interface to + * push down filters to the data source and reduce the size of the data to be read. * * Note that, if data source readers implement both this interface and * {@link SupportsPushDownCatalystFilters}, Spark will ignore this interface and only process * {@link SupportsPushDownCatalystFilters}. */ @InterfaceStability.Evolving -public interface SupportsPushDownFilters extends DataSourceReader { +public interface SupportsPushDownFilters extends ScanConfigBuilder { /** * Pushes down filters, and returns filters that need to be evaluated after scanning. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java index 427b4d00a112..edb164937d6e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java @@ -21,12 +21,12 @@ import org.apache.spark.sql.types.StructType; /** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this + * A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this * interface to push down required columns to the data source and only read these columns during * scan to reduce the size of the data to be read. */ @InterfaceStability.Evolving -public interface SupportsPushDownRequiredColumns extends DataSourceReader { +public interface SupportsPushDownRequiredColumns extends ScanConfigBuilder { /** * Applies column pruning w.r.t. the given requiredSchema. @@ -35,8 +35,8 @@ public interface SupportsPushDownRequiredColumns extends DataSourceReader { * also OK to do the pruning partially, e.g., a data source may not be able to prune nested * fields, and only prune top-level columns. * - * Note that, data source readers should update {@link DataSourceReader#readSchema()} after - * applying column pruning. + * Note that, {@link ScanConfig#readSchema()} implementation should take care of the column + * pruning applied here. */ void pruneColumns(StructType requiredSchema); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index 6b60da7c4dc1..db62cd451536 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -21,17 +21,17 @@ import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; /** - * A mix in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to report data partitioning and try to avoid shuffle at Spark side. + * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to + * report data partitioning and try to avoid shuffle at Spark side. * - * Note that, when the reader creates exactly one {@link InputPartition}, Spark may avoid - * adding a shuffle even if the reader does not implement this interface. + * Note that, when a {@link ReadSupport} implementation creates exactly one {@link InputPartition}, + * Spark may avoid adding a shuffle even if the reader does not implement this interface. */ @InterfaceStability.Evolving -public interface SupportsReportPartitioning extends DataSourceReader { +public interface SupportsReportPartitioning extends ReadSupport { /** * Returns the output data partitioning that this reader guarantees. */ - Partitioning outputPartitioning(); + Partitioning outputPartitioning(ScanConfig config); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index 926396414816..1831488ba096 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -20,18 +20,18 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A mix in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to report statistics to Spark. + * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to + * report statistics to Spark. * - * Statistics are reported to the optimizer before any operator is pushed to the DataSourceReader. - * Implementations that return more accurate statistics based on pushed operators will not improve - * query performance until the planner can push operators before getting stats. + * As of Spark 2.4, statistics are reported to the optimizer before any operator is pushed to the + * data source. Implementations that return more accurate statistics based on pushed operators will + * not improve query performance until the planner can push operators before getting stats. */ @InterfaceStability.Evolving -public interface SupportsReportStatistics extends DataSourceReader { +public interface SupportsReportStatistics extends ReadSupport { /** - * Returns the basic statistics of this data source. + * Returns the estimated statistics of this data source scan. */ - Statistics getStatistics(); + Statistics estimateStatistics(ScanConfig config); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java deleted file mode 100644 index f4da686740d1..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import java.util.List; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.vectorized.ColumnarBatch; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to output {@link ColumnarBatch} and make the scan faster. - */ -@InterfaceStability.Evolving -public interface SupportsScanColumnarBatch extends DataSourceReader { - @Override - default List> planInputPartitions() { - throw new IllegalStateException( - "planInputPartitions not supported by default within SupportsScanColumnarBatch."); - } - - /** - * Similar to {@link DataSourceReader#planInputPartitions()}, but returns columnar data - * in batches. - */ - List> planBatchInputPartitions(); - - /** - * Returns true if the concrete data source reader can read data in batch according to the scan - * properties like required columns, pushes filters, etc. It's possible that the implementation - * can only support some certain columns with certain types. Users can overwrite this method and - * {@link #planInputPartitions()} to fallback to normal read path under some conditions. - */ - default boolean enableBatchRead() { - return true; - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 38ca5fc6387b..6764d4b7665c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -18,12 +18,12 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** * A concrete implementation of {@link Distribution}. Represents a distribution where records that * share the same values for the {@link #clusteredColumns} will be produced by the same - * {@link InputPartitionReader}. + * {@link PartitionReader}. */ @InterfaceStability.Evolving public class ClusteredDistribution implements Distribution { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index 5e32ba6952e1..364a3f553923 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -18,14 +18,14 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the data partitions (one {@link InputPartitionReader} outputs data for one + * be distributed among the data partitions (one {@link PartitionReader} outputs data for one * partition). * Note that this interface has nothing to do with the data ordering inside one - * partition(the output records of a single {@link InputPartitionReader}). + * partition(the output records of a single {@link PartitionReader}). * * The instance of this interface is created and provided by Spark, then consumed by * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index f460f6bfe3bb..fb0b6f1df43b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -19,12 +19,13 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.ScanConfig; import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; /** * An interface to represent the output data partitioning for a data source, which is returned by - * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this should work like a - * snapshot. Once created, it should be deterministic and always report the same number of + * {@link SupportsReportPartitioning#outputPartitioning(ScanConfig)}. Note that this should work + * like a snapshot. Once created, it should be deterministic and always report the same number of * partitions and the same "satisfy" result for a certain distribution. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java similarity index 60% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java index 7b0ba0bbdda9..9101c8a44d34 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java @@ -18,19 +18,20 @@ package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** - * A variation on {@link InputPartitionReader} for use with streaming in continuous processing mode. + * A variation on {@link PartitionReader} for use with continuous streaming processing. */ @InterfaceStability.Evolving -public interface ContinuousInputPartitionReader extends InputPartitionReader { - /** - * Get the offset of the current record, or the start offset if no records have been read. - * - * The execution engine will call this method along with get() to keep track of the current - * offset. When an epoch ends, the offset of the previous record in each partition will be saved - * as a restart checkpoint. - */ - PartitionOffset getOffset(); +public interface ContinuousPartitionReader extends PartitionReader { + + /** + * Get the offset of the current record, or the start offset if no records have been read. + * + * The execution engine will call this method along with get() to keep track of the current + * offset. When an epoch ends, the offset of the previous record in each partition will be saved + * as a restart checkpoint. + */ + PartitionOffset getOffset(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsDeprecatedScanRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java similarity index 51% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsDeprecatedScanRow.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java index 595943cf4d8a..2d9f1ca1686a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsDeprecatedScanRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java @@ -15,25 +15,26 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; - -import java.util.List; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory; +import org.apache.spark.sql.vectorized.ColumnarBatch; /** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to output {@link Row} instead of {@link InternalRow}. - * This is an experimental and unstable interface. + * A variation on {@link PartitionReaderFactory} that returns {@link ContinuousPartitionReader} + * instead of {@link org.apache.spark.sql.sources.v2.reader.PartitionReader}. It's used for + * continuous streaming processing. */ -@InterfaceStability.Unstable -public interface SupportsDeprecatedScanRow extends DataSourceReader { - default List> planInputPartitions() { - throw new IllegalStateException( - "planInputPartitions not supported by default within SupportsDeprecatedScanRow"); - } +@InterfaceStability.Evolving +public interface ContinuousPartitionReaderFactory extends PartitionReaderFactory { + @Override + ContinuousPartitionReader createReader(InputPartition partition); - List> planRowInputPartitions(); + @Override + default ContinuousPartitionReader createColumnarReader(InputPartition partition) { + throw new UnsupportedOperationException("Cannot create columnar reader."); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java new file mode 100644 index 000000000000..9a3ad2eb8a80 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSource; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.ScanConfig; +import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder; + +/** + * An interface that defines how to load the data from data source for continuous streaming + * processing. + * + * The execution engine will get an instance of this interface from a data source provider + * (e.g. {@link org.apache.spark.sql.sources.v2.ContinuousReadSupportProvider}) at the start of a + * streaming query, then call {@link #newScanConfigBuilder(Offset)} and create an instance of + * {@link ScanConfig} for the duration of the streaming query or until + * {@link #needsReconfiguration(ScanConfig)} is true. The {@link ScanConfig} will be used to create + * input partitions and reader factory to scan data with a Spark job for its duration. At the end + * {@link #stop()} will be called when the streaming execution is completed. Note that a single + * query may have multiple executions due to restart or failure recovery. + */ +@InterfaceStability.Evolving +public interface ContinuousReadSupport extends StreamingReadSupport, BaseStreamingSource { + + /** + * Returns a builder of {@link ScanConfig}. Spark will call this method and create a + * {@link ScanConfig} for each data scanning job. + * + * The builder can take some query specific information to do operators pushdown, store streaming + * offsets, etc., and keep these information in the created {@link ScanConfig}. + * + * This is the first step of the data scan. All other methods in {@link ContinuousReadSupport} + * needs to take {@link ScanConfig} as an input. + */ + ScanConfigBuilder newScanConfigBuilder(Offset start); + + /** + * Returns a factory, which produces one {@link ContinuousPartitionReader} for one + * {@link InputPartition}. + */ + ContinuousPartitionReaderFactory createContinuousReaderFactory(ScanConfig config); + + /** + * Merge partitioned offsets coming from {@link ContinuousPartitionReader} instances + * for each partition to a single global offset. + */ + Offset mergeOffsets(PartitionOffset[] offsets); + + /** + * The execution engine will call this method in every epoch to determine if new input + * partitions need to be generated, which may be required if for example the underlying + * source system has had partitions added or removed. + * + * If true, the query will be shut down and restarted with a new {@link ContinuousReadSupport} + * instance. + */ + default boolean needsReconfiguration(ScanConfig config) { + return false; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java deleted file mode 100644 index 6e960bedf802..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader.streaming; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSource; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; - -import java.util.Optional; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to allow reading in a continuous processing mode stream. - * - * Implementations must ensure each partition reader is a {@link ContinuousInputPartitionReader}. - * - * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with - * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. - */ -@InterfaceStability.Evolving -public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { - /** - * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances - * for each partition to a single global offset. - */ - Offset mergeOffsets(PartitionOffset[] offsets); - - /** - * Deserialize a JSON string into an Offset of the implementation-defined offset type. - * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader - */ - Offset deserializeOffset(String json); - - /** - * Set the desired start offset for partitions created from this reader. The scan will - * start from the first record after the provided offset, or from an implementation-defined - * inferred starting point if no offset is provided. - */ - void setStartOffset(Optional start); - - /** - * Return the specified or inferred start offset for this reader. - * - * @throws IllegalStateException if setStartOffset has not been called - */ - Offset getStartOffset(); - - /** - * The execution engine will call this method in every epoch to determine if new input - * partitions need to be generated, which may be required if for example the underlying - * source system has had partitions added or removed. - * - * If true, the query will be shut down and restarted with a new reader. - */ - default boolean needsReconfiguration() { - return false; - } - - /** - * Informs the source that Spark has completed processing all data for offsets less than or - * equal to `end` and will only request offsets greater than `end` in the future. - */ - void commit(Offset end); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java new file mode 100644 index 000000000000..edb0db11bff2 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSource; +import org.apache.spark.sql.sources.v2.reader.*; + +/** + * An interface that defines how to scan the data from data source for micro-batch streaming + * processing. + * + * The execution engine will get an instance of this interface from a data source provider + * (e.g. {@link org.apache.spark.sql.sources.v2.MicroBatchReadSupportProvider}) at the start of a + * streaming query, then call {@link #newScanConfigBuilder(Offset, Offset)} and create an instance + * of {@link ScanConfig} for each micro-batch. The {@link ScanConfig} will be used to create input + * partitions and reader factory to scan a micro-batch with a Spark job. At the end {@link #stop()} + * will be called when the streaming execution is completed. Note that a single query may have + * multiple executions due to restart or failure recovery. + */ +@InterfaceStability.Evolving +public interface MicroBatchReadSupport extends StreamingReadSupport, BaseStreamingSource { + + /** + * Returns a builder of {@link ScanConfig}. Spark will call this method and create a + * {@link ScanConfig} for each data scanning job. + * + * The builder can take some query specific information to do operators pushdown, store streaming + * offsets, etc., and keep these information in the created {@link ScanConfig}. + * + * This is the first step of the data scan. All other methods in {@link MicroBatchReadSupport} + * needs to take {@link ScanConfig} as an input. + */ + ScanConfigBuilder newScanConfigBuilder(Offset start, Offset end); + + /** + * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. + */ + PartitionReaderFactory createReaderFactory(ScanConfig config); + + /** + * Returns the most recent offset available. + */ + Offset latestOffset(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java deleted file mode 100644 index 0159c731762d..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader.streaming; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.execution.streaming.BaseStreamingSource; - -import java.util.Optional; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to indicate they allow micro-batch streaming reads. - * - * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with - * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. - */ -@InterfaceStability.Evolving -public interface MicroBatchReader extends DataSourceReader, BaseStreamingSource { - /** - * Set the desired offset range for input partitions created from this reader. Partition readers - * will generate only data within (`start`, `end`]; that is, from the first record after `start` - * to the record with offset `end`. - * - * @param start The initial offset to scan from. If not specified, scan from an - * implementation-specified start point, such as the earliest available record. - * @param end The last offset to include in the scan. If not specified, scan up to an - * implementation-defined endpoint, such as the last available offset - * or the start offset plus a target batch size. - */ - void setOffsetRange(Optional start, Optional end); - - /** - * Returns the specified (if explicitly set through setOffsetRange) or inferred start offset - * for this reader. - * - * @throws IllegalStateException if setOffsetRange has not been called - */ - Offset getStartOffset(); - - /** - * Return the specified (if explicitly set through setOffsetRange) or inferred end offset - * for this reader. - * - * @throws IllegalStateException if setOffsetRange has not been called - */ - Offset getEndOffset(); - - /** - * Deserialize a JSON string into an Offset of the implementation-defined offset type. - * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader - */ - Offset deserializeOffset(String json); - - /** - * Informs the source that Spark has completed processing all data for offsets less than or - * equal to `end` and will only request offsets greater than `end` in the future. - */ - void commit(Offset end); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index e41c0351edc8..6cf27734867c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -20,8 +20,8 @@ import org.apache.spark.annotation.InterfaceStability; /** - * An abstract representation of progress through a {@link MicroBatchReader} or - * {@link ContinuousReader}. + * An abstract representation of progress through a {@link MicroBatchReadSupport} or + * {@link ContinuousReadSupport}. * During execution, offsets provided by the data source implementation will be logged and used as * restart checkpoints. Each source should provide an offset implementation which the source can use * to reconstruct a position in the stream up to which data has been seen/processed. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java new file mode 100644 index 000000000000..84872d1ebc26 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.sql.sources.v2.reader.ReadSupport; + +/** + * A base interface for streaming read support. This is package private and is invisible to data + * sources. Data sources should implement concrete streaming read support interfaces: + * {@link MicroBatchReadSupport} or {@link ContinuousReadSupport}. + */ +interface StreamingReadSupport extends ReadSupport { + + /** + * Returns the initial offset for a streaming query to start reading from. Note that the + * streaming data source should not assume that it will start reading from its initial offset: + * if Spark is restarting an existing query, it will restart from the check-pointed offset rather + * than the initial one. + */ + Offset initialOffset(); + + /** + * Deserialize a JSON string into an Offset of the implementation-defined offset type. + * + * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader + */ + Offset deserializeOffset(String json); + + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + void commit(Offset end); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java index 3b293d925c91..8693154cb704 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java @@ -14,20 +14,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.CustomMetrics; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; /** - * A mix in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to report custom metrics that gets reported under the + * A mix in interface for {@link StreamingReadSupport}. Data sources can implement this interface + * to report custom metrics that gets reported under the * {@link org.apache.spark.sql.streaming.SourceProgress} - * */ @InterfaceStability.Evolving -public interface SupportsCustomReaderMetrics extends DataSourceReader { +public interface SupportsCustomReaderMetrics extends StreamingReadSupport { + /** * Returns custom metrics specific to this data source. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java similarity index 79% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java index 385fc294fea8..0ec9e05d6a02 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java @@ -18,28 +18,13 @@ package org.apache.spark.sql.sources.v2.writer; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.StreamWriteSupport; -import org.apache.spark.sql.sources.v2.WriteSupport; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; /** - * A data source writer that is returned by - * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceOptions)}/ - * {@link StreamWriteSupport#createStreamWriter( - * String, StructType, OutputMode, DataSourceOptions)}. - * It can mix in various writing optimization interfaces to speed up the data saving. The actual - * writing logic is delegated to {@link DataWriter}. - * - * If an exception was throw when applying any of these writing optimizations, the action will fail - * and no Spark job will be submitted. + * An interface that defines how to write the data to data source for batch processing. * * The writing procedure is: - * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the - * partitions of the input data(RDD). + * 1. Create a writer factory by {@link #createBatchWriterFactory()}, serialize and send it to all + * the partitions of the input data(RDD). * 2. For each partition, create the data writer, and write the data of the partition with this * writer. If all the data are written successfully, call {@link DataWriter#commit()}. If * exception happens during the writing, call {@link DataWriter#abort()}. @@ -53,7 +38,7 @@ * Please refer to the documentation of commit/abort methods for detailed specifications. */ @InterfaceStability.Evolving -public interface DataSourceWriter { +public interface BatchWriteSupport { /** * Creates a writer factory which will be serialized and sent to executors. @@ -61,7 +46,7 @@ public interface DataSourceWriter { * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ - DataWriterFactory createWriterFactory(); + DataWriterFactory createBatchWriterFactory(); /** * Returns whether Spark should use the commit coordinator to ensure that at most one task for diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 27dc5ea224fe..5fb067966ee6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createDataWriter(int, long, long)} and is + * A data writer returned by {@link DataWriterFactory#createWriter(int, long)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -36,11 +36,11 @@ * * If this data writer succeeds(all records are successfully written and {@link #commit()} * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to - * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data + * {@link BatchWriteSupport#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an * exception will be sent to the driver side, and Spark may retry this writing task a few times. - * In each retry, {@link DataWriterFactory#createDataWriter(int, long, long)} will receive a - * different `taskId`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} + * In each retry, {@link DataWriterFactory#createWriter(int, long)} will receive a + * different `taskId`. Spark will call {@link BatchWriteSupport#abort(WriterCommitMessage[])} * when the configured number of retries is exhausted. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task @@ -71,11 +71,11 @@ public interface DataWriter { /** * Commits this writer after all records are written successfully, returns a commit message which * will be sent back to driver side and passed to - * {@link DataSourceWriter#commit(WriterCommitMessage[])}. + * {@link BatchWriteSupport#commit(WriterCommitMessage[])}. * * The written data should only be visible to data source readers after - * {@link DataSourceWriter#commit(WriterCommitMessage[])} succeeds, which means this method - * should still "hide" the written data and ask the {@link DataSourceWriter} at driver side to + * {@link BatchWriteSupport#commit(WriterCommitMessage[])} succeeds, which means this method + * should still "hide" the written data and ask the {@link BatchWriteSupport} at driver side to * do the final commit via {@link WriterCommitMessage}. * * If this method fails (by throwing an exception), {@link #abort()} will be called and this @@ -93,7 +93,7 @@ public interface DataWriter { * failed. * * If this method fails(by throwing an exception), the underlying data source may have garbage - * that need to be cleaned by {@link DataSourceWriter#abort(WriterCommitMessage[])} or manually, + * that need to be cleaned by {@link BatchWriteSupport#abort(WriterCommitMessage[])} or manually, * but these garbage should not be visible to data source readers. * * @throws IOException if failure happens during disk/network IO like writing files. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 3d337b6e0bdf..19a36dd23245 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -19,18 +19,20 @@ import java.io.Serializable; +import org.apache.spark.TaskContext; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.InternalRow; /** - * A factory of {@link DataWriter} returned by {@link DataSourceWriter#createWriterFactory()}, + * A factory of {@link DataWriter} returned by {@link BatchWriteSupport#createBatchWriterFactory()}, * which is responsible for creating and initializing the actual data writer at executor side. * * Note that, the writer factory will be serialized and sent to executors, then the data writer - * will be created on executors and do the actual writing. So {@link DataWriterFactory} must be + * will be created on executors and do the actual writing. So this interface must be * serializable and {@link DataWriter} doesn't need to be. */ @InterfaceStability.Evolving -public interface DataWriterFactory extends Serializable { +public interface DataWriterFactory extends Serializable { /** * Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data @@ -38,19 +40,16 @@ public interface DataWriterFactory extends Serializable { * are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a * list. * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. + * If this method fails (by throwing an exception), the corresponding Spark write task would fail + * and get retried until hitting the maximum retry times. * * @param partitionId A unique id of the RDD partition that the returned writer will process. * Usually Spark processes many RDD partitions at the same time, * implementations should use the partition id to distinguish writers for * different partitions. - * @param taskId A unique identifier for a task that is performing the write of the partition - * data. Spark may run multiple tasks for the same partition (due to speculation - * or task failures, for example). - * @param epochId A monotonically increasing id for streaming queries that are split in to - * discrete periods of execution. For non-streaming queries, - * this ID will always be 0. + * @param taskId The task id returned by {@link TaskContext#taskAttemptId()}. Spark may run + * multiple tasks for the same partition (due to speculation or task failures, + * for example). */ - DataWriter createDataWriter(int partitionId, long taskId, long epochId); + DataWriter createWriter(int partitionId, long taskId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java index 9e38836c0edf..123335c414e9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -19,15 +19,16 @@ import java.io.Serializable; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; import org.apache.spark.annotation.InterfaceStability; /** * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side - * as the input parameter of {@link DataSourceWriter#commit(WriterCommitMessage[])}. + * as the input parameter of {@link BatchWriteSupport#commit(WriterCommitMessage[])} or + * {@link StreamingWriteSupport#commit(long, WriterCommitMessage[])}. * - * This is an empty interface, data sources should define their own message class and use it in - * their {@link DataWriter#commit()} and {@link DataSourceWriter#commit(WriterCommitMessage[])} - * implementations. + * This is an empty interface, data sources should define their own message class and use it when + * generating messages at executor side and handling the messages at driver side. */ @InterfaceStability.Evolving public interface WriterCommitMessage extends Serializable {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java new file mode 100644 index 000000000000..a4da24fc5ae6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer.streaming; + +import java.io.Serializable; + +import org.apache.spark.TaskContext; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.sources.v2.writer.DataWriter; + +/** + * A factory of {@link DataWriter} returned by + * {@link StreamingWriteSupport#createStreamingWriterFactory()}, which is responsible for creating + * and initializing the actual data writer at executor side. + * + * Note that, the writer factory will be serialized and sent to executors, then the data writer + * will be created on executors and do the actual writing. So this interface must be + * serializable and {@link DataWriter} doesn't need to be. + */ +@InterfaceStability.Evolving +public interface StreamingDataWriterFactory extends Serializable { + + /** + * Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data + * object instance when sending data to the data writer, for better performance. Data writers + * are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a + * list. + * + * If this method fails (by throwing an exception), the corresponding Spark write task would fail + * and get retried until hitting the maximum retry times. + * + * @param partitionId A unique id of the RDD partition that the returned writer will process. + * Usually Spark processes many RDD partitions at the same time, + * implementations should use the partition id to distinguish writers for + * different partitions. + * @param taskId The task id returned by {@link TaskContext#taskAttemptId()}. Spark may run + * multiple tasks for the same partition (due to speculation or task failures, + * for example). + * @param epochId A monotonically increasing id for streaming queries that are split in to + * discrete periods of execution. + */ + DataWriter createWriter(int partitionId, long taskId, long epochId); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java similarity index 78% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java index a316b2a4c1d8..3fdfac5e1c84 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java @@ -18,27 +18,36 @@ package org.apache.spark.sql.sources.v2.writer.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.sources.v2.writer.DataWriter; import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * A {@link DataSourceWriter} for use with structured streaming. + * An interface that defines how to write the data to data source for streaming processing. * * Streaming queries are divided into intervals of data called epochs, with a monotonically * increasing numeric ID. This writer handles commits and aborts for each successive epoch. */ @InterfaceStability.Evolving -public interface StreamWriter extends DataSourceWriter { +public interface StreamingWriteSupport { + + /** + * Creates a writer factory which will be serialized and sent to executors. + * + * If this method fails (by throwing an exception), the action will fail and no Spark job will be + * submitted. + */ + StreamingDataWriterFactory createStreamingWriterFactory(); + /** * Commits this writing job for the specified epoch with a list of commit messages. The commit * messages are collected from successful data writers and are produced by * {@link DataWriter#commit()}. * * If this method fails (by throwing an exception), this writing job is considered to have been - * failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}. + * failed, and the execution engine will attempt to call + * {@link #abort(long, WriterCommitMessage[])}. * - * The execution engine may call commit() multiple times for the same epoch in some circumstances. + * The execution engine may call `commit` multiple times for the same epoch in some circumstances. * To support exactly-once data semantics, implementations must ensure that multiple commits for * the same epoch are idempotent. */ @@ -46,7 +55,8 @@ public interface StreamWriter extends DataSourceWriter { /** * Aborts this writing job because some data writers are failed and keep failing when retried, or - * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. + * the Spark job fails with some unknown reasons, or {@link #commit(long, WriterCommitMessage[])} + * fails. * * If this method fails (by throwing an exception), the underlying data source may require manual * cleanup. @@ -58,14 +68,4 @@ public interface StreamWriter extends DataSourceWriter { * clean up the data left by data writers. */ void abort(long epochId, WriterCommitMessage[] messages); - - default void commit(WriterCommitMessage[] messages) { - throw new UnsupportedOperationException( - "Commit without epoch should not be called with StreamWriter"); - } - - default void abort(WriterCommitMessage[] messages) { - throw new UnsupportedOperationException( - "Abort without epoch should not be called with StreamWriter"); - } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java index 0cd36501320f..2b018c7d123b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java @@ -14,20 +14,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.sources.v2.writer.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.CustomMetrics; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; /** - * A mix in interface for {@link DataSourceWriter}. Data source writers can implement this - * interface to report custom metrics that gets reported under the + * A mix in interface for {@link StreamingWriteSupport}. Data sources can implement this interface + * to report custom metrics that gets reported under the * {@link org.apache.spark.sql.streaming.SinkProgress} - * */ @InterfaceStability.Evolving -public interface SupportsCustomWriterMetrics extends DataSourceWriter { +public interface SupportsCustomWriterMetrics extends StreamingWriteSupport { + /** * Returns custom metrics specific to this data source. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 5b3b5c2451aa..0cfcc45fb3d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport} +import org.apache.spark.sql.sources.v2.{BatchReadSupportProvider, DataSourceOptions, DataSourceV2} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -194,7 +194,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { val ds = cls.newInstance().asInstanceOf[DataSourceV2] - if (ds.isInstanceOf[ReadSupport]) { + if (ds.isInstanceOf[BatchReadSupportProvider]) { val sessionOptions = DataSourceV2Utils.extractSessionConfigs( ds = ds, conf = sparkSession.sessionState.conf) val pathsOption = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 650c91790a75..eca2d5b97190 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -240,7 +240,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { if (classOf[DataSourceV2].isAssignableFrom(cls)) { val source = cls.newInstance().asInstanceOf[DataSourceV2] source match { - case ws: WriteSupport => + case provider: BatchWriteSupportProvider => val options = extraOptions ++ DataSourceV2Utils.extractSessionConfigs(source, df.sparkSession.sessionState.conf) @@ -251,8 +251,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } else { - val writer = ws.createWriter( - UUID.randomUUID.toString, df.logicalPlan.output.toStructType, mode, + val writer = provider.createBatchWriteSupport( + UUID.randomUUID().toString, + df.logicalPlan.output.toStructType, + mode, new DataSourceOptions(options.asJava)) if (writer.isPresent) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 782829887c44..f62f7349d1da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -17,19 +17,22 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.reflect.ClassTag - -import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} +import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReader, PartitionReaderFactory} -class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: InputPartition[T]) +class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition) extends Partition with Serializable -class DataSourceRDD[T: ClassTag]( +// TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for +// columnar scan. +class DataSourceRDD( sc: SparkContext, - @transient private val inputPartitions: Seq[InputPartition[T]]) - extends RDD[T](sc, Nil) { + @transient private val inputPartitions: Seq[InputPartition], + partitionReaderFactory: PartitionReaderFactory, + columnarReads: Boolean) + extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { inputPartitions.zipWithIndex.map { @@ -37,11 +40,21 @@ class DataSourceRDD[T: ClassTag]( }.toArray } - override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val reader = split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition - .createPartitionReader() + private def castPartition(split: Partition): DataSourceRDDPartition = split match { + case p: DataSourceRDDPartition => p + case _ => throw new SparkException(s"[BUG] Not a DataSourceRDDPartition: $split") + } + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val inputPartition = castPartition(split).inputPartition + val reader: PartitionReader[_] = if (columnarReads) { + partitionReaderFactory.createColumnarReader(inputPartition) + } else { + partitionReaderFactory.createReader(inputPartition) + } + context.addTaskCompletionListener[Unit](_ => reader.close()) - val iter = new Iterator[T] { + val iter = new Iterator[Any] { private[this] var valuePrepared = false override def hasNext: Boolean = { @@ -51,7 +64,7 @@ class DataSourceRDD[T: ClassTag]( valuePrepared } - override def next(): T = { + override def next(): Any = { if (!hasNext) { throw new java.util.NoSuchElementException("End of stream") } @@ -59,10 +72,11 @@ class DataSourceRDD[T: ClassTag]( reader.get() } } - new InterruptibleIterator(context, iter) + // TODO: SPARK-25083 remove the type erasure hack in data source scan + new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]]) } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition.preferredLocations() + castPartition(split).inputPartition.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index a4bfc861cc9a..f7e29593a635 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -27,21 +27,21 @@ import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelat import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport} -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics} -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter +import org.apache.spark.sql.sources.v2.{BatchReadSupportProvider, BatchWriteSupportProvider, DataSourceOptions, DataSourceV2} +import org.apache.spark.sql.sources.v2.reader.{BatchReadSupport, ReadSupport, ScanConfigBuilder, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport import org.apache.spark.sql.types.StructType /** * A logical plan representing a data source v2 scan. * * @param source An instance of a [[DataSourceV2]] implementation. - * @param options The options for this scan. Used to create fresh [[DataSourceReader]]. - * @param userSpecifiedSchema The user-specified schema for this scan. Used to create fresh - * [[DataSourceReader]]. + * @param options The options for this scan. Used to create fresh [[BatchWriteSupport]]. + * @param userSpecifiedSchema The user-specified schema for this scan. */ case class DataSourceV2Relation( source: DataSourceV2, + readSupport: BatchReadSupport, output: Seq[AttributeReference], options: Map[String, String], tableIdent: Option[TableIdentifier] = None, @@ -58,13 +58,12 @@ case class DataSourceV2Relation( override def simpleString: String = "RelationV2 " + metadataString - def newReader(): DataSourceReader = source.createReader(options, userSpecifiedSchema) + def newWriteSupport(): BatchWriteSupport = source.createWriteSupport(options, schema) - def newWriter(): DataSourceWriter = source.createWriter(options, schema) - - override def computeStats(): Statistics = newReader match { + override def computeStats(): Statistics = readSupport match { case r: SupportsReportStatistics => - Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + val statistics = r.estimateStatistics(readSupport.newScanConfigBuilder().build()) + Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } @@ -85,7 +84,8 @@ case class StreamingDataSourceV2Relation( output: Seq[AttributeReference], source: DataSourceV2, options: Map[String, String], - reader: DataSourceReader) + readSupport: ReadSupport, + scanConfigBuilder: ScanConfigBuilder) extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { override def isStreaming: Boolean = true @@ -99,7 +99,8 @@ case class StreamingDataSourceV2Relation( // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: StreamingDataSourceV2Relation => - output == other.output && reader.getClass == other.reader.getClass && options == other.options + output == other.output && readSupport.getClass == other.readSupport.getClass && + options == other.options case _ => false } @@ -107,9 +108,10 @@ case class StreamingDataSourceV2Relation( Seq(output, source, options).hashCode() } - override def computeStats(): Statistics = reader match { + override def computeStats(): Statistics = readSupport match { case r: SupportsReportStatistics => - Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + val statistics = r.estimateStatistics(scanConfigBuilder.build()) + Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } @@ -117,19 +119,19 @@ case class StreamingDataSourceV2Relation( object DataSourceV2Relation { private implicit class SourceHelpers(source: DataSourceV2) { - def asReadSupport: ReadSupport = { + def asReadSupportProvider: BatchReadSupportProvider = { source match { - case support: ReadSupport => - support + case provider: BatchReadSupportProvider => + provider case _ => throw new AnalysisException(s"Data source is not readable: $name") } } - def asWriteSupport: WriteSupport = { + def asWriteSupportProvider: BatchWriteSupportProvider = { source match { - case support: WriteSupport => - support + case provider: BatchWriteSupportProvider => + provider case _ => throw new AnalysisException(s"Data source is not writable: $name") } @@ -144,23 +146,26 @@ object DataSourceV2Relation { } } - def createReader( + def createReadSupport( options: Map[String, String], - userSpecifiedSchema: Option[StructType]): DataSourceReader = { + userSpecifiedSchema: Option[StructType]): BatchReadSupport = { val v2Options = new DataSourceOptions(options.asJava) userSpecifiedSchema match { case Some(s) => - asReadSupport.createReader(s, v2Options) + asReadSupportProvider.createBatchReadSupport(s, v2Options) case _ => - asReadSupport.createReader(v2Options) + asReadSupportProvider.createBatchReadSupport(v2Options) } } - def createWriter( + def createWriteSupport( options: Map[String, String], - schema: StructType): DataSourceWriter = { - val v2Options = new DataSourceOptions(options.asJava) - asWriteSupport.createWriter(UUID.randomUUID.toString, schema, SaveMode.Append, v2Options).get + schema: StructType): BatchWriteSupport = { + asWriteSupportProvider.createBatchWriteSupport( + UUID.randomUUID().toString, + schema, + SaveMode.Append, + new DataSourceOptions(options.asJava)).get } } @@ -169,15 +174,16 @@ object DataSourceV2Relation { options: Map[String, String], tableIdent: Option[TableIdentifier] = None, userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { - val reader = source.createReader(options, userSpecifiedSchema) + val readSupport = source.createReadSupport(options, userSpecifiedSchema) + val output = readSupport.fullSchema().toAttributes val ident = tableIdent.orElse(tableFromOptions(options)) DataSourceV2Relation( - source, reader.readSchema().toAttributes, options, ident, userSpecifiedSchema) + source, readSupport, output, options, ident, userSpecifiedSchema) } private def tableFromOptions(options: Map[String, String]): Option[TableIdentifier] = { options - .get(DataSourceOptions.TABLE_KEY) - .map(TableIdentifier(_, options.get(DataSourceOptions.DATABASE_KEY))) + .get(DataSourceOptions.TABLE_KEY) + .map(TableIdentifier(_, options.get(DataSourceOptions.DATABASE_KEY))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index c8494f97f176..04a97735d024 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -28,8 +26,7 @@ import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeSta import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader -import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReaderFactory, ContinuousReadSupport, MicroBatchReadSupport} /** * Physical plan node for scanning data from a data source. @@ -39,7 +36,8 @@ case class DataSourceV2ScanExec( @transient source: DataSourceV2, @transient options: Map[String, String], @transient pushedFilters: Seq[Expression], - @transient reader: DataSourceReader) + @transient readSupport: ReadSupport, + @transient scanConfig: ScanConfig) extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { override def simpleString: String = "ScanV2 " + metadataString @@ -47,7 +45,8 @@ case class DataSourceV2ScanExec( // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: DataSourceV2ScanExec => - output == other.output && reader.getClass == other.reader.getClass && options == other.options + output == other.output && readSupport.getClass == other.readSupport.getClass && + options == other.options case _ => false } @@ -55,36 +54,39 @@ case class DataSourceV2ScanExec( Seq(output, source, options).hashCode() } - override def outputPartitioning: physical.Partitioning = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchPartitions.size == 1 => - SinglePartition - - case r: SupportsScanColumnarBatch if !r.enableBatchRead() && partitions.size == 1 => - SinglePartition - - case r if !r.isInstanceOf[SupportsScanColumnarBatch] && partitions.size == 1 => + override def outputPartitioning: physical.Partitioning = readSupport match { + case _ if partitions.length == 1 => SinglePartition case s: SupportsReportPartitioning => new DataSourcePartitioning( - s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) + s.outputPartitioning(scanConfig), AttributeMap(output.map(a => a -> a.name))) case _ => super.outputPartitioning } - private lazy val partitions: Seq[InputPartition[InternalRow]] = { - reader.planInputPartitions().asScala + private lazy val partitions: Seq[InputPartition] = readSupport.planInputPartitions(scanConfig) + + private lazy val readerFactory = readSupport match { + case r: BatchReadSupport => r.createReaderFactory(scanConfig) + case r: MicroBatchReadSupport => r.createReaderFactory(scanConfig) + case r: ContinuousReadSupport => r.createContinuousReaderFactory(scanConfig) + case _ => throw new IllegalStateException("unknown read support: " + readSupport) } - private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() => - assert(!reader.isInstanceOf[ContinuousReader], - "continuous stream reader does not support columnar read yet.") - r.planBatchInputPartitions().asScala + // TODO: clean this up when we have dedicated scan plan for continuous streaming. + override val supportsBatch: Boolean = { + require(partitions.forall(readerFactory.supportColumnarReads) || + !partitions.exists(readerFactory.supportColumnarReads), + "Cannot mix row-based and columnar input partitions.") + + partitions.exists(readerFactory.supportColumnarReads) } - private lazy val inputRDD: RDD[InternalRow] = reader match { - case _: ContinuousReader => + private lazy val inputRDD: RDD[InternalRow] = readSupport match { + case _: ContinuousReadSupport => + assert(!supportsBatch, + "continuous stream reader does not support columnar read yet.") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) @@ -93,22 +95,17 @@ case class DataSourceV2ScanExec( sparkContext, sqlContext.conf.continuousStreamingExecutorQueueSize, sqlContext.conf.continuousStreamingExecutorPollIntervalMs, - partitions).asInstanceOf[RDD[InternalRow]] - - case r: SupportsScanColumnarBatch if r.enableBatchRead() => - new DataSourceRDD(sparkContext, batchPartitions).asInstanceOf[RDD[InternalRow]] + partitions, + schema, + readerFactory.asInstanceOf[ContinuousPartitionReaderFactory]) case _ => - new DataSourceRDD(sparkContext, partitions).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD( + sparkContext, partitions, readerFactory.asInstanceOf[PartitionReaderFactory], supportsBatch) } override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) - override val supportsBatch: Boolean = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() => true - case _ => false - } - override protected def needsUnsafeRowConversion: Boolean = false override protected def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 6daaa4c65c33..fe713ff6c785 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Rep import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport object DataSourceV2Strategy extends Strategy { @@ -37,9 +37,9 @@ object DataSourceV2Strategy extends Strategy { * @return pushed filter and post-scan filters. */ private def pushFilters( - reader: DataSourceReader, + configBuilder: ScanConfigBuilder, filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - reader match { + configBuilder match { case r: SupportsPushDownCatalystFilters => val postScanFilters = r.pushCatalystFilters(filters.toArray) val pushedFilters = r.pushedCatalystFilters() @@ -76,41 +76,43 @@ object DataSourceV2Strategy extends Strategy { /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * - * @return new output attributes after column pruning. + * @return the created `ScanConfig`(since column pruning is the last step of operator pushdown), + * and new output attributes after column pruning. */ // TODO: nested column pruning. private def pruneColumns( - reader: DataSourceReader, + configBuilder: ScanConfigBuilder, relation: DataSourceV2Relation, - exprs: Seq[Expression]): Seq[AttributeReference] = { - reader match { + exprs: Seq[Expression]): (ScanConfig, Seq[AttributeReference]) = { + configBuilder match { case r: SupportsPushDownRequiredColumns => val requiredColumns = AttributeSet(exprs.flatMap(_.references)) val neededOutput = relation.output.filter(requiredColumns.contains) if (neededOutput != relation.output) { r.pruneColumns(neededOutput.toStructType) + val config = r.build() val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap - r.readSchema().toAttributes.map { + config -> config.readSchema().toAttributes.map { // We have to keep the attribute id during transformation. a => a.withExprId(nameToAttr(a.name).exprId) } } else { - relation.output + r.build() -> relation.output } - case _ => relation.output + case _ => configBuilder.build() -> relation.output } } override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => - val reader = relation.newReader() + val configBuilder = relation.readSupport.newScanConfigBuilder() // `pushedFilters` will be pushed down and evaluated in the underlying data sources. // `postScanFilters` need to be evaluated after the scan. // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. - val (pushedFilters, postScanFilters) = pushFilters(reader, filters) - val output = pruneColumns(reader, relation, project ++ postScanFilters) + val (pushedFilters, postScanFilters) = pushFilters(configBuilder, filters) + val (config, output) = pruneColumns(configBuilder, relation, project ++ postScanFilters) logInfo( s""" |Pushing operators to ${relation.source.getClass} @@ -120,7 +122,12 @@ object DataSourceV2Strategy extends Strategy { """.stripMargin) val scan = DataSourceV2ScanExec( - output, relation.source, relation.options, pushedFilters, reader) + output, + relation.source, + relation.options, + pushedFilters, + relation.readSupport, + config) val filterCondition = postScanFilters.reduceLeftOption(And) val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) @@ -129,22 +136,26 @@ object DataSourceV2Strategy extends Strategy { ProjectExec(project, withFilter) :: Nil case r: StreamingDataSourceV2Relation => + // TODO: support operator pushdown for streaming data sources. + val scanConfig = r.scanConfigBuilder.build() // ensure there is a projection, which will produce unsafe rows required by some operators ProjectExec(r.output, - DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader)) :: Nil + DataSourceV2ScanExec( + r.output, r.source, r.options, r.pushedFilters, r.readSupport, scanConfig)) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil case AppendData(r: DataSourceV2Relation, query, _) => - WriteToDataSourceV2Exec(r.newWriter(), planLater(query)) :: Nil + WriteToDataSourceV2Exec(r.newWriteSupport(), planLater(query)) :: Nil case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil case Repartition(1, false, child) => - val isContinuous = child.collectFirst { - case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r + val isContinuous = child.find { + case s: StreamingDataSourceV2Relation => s.readSupport.isInstanceOf[ContinuousReadSupport] + case _ => false }.isDefined if (isContinuous) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 5267f5f1580c..e9cc3991155c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -21,6 +21,7 @@ import java.util.regex.Pattern import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport} private[sql] object DataSourceV2Utils extends Logging { @@ -55,4 +56,12 @@ private[sql] object DataSourceV2Utils extends Logging { case _ => Map.empty } + + def failForUserSpecifiedSchema[T](ds: DataSourceV2): T = { + val name = ds match { + case register: DataSourceRegister => register.shortName() + case _ => ds.getClass.getName + } + throw new UnsupportedOperationException(name + " source does not support user-specified schema") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 59ebb9bc5431..c3f7b690ef63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -23,15 +23,11 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.MicroBatchExecution import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils /** @@ -39,7 +35,8 @@ import org.apache.spark.util.Utils * specific logical plans, like [[org.apache.spark.sql.catalyst.plans.logical.AppendData]]. */ @deprecated("Use specific logical plans like AppendData instead", "2.4.0") -case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) extends LogicalPlan { +case class WriteToDataSourceV2(writeSupport: BatchWriteSupport, query: LogicalPlan) + extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } @@ -47,46 +44,48 @@ case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) ext /** * The physical plan for writing data into data source v2. */ -case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) extends SparkPlan { +case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: SparkPlan) + extends SparkPlan { + override def children: Seq[SparkPlan] = Seq(query) override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writeTask = writer.createWriterFactory() - val useCommitCoordinator = writer.useCommitCoordinator + val writerFactory = writeSupport.createBatchWriterFactory() + val useCommitCoordinator = writeSupport.useCommitCoordinator val rdd = query.execute() val messages = new Array[WriterCommitMessage](rdd.partitions.length) - logInfo(s"Start processing data source writer: $writer. " + + logInfo(s"Start processing data source write support: $writeSupport. " + s"The input RDD has ${messages.length} partitions.") try { sparkContext.runJob( rdd, (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator), + DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator), rdd.partitions.indices, (index, message: WriterCommitMessage) => { messages(index) = message - writer.onDataWriterCommit(message) + writeSupport.onDataWriterCommit(message) } ) - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") + logInfo(s"Data source write support $writeSupport is committing.") + writeSupport.commit(messages) + logInfo(s"Data source write support $writeSupport committed.") } catch { case cause: Throwable => - logError(s"Data source writer $writer is aborting.") + logError(s"Data source write support $writeSupport is aborting.") try { - writer.abort(messages) + writeSupport.abort(messages) } catch { case t: Throwable => - logError(s"Data source writer $writer failed to abort.") + logError(s"Data source write support $writeSupport failed to abort.") cause.addSuppressed(t) throw new SparkException("Writing job failed.", cause) } - logError(s"Data source writer $writer aborted.") + logError(s"Data source write support $writeSupport aborted.") cause match { // Only wrap non fatal exceptions. case NonFatal(e) => throw new SparkException("Writing job aborted.", e) @@ -100,7 +99,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e object DataWritingSparkTask extends Logging { def run( - writeTask: DataWriterFactory[InternalRow], + writerFactory: DataWriterFactory, context: TaskContext, iter: Iterator[InternalRow], useCommitCoordinator: Boolean): WriterCommitMessage = { @@ -109,8 +108,7 @@ object DataWritingSparkTask extends Logging { val partId = context.partitionId() val taskId = context.taskAttemptId() val attemptId = context.attemptNumber() - val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") - val dataWriter = writeTask.createDataWriter(partId, taskId, epochId.toLong) + val dataWriter = writerFactory.createWriter(partId, taskId) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index b1c91ac94b26..cf83ba7436d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.streaming -import java.util.Optional - import scala.collection.JavaConverters._ import scala.collection.mutable.{Map => MutableMap} @@ -28,9 +26,9 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} -import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} +import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport} +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -51,8 +49,8 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty - private val readerToDataSourceMap = - MutableMap.empty[MicroBatchReader, (DataSourceV2, Map[String, String])] + private val readSupportToDataSourceMap = + MutableMap.empty[MicroBatchReadSupport, (DataSourceV2, Map[String, String])] private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) @@ -91,20 +89,19 @@ class MicroBatchExecution( StreamingExecutionRelation(source, output)(sparkSession) }) case s @ StreamingRelationV2( - dataSourceV2: MicroBatchReadSupport, sourceName, options, output, _) if + dataSourceV2: MicroBatchReadSupportProvider, sourceName, options, output, _) if !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val reader = dataSourceV2.createMicroBatchReader( - Optional.empty(), // user specified schema + val readSupport = dataSourceV2.createMicroBatchReadSupport( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 - readerToDataSourceMap(reader) = dataSourceV2 -> options - logInfo(s"Using MicroBatchReader [$reader] from " + + readSupportToDataSourceMap(readSupport) = dataSourceV2 -> options + logInfo(s"Using MicroBatchReadSupport [$readSupport] from " + s"DataSourceV2 named '$sourceName' [$dataSourceV2]") - StreamingExecutionRelation(reader, output)(sparkSession) + StreamingExecutionRelation(readSupport, output)(sparkSession) }) case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { @@ -340,19 +337,19 @@ class MicroBatchExecution( reportTimeTaken("getOffset") { (s, s.getOffset) } - case s: MicroBatchReader => + case s: RateControlMicroBatchReadSupport => updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("setOffsetRange") { - // Once v1 streaming source execution is gone, we can refactor this away. - // For now, we set the range here to get the source to infer the available end offset, - // get that offset, and then set the range again when we later execute. - s.setOffsetRange( - toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), - Optional.empty()) + reportTimeTaken("latestOffset") { + val startOffset = availableOffsets + .get(s).map(off => s.deserializeOffset(off.json)) + .getOrElse(s.initialOffset()) + (s, Option(s.latestOffset(startOffset))) + } + case s: MicroBatchReadSupport => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("latestOffset") { + (s, Option(s.latestOffset())) } - - val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } - (s, Option(currentOffset)) }.toMap availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) @@ -392,8 +389,8 @@ class MicroBatchExecution( if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { case (src: Source, off) => src.commit(off) - case (reader: MicroBatchReader, off) => - reader.commit(reader.deserializeOffset(off.json)) + case (readSupport: MicroBatchReadSupport, off) => + readSupport.commit(readSupport.deserializeOffset(off.json)) case (src, _) => throw new IllegalArgumentException( s"Unknown source is found at constructNextBatch: $src") @@ -437,30 +434,34 @@ class MicroBatchExecution( s"${batch.queryExecution.logical}") logDebug(s"Retrieving data from $source: $current -> $available") Some(source -> batch.logicalPlan) - case (reader: MicroBatchReader, available) - if committedOffsets.get(reader).map(_ != available).getOrElse(true) => - val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) - val availableV2: OffsetV2 = available match { - case v1: SerializedOffset => reader.deserializeOffset(v1.json) + + // TODO(cloud-fan): for data source v2, the new batch is just a new `ScanConfigBuilder`, but + // to be compatible with streaming source v1, we return a logical plan as a new batch here. + case (readSupport: MicroBatchReadSupport, available) + if committedOffsets.get(readSupport).map(_ != available).getOrElse(true) => + val current = committedOffsets.get(readSupport).map { + off => readSupport.deserializeOffset(off.json) + } + val endOffset: OffsetV2 = available match { + case v1: SerializedOffset => readSupport.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - reader.setOffsetRange( - toJava(current), - Optional.of(availableV2)) - logDebug(s"Retrieving data from $reader: $current -> $availableV2") + val startOffset = current.getOrElse(readSupport.initialOffset) + val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset, endOffset) + logDebug(s"Retrieving data from $readSupport: $current -> $endOffset") - val (source, options) = reader match { + val (source, options) = readSupport match { // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2` // implementation. We provide a fake one here for explain. case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String] // Provide a fake value here just in case something went wrong, e.g. the reader gives // a wrong `equals` implementation. - case _ => readerToDataSourceMap.getOrElse(reader, { + case _ => readSupportToDataSourceMap.getOrElse(readSupport, { FakeDataSourceV2 -> Map.empty[String, String] }) } - Some(reader -> StreamingDataSourceV2Relation( - reader.readSchema().toAttributes, source, options, reader)) + Some(readSupport -> StreamingDataSourceV2Relation( + readSupport.fullSchema().toAttributes, source, options, readSupport, scanConfigBuilder)) case _ => None } } @@ -494,13 +495,13 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan - case s: StreamWriteSupport => - val writer = s.createStreamWriter( + case s: StreamingWriteSupportProvider => + val writer = s.createStreamingWriteSupport( s"$runId", newAttributePlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) + WriteToDataSourceV2(new MicroBatchWritSupport(currentBatchId, writer), newAttributePlan) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } @@ -526,7 +527,7 @@ class MicroBatchExecution( SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) - case _: StreamWriteSupport => + case _: StreamingWriteSupportProvider => // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } @@ -551,10 +552,6 @@ class MicroBatchExecution( awaitProgressLock.unlock() } } - - private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = { - Optional.ofNullable(scalaOption.orNull) - } } object MicroBatchExecution { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index ae1bfa2e499b..417b6b39366a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.control.NonFatal -import org.json4s.JsonAST.JValue import org.json4s.jackson.JsonMethods.parse import org.apache.spark.internal.Logging @@ -33,11 +32,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalP import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} -import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter +import org.apache.spark.sql.execution.streaming.sources.MicroBatchWritSupport import org.apache.spark.sql.sources.v2.CustomMetrics -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, SupportsCustomReaderMetrics} -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter -import org.apache.spark.sql.sources.v2.writer.streaming.SupportsCustomWriterMetrics +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, SupportsCustomReaderMetrics} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingWriteSupport, SupportsCustomWriterMetrics} import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -201,7 +199,7 @@ trait ProgressReporter extends Logging { ) } - val customWriterMetrics = dataSourceWriter match { + val customWriterMetrics = extractWriteSupport() match { case Some(s: SupportsCustomWriterMetrics) => extractMetrics(() => Option(s.getCustomMetrics), s.onInvalidMetrics) @@ -238,13 +236,13 @@ trait ProgressReporter extends Logging { } /** Extract writer from the executed query plan. */ - private def dataSourceWriter: Option[DataSourceWriter] = { + private def extractWriteSupport(): Option[StreamingWriteSupport] = { if (lastExecution == null) return None lastExecution.executedPlan.collect { case p if p.isInstanceOf[WriteToDataSourceV2Exec] => - p.asInstanceOf[WriteToDataSourceV2Exec].writer + p.asInstanceOf[WriteToDataSourceV2Exec].writeSupport }.headOption match { - case Some(w: MicroBatchWriter) => Some(w.writer) + case Some(w: MicroBatchWritSupport) => Some(w.writeSupport) case _ => None } } @@ -303,7 +301,7 @@ trait ProgressReporter extends Logging { // Check whether the streaming query's logical plan has only V2 data sources val allStreamingLeaves = logicalPlan.collect { case s: StreamingExecutionRelation => s } - allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReader] } + allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReadSupport] } } if (onlyDataSourceV2Sources) { @@ -330,7 +328,7 @@ trait ProgressReporter extends Logging { new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]() lastExecution.executedPlan.collectLeaves().foreach { - case s: DataSourceV2ScanExec if s.reader.isInstanceOf[BaseStreamingSource] => + case s: DataSourceV2ScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] => uniqueStreamingExecLeavesMap.put(s, s) case _ => } @@ -338,7 +336,7 @@ trait ProgressReporter extends Logging { val sourceToInputRowsTuples = uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf => val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) - val source = execLeaf.reader.asInstanceOf[BaseStreamingSource] + val source = execLeaf.readSupport.asInstanceOf[BaseStreamingSource] source -> numRows }.toSeq logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala new file mode 100644 index 000000000000..1be071614d92 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.sources.v2.reader.{ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.types.StructType + +/** + * A very simple [[ScanConfigBuilder]] implementation that creates a simple [[ScanConfig]] to + * carry schema and offsets for streaming data sources. + */ +class SimpleStreamingScanConfigBuilder( + schema: StructType, + start: Offset, + end: Option[Offset] = None) + extends ScanConfigBuilder { + + override def build(): ScanConfig = SimpleStreamingScanConfig(schema, start, end) +} + +case class SimpleStreamingScanConfig( + readSchema: StructType, + start: Offset, + end: Option[Offset]) + extends ScanConfig diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 24195b5657e8..4b696dfa5735 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceV2} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -83,7 +83,7 @@ case class StreamingExecutionRelation( // We have to pack in the V1 data source as a shim, for the case when a source implements // continuous processing (which is always V2) but only has V1 microbatch support. We don't -// know at read time whether the query is conntinuous or not, so we need to be able to +// know at read time whether the query is continuous or not, so we need to be able to // swap a V1 relation back in. /** * Used to link a [[DataSourceV2]] into a streaming @@ -113,7 +113,7 @@ case class StreamingRelationV2( * Used to link a [[DataSourceV2]] into a continuous processing execution. */ case class ContinuousExecutionRelation( - source: ContinuousReadSupport, + source: ContinuousReadSupportProvider, extraOptions: Map[String, String], output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index cfba1001c6de..9c5c16f4f5d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter +import org.apache.spark.sql.execution.streaming.sources.ConsoleWriteSupport import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -31,16 +31,16 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) } class ConsoleSinkProvider extends DataSourceV2 - with StreamWriteSupport + with StreamingWriteSupportProvider with DataSourceRegister with CreatableRelationProvider { - override def createStreamWriter( + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { - new ConsoleWriter(schema, options) + options: DataSourceOptions): StreamingWriteSupport = { + new ConsoleWriteSupport(schema, options) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index 554a0b0573f4..b68f67e0b22d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -21,12 +21,13 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousPartitionReaderFactory +import org.apache.spark.sql.types.StructType import org.apache.spark.util.NextIterator class ContinuousDataSourceRDDPartition( val index: Int, - val inputPartition: InputPartition[InternalRow]) + val inputPartition: InputPartition) extends Partition with Serializable { // This is semantically a lazy val - it's initialized once the first time a call to @@ -49,15 +50,22 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - private val readerInputPartitions: Seq[InputPartition[InternalRow]]) + private val inputPartitions: Seq[InputPartition], + schema: StructType, + partitionReaderFactory: ContinuousPartitionReaderFactory) extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerInputPartitions.zipWithIndex.map { + inputPartitions.zipWithIndex.map { case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition) }.toArray } + private def castPartition(split: Partition): ContinuousDataSourceRDDPartition = split match { + case p: ContinuousDataSourceRDDPartition => p + case _ => throw new SparkException(s"[BUG] Not a ContinuousDataSourceRDDPartition: $split") + } + /** * Initialize the shared reader for this partition if needed, then read rows from it until * it returns null to signal the end of the epoch. @@ -69,10 +77,12 @@ class ContinuousDataSourceRDD( } val readerForPartition = { - val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] + val partition = castPartition(split) if (partition.queueReader == null) { - partition.queueReader = - new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs) + val partitionReader = partitionReaderFactory.createReader( + partition.inputPartition) + partition.queueReader = new ContinuousQueuedDataReader( + partition.index, partitionReader, schema, context, dataQueueSize, epochPollIntervalMs) } partition.queueReader @@ -93,17 +103,6 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[ContinuousDataSourceRDDPartition].inputPartition.preferredLocations() - } -} - -object ContinuousDataSourceRDD { - private[continuous] def getContinuousReader( - reader: InputPartitionReader[InternalRow]): ContinuousInputPartitionReader[_] = { - reader match { - case r: ContinuousInputPartitionReader[InternalRow] => r - case _ => - throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") - } + castPartition(split).inputPartition.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 140cec64fffb..4ddebb33b79d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,13 +29,12 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} -import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} class ContinuousExecution( @@ -43,7 +42,7 @@ class ContinuousExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: StreamWriteSupport, + sink: StreamingWriteSupportProvider, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -53,7 +52,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = Seq() + @volatile protected var continuousSources: Seq[ContinuousReadSupport] = Seq() override protected def sources: Seq[BaseStreamingSource] = continuousSources // For use only in test harnesses. @@ -63,7 +62,8 @@ class ContinuousExecution( val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() analyzedPlan.transform { case r @ StreamingRelationV2( - source: ContinuousReadSupport, _, extraReaderOptions, output, _) => + source: ContinuousReadSupportProvider, _, extraReaderOptions, output, _) => + // TODO: shall we create `ContinuousReadSupport` here instead of each reconfiguration? toExecutionRelationMap.getOrElseUpdate(r, { ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession) }) @@ -148,8 +148,7 @@ class ContinuousExecution( val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" nextSourceId += 1 - dataSource.createContinuousReader( - java.util.Optional.empty[StructType](), + dataSource.createContinuousReadSupport( metadataPath, new DataSourceOptions(extraReaderOptions.asJava)) } @@ -160,9 +159,9 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { case ContinuousExecutionRelation(source, options, output) => - val reader = continuousSources(insertedSourceId) + val readSupport = continuousSources(insertedSourceId) insertedSourceId += 1 - val newOutput = reader.readSchema().toAttributes + val newOutput = readSupport.fullSchema().toAttributes assert(output.size == newOutput.size, s"Invalid reader: ${Utils.truncatedString(output, ",")} != " + @@ -170,9 +169,10 @@ class ContinuousExecution( replacements ++= output.zip(newOutput) val loggedOffset = offsets.offsets(0) - val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) - reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - StreamingDataSourceV2Relation(newOutput, source, options, reader) + val realOffset = loggedOffset.map(off => readSupport.deserializeOffset(off.json)) + val startOffset = realOffset.getOrElse(readSupport.initialOffset) + val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset) + StreamingDataSourceV2Relation(newOutput, source, options, readSupport, scanConfigBuilder) } // Rewire the plan to use the new attributes that were returned by the source. @@ -185,17 +185,13 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - val writer = sink.createStreamWriter( + val writer = sink.createStreamingWriteSupport( s"$runId", triggerLogicalPlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan) - val reader = withSink.collect { - case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r - }.head - reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionForQuery, @@ -208,6 +204,11 @@ class ContinuousExecution( lastExecution.executedPlan // Force the lazy generation of execution plan } + val (readSupport, scanConfig) = lastExecution.executedPlan.collect { + case scan: DataSourceV2ScanExec if scan.readSupport.isInstanceOf[ContinuousReadSupport] => + scan.readSupport.asInstanceOf[ContinuousReadSupport] -> scan.scanConfig + }.head + sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString) // Add another random ID on top of the run ID, to distinguish epoch coordinators across @@ -223,14 +224,16 @@ class ContinuousExecution( // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - writer, reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) + writer, readSupport, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { + val shouldReconfigure = readSupport.needsReconfiguration(scanConfig) && + state.compareAndSet(ACTIVE, RECONFIGURING) + if (shouldReconfigure) { if (queryExecutionThread.isAlive) { queryExecutionThread.interrupt() } @@ -280,10 +283,12 @@ class ContinuousExecution( * Report ending partition offsets for the given reader at the given epoch. */ def addOffset( - epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { + epoch: Long, + readSupport: ContinuousReadSupport, + partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) + val globalOffset = readSupport.mergeOffsets(partitionOffsets.toArray) val oldOffset = synchronized { offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) offsetLog.get(epoch - 1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index ec1dabd7da3e..65c5fc63c2f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -25,8 +25,9 @@ import scala.util.control.NonFatal import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} -import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, PartitionOffset} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.ThreadUtils /** @@ -37,15 +38,14 @@ import org.apache.spark.util.ThreadUtils * offsets across epochs. Each compute() should call the next() method here until null is returned. */ class ContinuousQueuedDataReader( - partition: ContinuousDataSourceRDDPartition, + partitionIndex: Int, + reader: ContinuousPartitionReader[InternalRow], + schema: StructType, context: TaskContext, dataQueueSize: Int, epochPollIntervalMs: Long) extends Closeable { - private val reader = partition.inputPartition.createPartitionReader() - // Important sequencing - we must get our starting point before the provider threads start running - private var currentOffset: PartitionOffset = - ContinuousDataSourceRDD.getContinuousReader(reader).getOffset + private var currentOffset: PartitionOffset = reader.getOffset /** * The record types in the read buffer. @@ -66,7 +66,7 @@ class ContinuousQueuedDataReader( epochMarkerExecutor.scheduleWithFixedDelay( epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) - private val dataReaderThread = new DataReaderThread + private val dataReaderThread = new DataReaderThread(schema) dataReaderThread.setDaemon(true) dataReaderThread.start() @@ -113,7 +113,7 @@ class ContinuousQueuedDataReader( currentEntry match { case EpochMarker => epochCoordEndpoint.send(ReportPartitionOffset( - partition.index, EpochTracker.getCurrentEpoch.get, currentOffset)) + partitionIndex, EpochTracker.getCurrentEpoch.get, currentOffset)) null case ContinuousRow(row, offset) => currentOffset = offset @@ -128,16 +128,16 @@ class ContinuousQueuedDataReader( /** * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when - * a new row arrives to the [[InputPartitionReader]]. + * a new row arrives to the [[ContinuousPartitionReader]]. */ - class DataReaderThread extends Thread( + class DataReaderThread(schema: StructType) extends Thread( s"continuous-reader--${context.partitionId()}--" + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging { @volatile private[continuous] var failureReason: Throwable = _ + private val toUnsafe = UnsafeProjection.create(schema) override def run(): Unit = { TaskContext.setTaskContext(context) - val baseReader = ContinuousDataSourceRDD.getContinuousReader(reader) try { while (!shouldStop()) { if (!reader.next()) { @@ -149,8 +149,9 @@ class ContinuousQueuedDataReader( return } } - - queue.put(ContinuousRow(reader.get().copy(), baseReader.getOffset)) + // `InternalRow#copy` may not be properly implemented, for safety we convert to unsafe row + // before copy here. + queue.put(ContinuousRow(toUnsafe(reader.get()).copy(), reader.getOffset)) } } catch { case _: InterruptedException => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 551e07c3db86..a6cde2b8a710 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -17,24 +17,22 @@ package org.apache.spark.sql.execution.streaming.continuous -import scala.collection.JavaConverters._ - import org.json4s.DefaultFormats import org.json4s.jackson.Serialization import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousReader { +class RateStreamContinuousReadSupport(options: DataSourceOptions) extends ContinuousReadSupport { implicit val defaultFormats: DefaultFormats = DefaultFormats val creationTime = System.currentTimeMillis() @@ -56,18 +54,18 @@ class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousR RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateStreamProvider.SCHEMA - - private var offset: Offset = _ + override def fullSchema(): StructType = RateStreamProvider.SCHEMA - override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start) } - override def getStartOffset(): Offset = offset + override def initialOffset: Offset = createInitialOffset(numPartitions, creationTime) - override def planInputPartitions(): java.util.List[InputPartition[InternalRow]] = { - val partitionStartMap = offset match { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffset = config.asInstanceOf[SimpleStreamingScanConfig].start + + val partitionStartMap = startOffset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => throw new IllegalArgumentException( @@ -90,8 +88,12 @@ class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousR i, numPartitions, perPartitionRate) - .asInstanceOf[InputPartition[InternalRow]] - }.asJava + }.toArray + } + + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + RateStreamContinuousReaderFactory } override def commit(end: Offset): Unit = {} @@ -118,33 +120,23 @@ case class RateStreamContinuousInputPartition( partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousInputPartition[InternalRow] { - - override def createContinuousReader( - offset: PartitionOffset): InputPartitionReader[InternalRow] = { - val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] - require(rateStreamOffset.partition == partitionIndex, - s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") - new RateStreamContinuousInputPartitionReader( - rateStreamOffset.currentValue, - rateStreamOffset.currentTimeMs, - partitionIndex, - increment, - rowsPerSecond) - } + extends InputPartition - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new RateStreamContinuousInputPartitionReader( - startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) +object RateStreamContinuousReaderFactory extends ContinuousPartitionReaderFactory { + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[RateStreamContinuousInputPartition] + new RateStreamContinuousPartitionReader( + p.startValue, p.startTimeMs, p.partitionIndex, p.increment, p.rowsPerSecond) + } } -class RateStreamContinuousInputPartitionReader( +class RateStreamContinuousPartitionReader( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousInputPartitionReader[InternalRow] { + extends ContinuousPartitionReader[InternalRow] { private var nextReadTime: Long = startTimeMs private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala index 1dbdfd558de4..28ab2448a663 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql.execution.streaming.continuous import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.sql.Timestamp -import java.util.{Calendar, List => JList} +import java.util.Calendar import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.json4s.{DefaultFormats, NoTypeHints} @@ -34,24 +33,26 @@ import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.streaming.{ContinuousRecordEndpoint, ContinuousRecordPartitionOffset, GetRecord} +import org.apache.spark.sql.execution.streaming.{Offset => _, _} import org.apache.spark.sql.execution.streaming.sources.TextSocketReader import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsDeprecatedScanRow} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils /** - * A ContinuousReader that reads text lines through a TCP socket, designed only for tutorials and - * debugging. This ContinuousReader will *not* work in production applications due to multiple - * reasons, including no support for fault recovery. + * A ContinuousReadSupport that reads text lines through a TCP socket, designed only for tutorials + * and debugging. This ContinuousReadSupport will *not* work in production applications due to + * multiple reasons, including no support for fault recovery. * * The driver maintains a socket connection to the host-port, keeps the received messages in * buckets and serves the messages to the executors via a RPC endpoint. */ -class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousReader with Logging { +class TextSocketContinuousReadSupport(options: DataSourceOptions) + extends ContinuousReadSupport with Logging { + implicit val defaultFormats: DefaultFormats = DefaultFormats private val host: String = options.get("host").get() @@ -73,7 +74,8 @@ class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousR @GuardedBy("this") private var currentOffset: Int = -1 - private var startOffset: TextSocketOffset = _ + // Exposed for tests. + private[spark] var startOffset: TextSocketOffset = _ private val recordEndpoint = new ContinuousRecordEndpoint(buckets, this) @volatile private var endpointRef: RpcEndpointRef = _ @@ -94,16 +96,16 @@ class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousR TextSocketOffset(Serialization.read[List[Int]](json)) } - override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.startOffset = offset - .orElse(TextSocketOffset(List.fill(numPartitions)(0))) - .asInstanceOf[TextSocketOffset] - recordEndpoint.setStartOffsets(startOffset.offsets) + override def initialOffset(): Offset = { + startOffset = TextSocketOffset(List.fill(numPartitions)(0)) + startOffset } - override def getStartOffset: Offset = startOffset + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start) + } - override def readSchema(): StructType = { + override def fullSchema(): StructType = { if (includeTimestamp) { TextSocketReader.SCHEMA_TIMESTAMP } else { @@ -111,8 +113,10 @@ class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousR } } - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { - + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffset = config.asInstanceOf[SimpleStreamingScanConfig] + .start.asInstanceOf[TextSocketOffset] + recordEndpoint.setStartOffsets(startOffset.offsets) val endpointName = s"TextSocketContinuousReaderEndpoint-${java.util.UUID.randomUUID()}" endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) @@ -132,10 +136,13 @@ class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousR startOffset.offsets.zipWithIndex.map { case (offset, i) => - TextSocketContinuousInputPartition( - endpointName, i, offset, includeTimestamp): InputPartition[InternalRow] - }.asJava + TextSocketContinuousInputPartition(endpointName, i, offset, includeTimestamp) + }.toArray + } + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + TextSocketReaderFactory } override def commit(end: Offset): Unit = synchronized { @@ -190,7 +197,7 @@ class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousR logWarning(s"Stream closed by $host:$port") return } - TextSocketContinuousReader.this.synchronized { + TextSocketContinuousReadSupport.this.synchronized { currentOffset += 1 val newData = (line, Timestamp.valueOf( @@ -221,25 +228,30 @@ case class TextSocketContinuousInputPartition( driverEndpointName: String, partitionId: Int, startOffset: Int, - includeTimestamp: Boolean) -extends InputPartition[InternalRow] { + includeTimestamp: Boolean) extends InputPartition + + +object TextSocketReaderFactory extends ContinuousPartitionReaderFactory { - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new TextSocketContinuousInputPartitionReader(driverEndpointName, partitionId, startOffset, - includeTimestamp) + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[TextSocketContinuousInputPartition] + new TextSocketContinuousPartitionReader( + p.driverEndpointName, p.partitionId, p.startOffset, p.includeTimestamp) + } } + /** * Continuous text socket input partition reader. * * Polls the driver endpoint for new records. */ -class TextSocketContinuousInputPartitionReader( +class TextSocketContinuousPartitionReader( driverEndpointName: String, partitionId: Int, startOffset: Int, includeTimestamp: Boolean) - extends ContinuousInputPartitionReader[InternalRow] { + extends ContinuousPartitionReader[InternalRow] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index 967dbe24a370..a08411d746ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory} +import org.apache.spark.sql.sources.v2.writer.DataWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory import org.apache.spark.util.Utils /** @@ -31,7 +32,7 @@ import org.apache.spark.util.Utils * * We keep repeating prev.compute() and writing new epochs until the query is shut down. */ -class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow]) +class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDataWriterFactory) extends RDD[Unit](prev) { override val partitioner = prev.partitioner @@ -50,7 +51,7 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor Utils.tryWithSafeFinallyAndFailureCallbacks(block = { try { val dataIterator = prev.compute(split, context) - dataWriter = writeTask.createDataWriter( + dataWriter = writerFactory.createWriter( context.partitionId(), context.taskAttemptId(), EpochTracker.getCurrentEpoch.get) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 8877ebeb2673..2238ce26e7b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -23,9 +23,9 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable @@ -82,15 +82,15 @@ private[sql] object EpochCoordinatorRef extends Logging { * Create a reference to a new [[EpochCoordinator]]. */ def create( - writer: StreamWriter, - reader: ContinuousReader, + writeSupport: StreamingWriteSupport, + readSupport: ContinuousReadSupport, query: ContinuousExecution, epochCoordinatorId: String, startEpoch: Long, session: SparkSession, env: SparkEnv): RpcEndpointRef = synchronized { val coordinator = new EpochCoordinator( - writer, reader, query, startEpoch, session, env.rpcEnv) + writeSupport, readSupport, query, startEpoch, session, env.rpcEnv) val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator) logInfo("Registered EpochCoordinator endpoint") ref @@ -115,8 +115,8 @@ private[sql] object EpochCoordinatorRef extends Logging { * have both committed and reported an end offset for a given epoch. */ private[continuous] class EpochCoordinator( - writer: StreamWriter, - reader: ContinuousReader, + writeSupport: StreamingWriteSupport, + readSupport: ContinuousReadSupport, query: ContinuousExecution, startEpoch: Long, session: SparkSession, @@ -198,7 +198,7 @@ private[continuous] class EpochCoordinator( s"and is ready to be committed. Committing epoch $epoch.") // Sequencing is important here. We must commit to the writer before recording the commit // in the query, or we will end up dropping the commit if we restart in the middle. - writer.commit(epoch, messages.toArray) + writeSupport.commit(epoch, messages.toArray) query.commit(epoch) } @@ -220,7 +220,7 @@ private[continuous] class EpochCoordinator( partitionOffsets.collect { case ((e, _), o) if e == epoch => o } if (thisEpochOffsets.size == numReaderPartitions) { logDebug(s"Epoch $epoch has offsets reported from all partitions: $thisEpochOffsets") - query.addOffset(epoch, reader, thisEpochOffsets.toSeq) + query.addOffset(epoch, readSupport, thisEpochOffsets.toSeq) resolveCommitsAtEpoch(epoch) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala index 943c731a7052..7ad21cc304e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport /** * The logical plan for writing data in a continuous stream. */ case class WriteToContinuousDataSource( - writer: StreamWriter, query: LogicalPlan) extends LogicalPlan { + writeSupport: StreamingWriteSupport, query: LogicalPlan) extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index 927d3a84e296..c216b6138385 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -26,21 +26,21 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport /** - * The physical plan for writing data into a continuous processing [[StreamWriter]]. + * The physical plan for writing data into a continuous processing [[StreamingWriteSupport]]. */ -case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPlan) +case class WriteToContinuousDataSourceExec(writeSupport: StreamingWriteSupport, query: SparkPlan) extends SparkPlan with Logging { override def children: Seq[SparkPlan] = Seq(query) override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writerFactory = writer.createWriterFactory() + val writerFactory = writeSupport.createStreamingWriterFactory() val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) - logInfo(s"Start processing data source writer: $writer. " + + logInfo(s"Start processing data source write support: $writeSupport. " + s"The input RDD has ${rdd.partitions.length} partitions.") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index f81abdcc3711..adf52aba21a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -17,12 +17,9 @@ package org.apache.spark.sql.execution.streaming -import java.{util => ju} -import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.NonFatal @@ -34,8 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -67,7 +64,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas addData(data.toTraversable) } - def readSchema(): StructType = encoder.schema + def fullSchema(): StructType = encoder.schema protected def logicalPlan: LogicalPlan @@ -80,7 +77,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MemoryStreamBase[A](sqlContext) with MicroBatchReader with Logging { + extends MemoryStreamBase[A](sqlContext) with MicroBatchReadSupport with Logging { protected val logicalPlan: LogicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) @@ -122,24 +119,22 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" - override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { - synchronized { - startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset] - endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset] - } - } - override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) - override def getStartOffset: OffsetV2 = synchronized { - if (startOffset.offset == -1) null else startOffset + override def initialOffset: OffsetV2 = LongOffset(-1) + + override def latestOffset(): OffsetV2 = { + if (currentOffset.offset == -1) null else currentOffset } - override def getEndOffset: OffsetV2 = synchronized { - if (endOffset.offset == -1) null else endOffset + override def newScanConfigBuilder(start: OffsetV2, end: OffsetV2): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) } - override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startOffset = sc.start.asInstanceOf[LongOffset] + val endOffset = sc.end.get.asInstanceOf[LongOffset] synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 @@ -156,11 +151,15 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamInputPartition(block): InputPartition[InternalRow] - }.asJava + new MemoryStreamInputPartition(block) + }.toArray } } + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + MemoryStreamReaderFactory + } + private def generateDebugString( rows: Seq[UnsafeRow], startOrdinal: Int, @@ -201,10 +200,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } -class MemoryStreamInputPartition(records: Array[UnsafeRow]) - extends InputPartition[InternalRow] { - override def createPartitionReader(): InputPartitionReader[InternalRow] = { - new InputPartitionReader[InternalRow] { +class MemoryStreamInputPartition(val records: Array[UnsafeRow]) extends InputPartition + +object MemoryStreamReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val records = partition.asInstanceOf[MemoryStreamInputPartition].records + new PartitionReader[InternalRow] { private var currentIndex = -1 override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala similarity index 86% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala index fd45ba509091..833e62f35ede 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala @@ -19,16 +19,15 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ -class ConsoleWriter(schema: StructType, options: DataSourceOptions) - extends StreamWriter with Logging { +class ConsoleWriteSupport(schema: StructType, options: DataSourceOptions) + extends StreamingWriteSupport with Logging { // Number of rows to display, by default 20 rows protected val numRowsToShow = options.getInt("numRows", 20) @@ -39,7 +38,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) assert(SparkSession.getActiveSession.isDefined) protected val spark = SparkSession.getActiveSession.get - def createWriterFactory(): DataWriterFactory[InternalRow] = PackedRowWriterFactory + def createStreamingWriterFactory(): StreamingDataWriterFactory = PackedRowWriterFactory override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index 4a32217f149b..dbcc4483e577 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -17,26 +17,22 @@ package org.apache.spark.sql.execution.streaming.sources -import java.{util => ju} -import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.json4s.NoTypeHints import org.json4s.jackson.Serialization import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.{Encoder, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.InputPartition -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.streaming.{Offset => _, _} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.util.RpcUtils /** @@ -48,7 +44,9 @@ import org.apache.spark.util.RpcUtils * the specified offset within the list, or null if that offset doesn't yet have a record. */ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) - extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { + extends MemoryStreamBase[A](sqlContext) + with ContinuousReadSupportProvider with ContinuousReadSupport { + private implicit val formats = Serialization.formats(NoTypeHints) protected val logicalPlan = @@ -59,9 +57,6 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa @GuardedBy("this") private val records = Seq.fill(numPartitions)(new ListBuffer[A]) - @GuardedBy("this") - private var startOffset: ContinuousMemoryStreamOffset = _ - private val recordEndpoint = new ContinuousRecordEndpoint(records, this) @volatile private var endpointRef: RpcEndpointRef = _ @@ -75,15 +70,8 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) } - override def setStartOffset(start: Optional[Offset]): Unit = synchronized { - // Inferred initial offset is position 0 in each partition. - startOffset = start.orElse { - ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) - }.asInstanceOf[ContinuousMemoryStreamOffset] - } - - override def getStartOffset: Offset = synchronized { - startOffset + override def initialOffset(): Offset = { + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) } override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = { @@ -98,34 +86,40 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa ) } - override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start) + } + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffset = config.asInstanceOf[SimpleStreamingScanConfig] + .start.asInstanceOf[ContinuousMemoryStreamOffset] synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) startOffset.partitionNums.map { - case (part, index) => - new ContinuousMemoryStreamInputPartition( - endpointName, part, index): InputPartition[InternalRow] - }.toList.asJava + case (part, index) => ContinuousMemoryStreamInputPartition(endpointName, part, index) + }.toArray } } + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + ContinuousMemoryStreamReaderFactory + } + override def stop(): Unit = { if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) } override def commit(end: Offset): Unit = {} - // ContinuousReadSupport implementation + // ContinuousReadSupportProvider implementation // This is necessary because of how StreamTest finds the source for AddDataMemory steps. - def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = { - this - } + options: DataSourceOptions): ContinuousReadSupport = this } object ContinuousMemoryStream { @@ -141,12 +135,16 @@ object ContinuousMemoryStream { /** * An input partition for continuous memory stream. */ -class ContinuousMemoryStreamInputPartition( +case class ContinuousMemoryStreamInputPartition( driverEndpointName: String, partition: Int, - startOffset: Int) extends InputPartition[InternalRow] { - override def createPartitionReader: ContinuousMemoryStreamInputPartitionReader = - new ContinuousMemoryStreamInputPartitionReader(driverEndpointName, partition, startOffset) + startOffset: Int) extends InputPartition + +object ContinuousMemoryStreamReaderFactory extends ContinuousPartitionReaderFactory { + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[ContinuousMemoryStreamInputPartition] + new ContinuousMemoryStreamPartitionReader(p.driverEndpointName, p.partition, p.startOffset) + } } /** @@ -154,10 +152,10 @@ class ContinuousMemoryStreamInputPartition( * * Polls the driver endpoint for new records. */ -class ContinuousMemoryStreamInputPartitionReader( +class ContinuousMemoryStreamPartitionReader( driverEndpointName: String, partition: Int, - startOffset: Int) extends ContinuousInputPartitionReader[InternalRow] { + startOffset: Int) extends ContinuousPartitionReader[InternalRow] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, SparkEnv.get.conf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala similarity index 82% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala index e8ce21cc1204..4218fd51ad20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala @@ -22,9 +22,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -37,20 +37,21 @@ import org.apache.spark.sql.types.StructType * a [[ExpressionEncoder]] or a direct converter function. * @tparam T The expected type of the sink. */ -case class ForeachWriterProvider[T]( +case class ForeachWriteSupportProvider[T]( writer: ForeachWriter[T], - converter: Either[ExpressionEncoder[T], InternalRow => T]) extends StreamWriteSupport { + converter: Either[ExpressionEncoder[T], InternalRow => T]) + extends StreamingWriteSupportProvider { - override def createStreamWriter( + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { - new StreamWriter { + options: DataSourceOptions): StreamingWriteSupport = { + new StreamingWriteSupport { override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def createWriterFactory(): DataWriterFactory[InternalRow] = { + override def createStreamingWriterFactory(): StreamingDataWriterFactory = { val rowConverter: InternalRow => T = converter match { case Left(enc) => val boundEnc = enc.resolveAndBind( @@ -68,16 +69,16 @@ case class ForeachWriterProvider[T]( } } -object ForeachWriterProvider { +object ForeachWriteSupportProvider { def apply[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]): ForeachWriterProvider[_] = { + encoder: ExpressionEncoder[T]): ForeachWriteSupportProvider[_] = { writer match { case pythonWriter: PythonForeachWriter => - new ForeachWriterProvider[UnsafeRow]( + new ForeachWriteSupportProvider[UnsafeRow]( pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) case _ => - new ForeachWriterProvider[T](writer, Left(encoder)) + new ForeachWriteSupportProvider[T](writer, Left(encoder)) } } } @@ -85,8 +86,8 @@ object ForeachWriterProvider { case class ForeachWriterFactory[T]( writer: ForeachWriter[T], rowConverter: InternalRow => T) - extends DataWriterFactory[InternalRow] { - override def createDataWriter( + extends StreamingDataWriterFactory { + override def createWriter( partitionId: Int, taskId: Long, epochId: Long): ForeachDataWriter[T] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala new file mode 100644 index 000000000000..9f88416871f8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.writer.{BatchWriteSupport, DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} + +/** + * A [[BatchWriteSupport]] used to hook V2 stream writers into a microbatch plan. It implements + * the non-streaming interface, forwarding the epoch ID determined at construction to a wrapped + * streaming write support. + */ +class MicroBatchWritSupport(eppchId: Long, val writeSupport: StreamingWriteSupport) + extends BatchWriteSupport { + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + writeSupport.commit(eppchId, messages) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + writeSupport.abort(eppchId, messages) + } + + override def createBatchWriterFactory(): DataWriterFactory = { + new MicroBatchWriterFactory(eppchId, writeSupport.createStreamingWriterFactory()) + } +} + +class MicroBatchWriterFactory(epochId: Long, streamingWriterFactory: StreamingDataWriterFactory) + extends DataWriterFactory { + + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + streamingWriterFactory.createWriter(partitionId, taskId, epochId) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index f26e11d842b2..ac3c71cc222b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -21,17 +21,18 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{BatchWriteSupport, DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory /** * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery - * to a [[DataSourceWriter]] on the driver. + * to a [[BatchWriteSupport]] on the driver. * * Note that, because it sends all rows to the driver, this factory will generally be unsuitable * for production-quality sinks. It's intended for use in tests. */ -case object PackedRowWriterFactory extends DataWriterFactory[InternalRow] { - override def createDataWriter( +case object PackedRowWriterFactory extends StreamingDataWriterFactory { + override def createWriter( partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala similarity index 50% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala index 2d43a7bb7787..90680ea38fbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala @@ -17,21 +17,15 @@ package org.apache.spark.sql.execution.streaming.sources -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} -/** - * A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements - * the non-streaming interface, forwarding the batch ID determined at construction to a wrapped - * streaming writer. - */ -class MicroBatchWriter(batchId: Long, val writer: StreamWriter) extends DataSourceWriter { - override def commit(messages: Array[WriterCommitMessage]): Unit = { - writer.commit(batchId, messages) - } +// A special `MicroBatchReadSupport` that can get latestOffset with a start offset. +trait RateControlMicroBatchReadSupport extends MicroBatchReadSupport { - override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) + override def latestOffset(): Offset = { + throw new IllegalAccessException( + "latestOffset should not be called for RateControlMicroBatchReadSupport") + } - override def createWriterFactory(): DataWriterFactory[InternalRow] = writer.createWriterFactory() + def latestOffset(start: Offset): Offset } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala similarity index 78% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala index 9e0d95493216..f5364047adff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala @@ -19,27 +19,24 @@ package org.apache.spark.sql.execution.streaming.sources import java.io._ import java.nio.charset.StandardCharsets -import java.util.Optional import java.util.concurrent.TimeUnit -import scala.collection.JavaConverters._ - import org.apache.commons.io.IOUtils import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ManualClock, SystemClock} -class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReader with Logging { +class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchReadSupport with Logging { import RateStreamProvider._ private[sources] val clock = { @@ -106,38 +103,30 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: @volatile private var lastTimeMs: Long = creationTimeMs - private var start: LongOffset = _ - private var end: LongOffset = _ - - override def readSchema(): StructType = SCHEMA + override def initialOffset(): Offset = LongOffset(0L) - override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { - this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] - this.end = end.orElse { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) - }.asInstanceOf[LongOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end + override def latestOffset(): Offset = { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) } override def deserializeOffset(json: String): Offset = { LongOffset(json.toLong) } - override def planInputPartitions(): java.util.List[InputPartition[InternalRow]] = { - val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + override def fullSchema(): StructType = SCHEMA + + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + } + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startSeconds = sc.start.asInstanceOf[LongOffset].offset + val endSeconds = sc.end.get.asInstanceOf[LongOffset].offset assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") if (endSeconds > maxSeconds) { throw new ArithmeticException("Integer overflow. Max offset with " + @@ -153,7 +142,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") if (rangeStart == rangeEnd) { - return List.empty.asJava + return Array.empty } val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) @@ -170,8 +159,11 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: (0 until numPartitions).map { p => new RateStreamMicroBatchInputPartition( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - : InputPartition[InternalRow] - }.toList.asJava + }.toArray + } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + RateStreamMicroBatchReaderFactory } override def commit(end: Offset): Unit = {} @@ -183,26 +175,29 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" } -class RateStreamMicroBatchInputPartition( +case class RateStreamMicroBatchInputPartition( partitionId: Int, numPartitions: Int, rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends InputPartition[InternalRow] { + relativeMsPerValue: Double) extends InputPartition - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new RateStreamMicroBatchInputPartitionReader(partitionId, numPartitions, rangeStart, rangeEnd, - localStartTimeMs, relativeMsPerValue) +object RateStreamMicroBatchReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val p = partition.asInstanceOf[RateStreamMicroBatchInputPartition] + new RateStreamMicroBatchPartitionReader(p.partitionId, p.numPartitions, p.rangeStart, + p.rangeEnd, p.localStartTimeMs, p.relativeMsPerValue) + } } -class RateStreamMicroBatchInputPartitionReader( +class RateStreamMicroBatchPartitionReader( partitionId: Int, numPartitions: Int, rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends InputPartitionReader[InternalRow] { + relativeMsPerValue: Double) extends PartitionReader[InternalRow] { private var count: Long = 0 override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 6bdd492f0cb3..6942dfbfe0ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -17,14 +17,11 @@ package org.apache.spark.sql.execution.streaming.sources -import java.util.Optional - import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReadSupport import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types._ /** @@ -42,13 +39,12 @@ import org.apache.spark.sql.types._ * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ class RateStreamProvider extends DataSourceV2 - with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { + with MicroBatchReadSupportProvider with ContinuousReadSupportProvider with DataSourceRegister { import RateStreamProvider._ - override def createMicroBatchReader( - schema: Optional[StructType], + override def createMicroBatchReadSupport( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { + options: DataSourceOptions): MicroBatchReadSupport = { if (options.get(ROWS_PER_SECOND).isPresent) { val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong if (rowsPerSecond <= 0) { @@ -74,17 +70,14 @@ class RateStreamProvider extends DataSourceV2 } } - if (schema.isPresent) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - new RateStreamMicroBatchReader(options, checkpointLocation) + new RateStreamMicroBatchReadSupport(options, checkpointLocation) } - override def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) + options: DataSourceOptions): ContinuousReadSupport = { + new RateStreamContinuousReadSupport(options) + } override def shortName(): String = "rate" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 2a5d21f33054..2509450f0da9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -35,9 +35,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions, DataSourceV2, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamWriter, SupportsCustomWriterMetrics} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport, SupportsCustomWriterMetrics} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -45,13 +45,15 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkBase with Logging { - override def createStreamWriter( +class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider + with MemorySinkBase with Logging { + + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, mode, schema) + options: DataSourceOptions): StreamingWriteSupport = { + new MemoryStreamingWriteSupport(this, mode, schema) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -132,35 +134,15 @@ class MemoryV2CustomMetrics(sink: MemorySinkV2) extends CustomMetrics { override def json(): String = Serialization.write(Map("numRows" -> sink.numRows)) } -class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode, schema: StructType) - extends DataSourceWriter with SupportsCustomWriterMetrics with Logging { - - private val memoryV2CustomMetrics = new MemoryV2CustomMetrics(sink) - - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) - - def commit(messages: Array[WriterCommitMessage]): Unit = { - val newRows = messages.flatMap { - case message: MemoryWriterCommitMessage => message.data - } - sink.write(batchId, outputMode, newRows) - } - - override def abort(messages: Array[WriterCommitMessage]): Unit = { - // Don't accept any of the new input. - } - - override def getCustomMetrics: CustomMetrics = { - memoryV2CustomMetrics - } -} - -class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) - extends StreamWriter with SupportsCustomWriterMetrics { +class MemoryStreamingWriteSupport( + val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) + extends StreamingWriteSupport with SupportsCustomWriterMetrics { private val customMemoryV2Metrics = new MemoryV2CustomMetrics(sink) - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) + override def createStreamingWriterFactory: MemoryWriterFactory = { + MemoryWriterFactory(outputMode, schema) + } override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { @@ -173,19 +155,23 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: // Don't accept any of the new input. } - override def getCustomMetrics: CustomMetrics = { - customMemoryV2Metrics - } + override def getCustomMetrics: CustomMetrics = customMemoryV2Metrics } case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) - extends DataWriterFactory[InternalRow] { + extends DataWriterFactory with StreamingDataWriterFactory { - override def createDataWriter( + override def createWriter( + partitionId: Int, + taskId: Long): DataWriter[InternalRow] = { + new MemoryDataWriter(partitionId, outputMode, schema) + } + + override def createWriter( partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { - new MemoryDataWriter(partitionId, outputMode, schema) + createWriter(partitionId, taskId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 874c479db95d..b2a573eae504 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.text.SimpleDateFormat -import java.util.{Calendar, List => JList, Locale, Optional} +import java.util.{Calendar, Locale} import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} @@ -32,16 +31,15 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.LongOffset -import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReader +import org.apache.spark.sql.execution.streaming.{LongOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} +import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReadSupport import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, DataSourceV2, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} import org.apache.spark.unsafe.types.UTF8String -// Shared object for micro-batch and continuous reader object TextSocketReader { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: @@ -50,14 +48,12 @@ object TextSocketReader { } /** - * A MicroBatchReader that reads text lines through a TCP socket, designed only for tutorials and - * debugging. This MicroBatchReader will *not* work in production applications due to multiple - * reasons, including no support for fault recovery. + * A MicroBatchReadSupport that reads text lines through a TCP socket, designed only for tutorials + * and debugging. This MicroBatchReadSupport will *not* work in production applications due to + * multiple reasons, including no support for fault recovery. */ -class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { - - private var startOffset: Offset = _ - private var endOffset: Offset = _ +class TextSocketMicroBatchReadSupport(options: DataSourceOptions) + extends MicroBatchReadSupport with Logging { private val host: String = options.get("host").get() private val port: Int = options.get("port").get().toInt @@ -103,7 +99,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR logWarning(s"Stream closed by $host:$port") return } - TextSocketMicroBatchReader.this.synchronized { + TextSocketMicroBatchReadSupport.this.synchronized { val newData = ( UTF8String.fromString(line), DateTimeUtils.fromMillis(Calendar.getInstance().getTimeInMillis) @@ -120,24 +116,15 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR readThread.start() } - override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = synchronized { - startOffset = start.orElse(LongOffset(-1L)) - endOffset = end.orElse(currentOffset) - } + override def initialOffset(): Offset = LongOffset(-1L) - override def getStartOffset(): Offset = { - Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set")) - } - - override def getEndOffset(): Offset = { - Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set")) - } + override def latestOffset(): Offset = currentOffset override def deserializeOffset(json: String): Offset = { LongOffset(json.toLong) } - override def readSchema(): StructType = { + override def fullSchema(): StructType = { if (options.getBoolean("includeTimestamp", false)) { TextSocketReader.SCHEMA_TIMESTAMP } else { @@ -145,12 +132,14 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { - assert(startOffset != null && endOffset != null, - "start offset and end offset should already be set before create read tasks.") + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + } - val startOrdinal = LongOffset.convert(startOffset).get.offset.toInt + 1 - val endOrdinal = LongOffset.convert(endOffset).get.offset.toInt + 1 + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startOrdinal = sc.start.asInstanceOf[LongOffset].offset.toInt + 1 + val endOrdinal = sc.end.get.asInstanceOf[LongOffset].offset.toInt + 1 // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { @@ -172,26 +161,29 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR slices(idx % numPartitions).append(r) } - (0 until numPartitions).map { i => - val slice = slices(i) - new InputPartition[InternalRow] { - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new InputPartitionReader[InternalRow] { - private var currentIdx = -1 + slices.map(TextSocketInputPartition) + } - override def next(): Boolean = { - currentIdx += 1 - currentIdx < slice.size - } + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + new PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val slice = partition.asInstanceOf[TextSocketInputPartition].slice + new PartitionReader[InternalRow] { + private var currentIdx = -1 - override def get(): InternalRow = { - InternalRow(slice(currentIdx)._1, slice(currentIdx)._2) - } + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } - override def close(): Unit = {} + override def get(): InternalRow = { + InternalRow(slice(currentIdx)._1, slice(currentIdx)._2) } + + override def close(): Unit = {} + } } - }.toList.asJava + } } override def commit(end: Offset): Unit = synchronized { @@ -227,8 +219,11 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR override def toString: String = s"TextSocketV2[host: $host, port: $port]" } +case class TextSocketInputPartition(slice: ListBuffer[(UTF8String, Long)]) extends InputPartition + class TextSocketSourceProvider extends DataSourceV2 - with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister with Logging { + with MicroBatchReadSupportProvider with ContinuousReadSupportProvider + with DataSourceRegister with Logging { private def checkParameters(params: DataSourceOptions): Unit = { logWarning("The socket source should not be used for production applications! " + @@ -248,27 +243,18 @@ class TextSocketSourceProvider extends DataSourceV2 } } - override def createMicroBatchReader( - schema: Optional[StructType], + override def createMicroBatchReadSupport( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { + options: DataSourceOptions): MicroBatchReadSupport = { checkParameters(options) - if (schema.isPresent) { - throw new AnalysisException("The socket source does not support a user-specified schema.") - } - - new TextSocketMicroBatchReader(options) + new TextSocketMicroBatchReadSupport(options) } - override def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = { + options: DataSourceOptions): ContinuousReadSupport = { checkParameters(options) - if (schema.isPresent) { - throw new AnalysisException("The socket source does not support a user-specified schema.") - } - new TextSocketContinuousReader(options) + new TextSocketContinuousReadSupport(options) } /** String that represents the format that this data source provider uses. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index ef8dc3a325a3..39e9e1ad426b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util.{Locale, Optional} +import java.util.Locale import scala.collection.JavaConverters._ @@ -28,8 +28,8 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -172,19 +172,21 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo case _ => None } ds match { - case s: MicroBatchReadSupport => - var tempReader: MicroBatchReader = null + case s: MicroBatchReadSupportProvider => + var tempReadSupport: MicroBatchReadSupport = null val schema = try { - tempReader = s.createMicroBatchReader( - Optional.ofNullable(userSpecifiedSchema.orNull), - Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, - options) - tempReader.readSchema() + val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath + tempReadSupport = if (userSpecifiedSchema.isDefined) { + s.createMicroBatchReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) + } else { + s.createMicroBatchReadSupport(tmpCheckpointPath, options) + } + tempReadSupport.fullSchema() } finally { // Stop tempReader to avoid side-effect thing - if (tempReader != null) { - tempReader.stop() - tempReader = null + if (tempReadSupport != null) { + tempReadSupport.stop() + tempReadSupport = null } } Dataset.ofRows( @@ -192,16 +194,28 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo StreamingRelationV2( s, source, extraOptions.toMap, schema.toAttributes, v1Relation)(sparkSession)) - case s: ContinuousReadSupport => - val tempReader = s.createContinuousReader( - Optional.ofNullable(userSpecifiedSchema.orNull), - Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, - options) + case s: ContinuousReadSupportProvider => + var tempReadSupport: ContinuousReadSupport = null + val schema = try { + val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath + tempReadSupport = if (userSpecifiedSchema.isDefined) { + s.createContinuousReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) + } else { + s.createContinuousReadSupport(tmpCheckpointPath, options) + } + tempReadSupport.fullSchema() + } finally { + // Stop tempReader to avoid side-effect thing + if (tempReadSupport != null) { + tempReadSupport.stop() + tempReadSupport = null + } + } Dataset.ofRows( sparkSession, StreamingRelationV2( s, source, extraOptions.toMap, - tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) + schema.toAttributes, v1Relation)(sparkSession)) case _ => // Code path for data source v1. Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 3b9a56ffdde4..7866e4f70f14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{InterfaceStability, Since} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.StreamWriteSupport +import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -270,7 +270,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc) + val sink = ForeachWriteSupportProvider[T](foreachWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -299,7 +299,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") val sink = ds.newInstance() match { - case w: StreamWriteSupport if !disabledSources.contains(w.getClass.getCanonicalName) => w + case w: StreamingWriteSupportProvider + if !disabledSources.contains(w.getClass.getCanonicalName) => w case _ => val ds = DataSource( df.sparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 25bb05212d66..cd52d991d55c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS -import org.apache.spark.sql.sources.v2.StreamWriteSupport +import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -256,7 +256,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } (sink, trigger) match { - case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) => + case (v2Sink: StreamingWriteSupportProvider, trigger: ContinuousTrigger) => if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index e4cead9df429..5602310219a7 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -24,29 +24,71 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { +public class JavaAdvancedDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { - public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, - SupportsPushDownFilters { + public class ReadSupport extends JavaSimpleReadSupport { + @Override + public ScanConfigBuilder newScanConfigBuilder() { + return new AdvancedScanConfigBuilder(); + } + + @Override + public InputPartition[] planInputPartitions(ScanConfig config) { + Filter[] filters = ((AdvancedScanConfigBuilder) config).filters; + List res = new ArrayList<>(); + + Integer lowerBound = null; + for (Filter filter : filters) { + if (filter instanceof GreaterThan) { + GreaterThan f = (GreaterThan) filter; + if ("i".equals(f.attribute()) && f.value() instanceof Integer) { + lowerBound = (Integer) f.value(); + break; + } + } + } + + if (lowerBound == null) { + res.add(new JavaRangeInputPartition(0, 5)); + res.add(new JavaRangeInputPartition(5, 10)); + } else if (lowerBound < 4) { + res.add(new JavaRangeInputPartition(lowerBound + 1, 5)); + res.add(new JavaRangeInputPartition(5, 10)); + } else if (lowerBound < 9) { + res.add(new JavaRangeInputPartition(lowerBound + 1, 10)); + } + + return res.stream().toArray(InputPartition[]::new); + } + + @Override + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + StructType requiredSchema = ((AdvancedScanConfigBuilder) config).requiredSchema; + return new AdvancedReaderFactory(requiredSchema); + } + } + + public static class AdvancedScanConfigBuilder implements ScanConfigBuilder, ScanConfig, + SupportsPushDownFilters, SupportsPushDownRequiredColumns { // Exposed for testing. public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); public Filter[] filters = new Filter[0]; @Override - public StructType readSchema() { - return requiredSchema; + public void pruneColumns(StructType requiredSchema) { + this.requiredSchema = requiredSchema; } @Override - public void pruneColumns(StructType requiredSchema) { - this.requiredSchema = requiredSchema; + public StructType readSchema() { + return requiredSchema; } @Override @@ -79,79 +121,54 @@ public Filter[] pushedFilters() { } @Override - public List> planInputPartitions() { - List> res = new ArrayList<>(); - - Integer lowerBound = null; - for (Filter filter : filters) { - if (filter instanceof GreaterThan) { - GreaterThan f = (GreaterThan) filter; - if ("i".equals(f.attribute()) && f.value() instanceof Integer) { - lowerBound = (Integer) f.value(); - break; - } - } - } - - if (lowerBound == null) { - res.add(new JavaAdvancedInputPartition(0, 5, requiredSchema)); - res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); - } else if (lowerBound < 4) { - res.add(new JavaAdvancedInputPartition(lowerBound + 1, 5, requiredSchema)); - res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); - } else if (lowerBound < 9) { - res.add(new JavaAdvancedInputPartition(lowerBound + 1, 10, requiredSchema)); - } - - return res; + public ScanConfig build() { + return this; } } - static class JavaAdvancedInputPartition implements InputPartition, - InputPartitionReader { - private int start; - private int end; - private StructType requiredSchema; + static class AdvancedReaderFactory implements PartitionReaderFactory { + StructType requiredSchema; - JavaAdvancedInputPartition(int start, int end, StructType requiredSchema) { - this.start = start; - this.end = end; + AdvancedReaderFactory(StructType requiredSchema) { this.requiredSchema = requiredSchema; } @Override - public InputPartitionReader createPartitionReader() { - return new JavaAdvancedInputPartition(start - 1, end, requiredSchema); - } - - @Override - public boolean next() { - start += 1; - return start < end; - } + public PartitionReader createReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + return new PartitionReader() { + private int current = p.start - 1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.end; + } - @Override - public InternalRow get() { - Object[] values = new Object[requiredSchema.size()]; - for (int i = 0; i < values.length; i++) { - if ("i".equals(requiredSchema.apply(i).name())) { - values[i] = start; - } else if ("j".equals(requiredSchema.apply(i).name())) { - values[i] = -start; + @Override + public InternalRow get() { + Object[] values = new Object[requiredSchema.size()]; + for (int i = 0; i < values.length; i++) { + if ("i".equals(requiredSchema.apply(i).name())) { + values[i] = current; + } else if ("j".equals(requiredSchema.apply(i).name())) { + values[i] = -current; + } + } + return new GenericInternalRow(values); } - } - return new GenericInternalRow(values); - } - @Override - public void close() throws IOException { + @Override + public void close() throws IOException { + } + }; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java deleted file mode 100644 index 97d6176d0255..000000000000 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark.sql.sources.v2; - -import java.io.IOException; -import java.util.List; - -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.*; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.vectorized.ColumnVector; -import org.apache.spark.sql.vectorized.ColumnarBatch; - - -public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { - - class Reader implements DataSourceReader, SupportsScanColumnarBatch { - private final StructType schema = new StructType().add("i", "int").add("j", "int"); - - @Override - public StructType readSchema() { - return schema; - } - - @Override - public List> planBatchInputPartitions() { - return java.util.Arrays.asList( - new JavaBatchInputPartition(0, 50), new JavaBatchInputPartition(50, 90)); - } - } - - static class JavaBatchInputPartition - implements InputPartition, InputPartitionReader { - private int start; - private int end; - - private static final int BATCH_SIZE = 20; - - private OnHeapColumnVector i; - private OnHeapColumnVector j; - private ColumnarBatch batch; - - JavaBatchInputPartition(int start, int end) { - this.start = start; - this.end = end; - } - - @Override - public InputPartitionReader createPartitionReader() { - this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); - this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); - ColumnVector[] vectors = new ColumnVector[2]; - vectors[0] = i; - vectors[1] = j; - this.batch = new ColumnarBatch(vectors); - return this; - } - - @Override - public boolean next() { - i.reset(); - j.reset(); - int count = 0; - while (start < end && count < BATCH_SIZE) { - i.putInt(count, start); - j.putInt(count, -start); - start += 1; - count += 1; - } - - if (count == 0) { - return false; - } else { - batch.setNumRows(count); - return true; - } - } - - @Override - public ColumnarBatch get() { - return batch; - } - - @Override - public void close() throws IOException { - batch.close(); - } - } - - - @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); - } -} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java new file mode 100644 index 000000000000..28a933039831 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +public class JavaColumnarDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { + + class ReadSupport extends JavaSimpleReadSupport { + + @Override + public InputPartition[] planInputPartitions(ScanConfig config) { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new JavaRangeInputPartition(0, 50); + partitions[1] = new JavaRangeInputPartition(50, 90); + return partitions; + } + + @Override + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + return new ColumnarReaderFactory(); + } + } + + static class ColumnarReaderFactory implements PartitionReaderFactory { + private static final int BATCH_SIZE = 20; + + @Override + public boolean supportColumnarReads(InputPartition partition) { + return true; + } + + @Override + public PartitionReader createReader(InputPartition partition) { + throw new UnsupportedOperationException(""); + } + + @Override + public PartitionReader createColumnarReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + OnHeapColumnVector i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + OnHeapColumnVector j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + ColumnVector[] vectors = new ColumnVector[2]; + vectors[0] = i; + vectors[1] = j; + ColumnarBatch batch = new ColumnarBatch(vectors); + + return new PartitionReader() { + private int current = p.start; + + @Override + public boolean next() throws IOException { + i.reset(); + j.reset(); + int count = 0; + while (current < p.end && count < BATCH_SIZE) { + i.putInt(count, current); + j.putInt(count, -current); + current += 1; + count += 1; + } + + if (count == 0) { + return false; + } else { + batch.setNumRows(count); + return true; + } + } + + @Override + public ColumnarBatch get() { + return batch; + } + + @Override + public void close() throws IOException { + batch.close(); + } + }; + } + } + + @Override + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 2d21324f5ece..18a11dde8219 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -19,38 +19,34 @@ import java.io.IOException; import java.util.Arrays; -import java.util.List; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.*; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; -import org.apache.spark.sql.types.StructType; -public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { +public class JavaPartitionAwareDataSource implements DataSourceV2, BatchReadSupportProvider { - class Reader implements DataSourceReader, SupportsReportPartitioning { - private final StructType schema = new StructType().add("a", "int").add("b", "int"); + class ReadSupport extends JavaSimpleReadSupport implements SupportsReportPartitioning { @Override - public StructType readSchema() { - return schema; + public InputPartition[] planInputPartitions(ScanConfig config) { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}); + partitions[1] = new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2}); + return partitions; } @Override - public List> planInputPartitions() { - return java.util.Arrays.asList( - new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}), - new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + return new SpecificReaderFactory(); } @Override - public Partitioning outputPartitioning() { + public Partitioning outputPartitioning(ScanConfig config) { return new MyPartitioning(); } } @@ -66,50 +62,53 @@ public int numPartitions() { public boolean satisfy(Distribution distribution) { if (distribution instanceof ClusteredDistribution) { String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns; - return Arrays.asList(clusteredCols).contains("a"); + return Arrays.asList(clusteredCols).contains("i"); } return false; } } - static class SpecificInputPartition implements InputPartition, - InputPartitionReader { - - private int[] i; - private int[] j; - private int current = -1; + static class SpecificInputPartition implements InputPartition { + int[] i; + int[] j; SpecificInputPartition(int[] i, int[] j) { assert i.length == j.length; this.i = i; this.j = j; } + } - @Override - public boolean next() throws IOException { - current += 1; - return current < i.length; - } - - @Override - public InternalRow get() { - return new GenericInternalRow(new Object[] {i[current], j[current]}); - } - - @Override - public void close() throws IOException { - - } + static class SpecificReaderFactory implements PartitionReaderFactory { @Override - public InputPartitionReader createPartitionReader() { - return this; + public PartitionReader createReader(InputPartition partition) { + SpecificInputPartition p = (SpecificInputPartition) partition; + return new PartitionReader() { + private int current = -1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.i.length; + } + + @Override + public InternalRow get() { + return new GenericInternalRow(new Object[] {p.i[current], p.j[current]}); + } + + @Override + public void close() throws IOException { + + } + }; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index 6fd6a44d2c4d..cc9ac04a0dad 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -17,43 +17,39 @@ package test.org.apache.spark.sql.sources.v2; -import java.util.List; - -import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupport { +public class JavaSchemaRequiredDataSource implements DataSourceV2, BatchReadSupportProvider { - class Reader implements DataSourceReader { + class ReadSupport extends JavaSimpleReadSupport { private final StructType schema; - Reader(StructType schema) { + ReadSupport(StructType schema) { this.schema = schema; } @Override - public StructType readSchema() { + public StructType fullSchema() { return schema; } @Override - public List> planInputPartitions() { - return java.util.Collections.emptyList(); + public InputPartition[] planInputPartitions(ScanConfig config) { + return new InputPartition[0]; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { throw new IllegalArgumentException("requires a user-supplied schema"); } @Override - public DataSourceReader createReader(StructType schema, DataSourceOptions options) { - return new Reader(schema); + public BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) { + return new ReadSupport(schema); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 274dc3745bcf..2cdbba84ec4a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -17,72 +17,26 @@ package test.org.apache.spark.sql.sources.v2; -import java.io.IOException; -import java.util.List; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; -import org.apache.spark.sql.sources.v2.reader.InputPartition; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.types.StructType; - -public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { - - class Reader implements DataSourceReader { - private final StructType schema = new StructType().add("i", "int").add("j", "int"); - - @Override - public StructType readSchema() { - return schema; - } - - @Override - public List> planInputPartitions() { - return java.util.Arrays.asList( - new JavaSimpleInputPartition(0, 5), - new JavaSimpleInputPartition(5, 10)); - } - } - - static class JavaSimpleInputPartition implements InputPartition, - InputPartitionReader { +import org.apache.spark.sql.sources.v2.reader.*; - private int start; - private int end; +public class JavaSimpleDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { - JavaSimpleInputPartition(int start, int end) { - this.start = start; - this.end = end; - } - - @Override - public InputPartitionReader createPartitionReader() { - return new JavaSimpleInputPartition(start - 1, end); - } + class ReadSupport extends JavaSimpleReadSupport { @Override - public boolean next() { - start += 1; - return start < end; - } - - @Override - public InternalRow get() { - return new GenericInternalRow(new Object[] {start, -start}); - } - - @Override - public void close() throws IOException { - + public InputPartition[] planInputPartitions(ScanConfig config) { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new JavaRangeInputPartition(0, 5); + partitions[1] = new JavaRangeInputPartition(5, 10); + return partitions; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java new file mode 100644 index 000000000000..685f9b9747e8 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.StructType; + +abstract class JavaSimpleReadSupport implements BatchReadSupport { + + @Override + public StructType fullSchema() { + return new StructType().add("i", "int").add("j", "int"); + } + + @Override + public ScanConfigBuilder newScanConfigBuilder() { + return new JavaNoopScanConfigBuilder(fullSchema()); + } + + @Override + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + return new JavaSimpleReaderFactory(); + } +} + +class JavaNoopScanConfigBuilder implements ScanConfigBuilder, ScanConfig { + + private StructType schema; + + JavaNoopScanConfigBuilder(StructType schema) { + this.schema = schema; + } + + @Override + public ScanConfig build() { + return this; + } + + @Override + public StructType readSchema() { + return schema; + } +} + +class JavaSimpleReaderFactory implements PartitionReaderFactory { + + @Override + public PartitionReader createReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + return new PartitionReader() { + private int current = p.start - 1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.end; + } + + @Override + public InternalRow get() { + return new GenericInternalRow(new Object[] {current, -current}); + } + + @Override + public void close() throws IOException { + + } + }; + } +} + +class JavaRangeInputPartition implements InputPartition { + int start; + int end; + + JavaRangeInputPartition(int start, int end) { + this.start = start; + this.end = end; + } +} diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 46b38bed1c0f..a36b0cfa6ff1 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -9,6 +9,6 @@ org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly org.apache.spark.sql.streaming.sources.FakeReadBothModes org.apache.spark.sql.streaming.sources.FakeReadNeitherMode -org.apache.spark.sql.streaming.sources.FakeWrite +org.apache.spark.sql.streaming.sources.FakeWriteSupportProvider org.apache.spark.sql.streaming.sources.FakeNoWrite -org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback +org.apache.spark.sql.streaming.sources.FakeWriteSupportProviderV1Fallback diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 1efaead0845d..50f13bee251e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -41,10 +41,11 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(writer.commit().data.isEmpty) } - test("continuous writer") { + test("streaming writer") { val sink = new MemorySinkV2 - val writer = new MemoryStreamWriter(sink, OutputMode.Append(), new StructType().add("i", "int")) - writer.commit(0, + val writeSupport = new MemoryStreamingWriteSupport( + sink, OutputMode.Append(), new StructType().add("i", "int")) + writeSupport.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -52,29 +53,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - writer.commit(19, - Array( - MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), - MemoryWriterCommitMessage(0, Seq(Row(33))) - )) - assert(sink.latestBatchId.contains(19)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) - - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) - } - - test("microbatch writer") { - val sink = new MemorySinkV2 - val schema = new StructType().add("i", "int") - new MemoryWriter(sink, 0, OutputMode.Append(), schema).commit( - Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), - MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) - )) - assert(sink.latestBatchId.contains(0)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append(), schema).commit( + writeSupport.commit(19, Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) @@ -88,22 +67,21 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("writer metrics") { val sink = new MemorySinkV2 val schema = new StructType().add("i", "int") + val writeSupport = new MemoryStreamingWriteSupport( + sink, OutputMode.Append(), schema) // batch 0 - var writer = new MemoryWriter(sink, 0, OutputMode.Append(), schema) - writer.commit( + writeSupport.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))) )) - assert(writer.getCustomMetrics.json() == "{\"numRows\":6}") + assert(writeSupport.getCustomMetrics.json() == "{\"numRows\":6}") // batch 1 - writer = new MemoryWriter(sink, 1, OutputMode.Append(), schema - ) - writer.commit( + writeSupport.commit(1, Array( MemoryWriterCommitMessage(0, Seq(Row(7), Row(8))) )) - assert(writer.getCustomMetrics.json() == "{\"numRows\":8}") + assert(writeSupport.getCustomMetrics.json() == "{\"numRows\":8}") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala index 55acf2ba28d2..5884380271f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala @@ -19,12 +19,10 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.ByteArrayOutputStream -import org.scalatest.time.SpanSugar._ - import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming.{StreamTest, Trigger} -class ConsoleWriterSuite extends StreamTest { +class ConsoleWriteSupportSuite extends StreamTest { import testImplicits._ test("microbatch - default") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 7e53da1f312c..9c1756d68ccc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -17,19 +17,17 @@ package org.apache.spark.sql.execution.streaming.sources -import java.util.Optional import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock @@ -42,7 +40,7 @@ class RateSourceSuite extends StreamTest { override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { assert(query.nonEmpty) val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + case StreamingExecutionRelation(source: RateStreamMicroBatchReadSupport, _) => source }.head rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) @@ -55,10 +53,10 @@ class RateSourceSuite extends StreamTest { test("microbatch in registry") { withTempDir { temp => DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader( - Optional.empty(), temp.getCanonicalPath, DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case ds: MicroBatchReadSupportProvider => + val readSupport = ds.createMicroBatchReadSupport( + temp.getCanonicalPath, DataSourceOptions.empty()) + assert(readSupport.isInstanceOf[RateStreamMicroBatchReadSupport]) case _ => throw new IllegalStateException("Could not find read support for rate") } @@ -68,7 +66,7 @@ class RateSourceSuite extends StreamTest { test("compatible with old path in registry") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => + case ds: MicroBatchReadSupportProvider => assert(ds.isInstanceOf[RateStreamProvider]) case _ => throw new IllegalStateException("Could not find read support for rate") @@ -109,30 +107,19 @@ class RateSourceSuite extends StreamTest { ) } - test("microbatch - set offset") { - withTempDir { temp => - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp.getCanonicalPath) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - } - test("microbatch - infer offsets") { withTempDir { temp => - val reader = new RateStreamMicroBatchReader( + val readSupport = new RateStreamMicroBatchReadSupport( new DataSourceOptions( Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), temp.getCanonicalPath) - reader.clock.asInstanceOf[ManualClock].advance(100000) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { + readSupport.clock.asInstanceOf[ManualClock].advance(100000) + val startOffset = readSupport.initialOffset() + startOffset match { case r: LongOffset => assert(r.offset === 0L) case _ => throw new IllegalStateException("unexpected offset type") } - reader.getEndOffset() match { + readSupport.latestOffset() match { case r: LongOffset => assert(r.offset >= 100) case _ => throw new IllegalStateException("unexpected offset type") } @@ -141,15 +128,16 @@ class RateSourceSuite extends StreamTest { test("microbatch - predetermined batch size") { withTempDir { temp => - val reader = new RateStreamMicroBatchReader( + val readSupport = new RateStreamMicroBatchReadSupport( new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp.getCanonicalPath) val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planInputPartitions() + val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() + val tasks = readSupport.planInputPartitions(config) + val readerFactory = readSupport.createReaderFactory(config) assert(tasks.size == 1) - val dataReader = tasks.get(0).createPartitionReader() + val dataReader = readerFactory.createReader(tasks(0)) val data = ArrayBuffer[InternalRow]() while (dataReader.next()) { data.append(dataReader.get()) @@ -160,24 +148,25 @@ class RateSourceSuite extends StreamTest { test("microbatch - data read") { withTempDir { temp => - val reader = new RateStreamMicroBatchReader( + val readSupport = new RateStreamMicroBatchReadSupport( new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp.getCanonicalPath) val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planInputPartitions() + val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() + val tasks = readSupport.planInputPartitions(config) + val readerFactory = readSupport.createReaderFactory(config) assert(tasks.size == 11) - val readData = tasks.asScala - .map(_.createPartitionReader()) + val readData = tasks + .map(readerFactory.createReader) .flatMap { reader => val buf = scala.collection.mutable.ListBuffer[InternalRow]() while (reader.next()) buf.append(reader.get()) buf } - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + assert(readData.map(_.getLong(1)).sorted === 0.until(33).toArray) } } @@ -288,41 +277,44 @@ class RateSourceSuite extends StreamTest { } test("user-specified schema given") { - val exception = intercept[AnalysisException] { + val exception = intercept[UnsupportedOperationException] { spark.readStream .format("rate") .schema(spark.range(1).schema) .load() } assert(exception.getMessage.contains( - "rate source does not support a user-specified schema")) + "rate source does not support user-specified schema")) } test("continuous in registry") { DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) + case ds: ContinuousReadSupportProvider => + val readSupport = ds.createContinuousReadSupport( + "", DataSourceOptions.empty()) + assert(readSupport.isInstanceOf[RateStreamContinuousReadSupport]) case _ => throw new IllegalStateException("Could not find read support for continuous rate") } } test("continuous data") { - val reader = new RateStreamContinuousReader( + val readSupport = new RateStreamContinuousReadSupport( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.planInputPartitions() + val config = readSupport.newScanConfigBuilder(readSupport.initialOffset).build() + val tasks = readSupport.planInputPartitions(config) + val readerFactory = readSupport.createContinuousReaderFactory(config) assert(tasks.size == 2) val data = scala.collection.mutable.ListBuffer[InternalRow]() - tasks.asScala.foreach { + tasks.foreach { case t: RateStreamContinuousInputPartition => - val startTimeMs = reader.getStartOffset() + val startTimeMs = readSupport.initialOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs - val r = t.createPartitionReader().asInstanceOf[RateStreamContinuousInputPartitionReader] + val r = readerFactory.createReader(t) + .asInstanceOf[RateStreamContinuousPartitionReader] for (rowIndex <- 0 to 9) { r.next() data.append(r.get()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 48e5cf75bf8b..409156e5ebc7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -21,7 +21,6 @@ import java.net.{InetSocketAddress, SocketException} import java.nio.ByteBuffer import java.nio.channels.ServerSocketChannel import java.sql.Timestamp -import java.util.Optional import java.util.concurrent.LinkedBlockingQueue import scala.collection.JavaConverters._ @@ -34,8 +33,8 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -49,14 +48,9 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread.join() serverThread = null } - if (batchReader != null) { - batchReader.stop() - batchReader = null - } } private var serverThread: ServerThread = null - private var batchReader: MicroBatchReader = null case class AddSocketData(data: String*) extends AddData { override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { @@ -65,7 +59,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before "Cannot add data when there is no query for finding the active socket source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: TextSocketMicroBatchReader, _) => source + case StreamingExecutionRelation(source: TextSocketMicroBatchReadSupport, _) => source } if (sources.isEmpty) { throw new Exception( @@ -91,7 +85,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("backward compatibility with old path") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => + case ds: MicroBatchReadSupportProvider => assert(ds.isInstanceOf[TextSocketSourceProvider]) case _ => throw new IllegalStateException("Could not find socket source") @@ -181,16 +175,16 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("params not given") { val provider = new TextSocketSourceProvider intercept[AnalysisException] { - provider.createMicroBatchReader(Optional.empty(), "", - new DataSourceOptions(Map.empty[String, String].asJava)) + provider.createMicroBatchReadSupport( + "", new DataSourceOptions(Map.empty[String, String].asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReader(Optional.empty(), "", - new DataSourceOptions(Map("host" -> "localhost").asJava)) + provider.createMicroBatchReadSupport( + "", new DataSourceOptions(Map("host" -> "localhost").asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReader(Optional.empty(), "", - new DataSourceOptions(Map("port" -> "1234").asJava)) + provider.createMicroBatchReadSupport( + "", new DataSourceOptions(Map("port" -> "1234").asJava)) } } @@ -199,7 +193,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") intercept[AnalysisException] { val a = new DataSourceOptions(params.asJava) - provider.createMicroBatchReader(Optional.empty(), "", a) + provider.createMicroBatchReadSupport("", a) } } @@ -209,12 +203,12 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before StructField("name", StringType) :: StructField("area", StringType) :: Nil) val params = Map("host" -> "localhost", "port" -> "1234") - val exception = intercept[AnalysisException] { - provider.createMicroBatchReader( - Optional.of(userSpecifiedSchema), "", new DataSourceOptions(params.asJava)) + val exception = intercept[UnsupportedOperationException] { + provider.createMicroBatchReadSupport( + userSpecifiedSchema, "", new DataSourceOptions(params.asJava)) } assert(exception.getMessage.contains( - "socket source does not support a user-specified schema")) + "socket source does not support user-specified schema")) } test("input row metrics") { @@ -305,25 +299,27 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val reader = new TextSocketContinuousReader( + val readSupport = new TextSocketContinuousReadSupport( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "port" -> serverThread.port.toString).asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.planInputPartitions() + + val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build() + val tasks = readSupport.planInputPartitions(scanConfig) assert(tasks.size == 2) val numRecords = 10 val data = scala.collection.mutable.ListBuffer[Int]() val offsets = scala.collection.mutable.ListBuffer[Int]() + val readerFactory = readSupport.createContinuousReaderFactory(scanConfig) import org.scalatest.time.SpanSugar._ failAfter(5 seconds) { // inject rows, read and check the data and offsets for (i <- 0 until numRecords) { serverThread.enqueue(i.toString) } - tasks.asScala.foreach { + tasks.foreach { case t: TextSocketContinuousInputPartition => - val r = t.createPartitionReader().asInstanceOf[TextSocketContinuousInputPartitionReader] + val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] for (i <- 0 until numRecords / 2) { r.next() offsets.append(r.getOffset().asInstanceOf[ContinuousRecordPartitionOffset].offset) @@ -339,16 +335,15 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before data.clear() case _ => throw new IllegalStateException("Unexpected task type") } - assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == List(3, 3)) - reader.commit(TextSocketOffset(List(5, 5))) - assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == List(5, 5)) + assert(readSupport.startOffset.offsets == List(3, 3)) + readSupport.commit(TextSocketOffset(List(5, 5))) + assert(readSupport.startOffset.offsets == List(5, 5)) } def commitOffset(partition: Int, offset: Int): Unit = { - val offsetsToCommit = reader.getStartOffset.asInstanceOf[TextSocketOffset] - .offsets.updated(partition, offset) - reader.commit(TextSocketOffset(offsetsToCommit)) - assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == offsetsToCommit) + val offsetsToCommit = readSupport.startOffset.offsets.updated(partition, offset) + readSupport.commit(TextSocketOffset(offsetsToCommit)) + assert(readSupport.startOffset.offsets == offsetsToCommit) } } @@ -356,14 +351,13 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val reader = new TextSocketContinuousReader( + val readSupport = new TextSocketContinuousReadSupport( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "port" -> serverThread.port.toString).asJava)) - reader.setStartOffset(Optional.of(TextSocketOffset(List(5, 5)))) - // ok to commit same offset - reader.setStartOffset(Optional.of(TextSocketOffset(List(5, 5)))) + + readSupport.startOffset = TextSocketOffset(List(5, 5)) assertThrows[IllegalStateException] { - reader.commit(TextSocketOffset(List(6, 6))) + readSupport.commit(TextSocketOffset(List(6, 6))) } } @@ -371,12 +365,12 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val reader = new TextSocketContinuousReader( + val readSupport = new TextSocketContinuousReadSupport( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "includeTimestamp" -> "true", "port" -> serverThread.port.toString).asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.planInputPartitions() + val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build() + val tasks = readSupport.planInputPartitions(scanConfig) assert(tasks.size == 2) val numRecords = 4 @@ -384,9 +378,10 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before for (i <- 0 until numRecords) { serverThread.enqueue(i.toString) } - tasks.asScala.foreach { + val readerFactory = readSupport.createContinuousReaderFactory(scanConfig) + tasks.foreach { case t: TextSocketContinuousInputPartition => - val r = t.createPartitionReader().asInstanceOf[TextSocketContinuousInputPartitionReader] + val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] for (i <- 0 until numRecords / 2) { r.next() assert(r.get().get(0, TextSocketReader.SCHEMA_TIMESTAMP) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index aa5f723365d5..5edeff553eb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.sources.v2 -import java.util.{ArrayList, List => JList} - import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException @@ -38,6 +36,21 @@ import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ + private def getScanConfig(query: DataFrame): AdvancedScanConfigBuilder = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => + d.scanConfig.asInstanceOf[AdvancedScanConfigBuilder] + }.head + } + + private def getJavaScanConfig( + query: DataFrame): JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => + d.scanConfig.asInstanceOf[JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder] + }.head + } + test("simplest implementation") { Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -50,18 +63,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("advanced implementation") { - def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] - }.head - } - - def getJavaReader(query: DataFrame): JavaAdvancedDataSourceV2#Reader = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader] - }.head - } - Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() @@ -70,58 +71,58 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val q1 = df.select('j) checkAnswer(q1, (0 until 10).map(i => Row(-i))) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q1) - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + val config = getScanConfig(q1) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } else { - val reader = getJavaReader(q1) - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + val config = getJavaScanConfig(q1) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } val q2 = df.filter('i > 3) checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q2) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + val config = getScanConfig(q2) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i", "j")) } else { - val reader = getJavaReader(q2) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + val config = getJavaScanConfig(q2) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i", "j")) } val q3 = df.select('i).filter('i > 6) checkAnswer(q3, (7 until 10).map(i => Row(i))) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q3) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i")) + val config = getScanConfig(q3) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i")) } else { - val reader = getJavaReader(q3) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i")) + val config = getJavaScanConfig(q3) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i")) } val q4 = df.select('j).filter('j < -10) checkAnswer(q4, Nil) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q4) + val config = getScanConfig(q4) // 'j < 10 is not supported by the testing data source. - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } else { - val reader = getJavaReader(q4) + val config = getJavaScanConfig(q4) // 'j < 10 is not supported by the testing data source. - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } } } } test("columnar batch scan implementation") { - Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls => + Seq(classOf[ColumnarDataSourceV2], classOf[JavaColumnarDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 90).map(i => Row(i, -i))) @@ -153,25 +154,25 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val df = spark.read.format(cls.getName).load() checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) - val groupByColA = df.groupBy('a).agg(sum('b)) + val groupByColA = df.groupBy('i).agg(sum('j)) checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) assert(groupByColA.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColAB = df.groupBy('a, 'b).agg(count("*")) + val groupByColAB = df.groupBy('i, 'j).agg(count("*")) checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) assert(groupByColAB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColB = df.groupBy('b).agg(sum('a)) + val groupByColB = df.groupBy('j).agg(sum('i)) checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) assert(groupByColB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isDefined) - val groupByAPlusB = df.groupBy('a + 'b).agg(count("*")) + val groupByAPlusB = df.groupBy('i + 'j).agg(count("*")) checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) assert(groupByAPlusB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e @@ -272,36 +273,30 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("SPARK-23301: column pruning with arbitrary expressions") { - def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] - }.head - } - val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() val q1 = df.select('i + 1) checkAnswer(q1, (1 until 11).map(i => Row(i))) - val reader1 = getReader(q1) - assert(reader1.requiredSchema.fieldNames === Seq("i")) + val config1 = getScanConfig(q1) + assert(config1.requiredSchema.fieldNames === Seq("i")) val q2 = df.select(lit(1)) checkAnswer(q2, (0 until 10).map(i => Row(1))) - val reader2 = getReader(q2) - assert(reader2.requiredSchema.isEmpty) + val config2 = getScanConfig(q2) + assert(config2.requiredSchema.isEmpty) // 'j === 1 can't be pushed down, but we should still be able do column pruning val q3 = df.filter('j === -1).select('j * 2) checkAnswer(q3, Row(-2)) - val reader3 = getReader(q3) - assert(reader3.filters.isEmpty) - assert(reader3.requiredSchema.fieldNames === Seq("j")) + val config3 = getScanConfig(q3) + assert(config3.filters.isEmpty) + assert(config3.requiredSchema.fieldNames === Seq("j")) // column pruning should work with other operators. val q4 = df.sort('i).limit(1).select('i + 1) checkAnswer(q4, Row(1)) - val reader4 = getReader(q4) - assert(reader4.requiredSchema.fieldNames === Seq("i")) + val config4 = getScanConfig(q4) + assert(config4.requiredSchema.fieldNames === Seq("i")) } test("SPARK-23315: get output from canonicalized data source v2 related plans") { @@ -324,240 +319,290 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } -class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") +case class RangeInputPartition(start: Int, end: Int) extends InputPartition - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { - java.util.Arrays.asList(new SimpleInputPartition(0, 5)) - } - } - - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader +case class NoopScanConfigBuilder(readSchema: StructType) extends ScanConfigBuilder with ScanConfig { + override def build(): ScanConfig = this } -class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { +object SimpleReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[InternalRow] { + private var current = start - 1 + + override def next(): Boolean = { + current += 1 + current < end + } - class Reader extends DataSourceReader { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + override def get(): InternalRow = InternalRow(current, -current) - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { - java.util.Arrays.asList(new SimpleInputPartition(0, 5), new SimpleInputPartition(5, 10)) + override def close(): Unit = {} } } - - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class SimpleInputPartition(start: Int, end: Int) - extends InputPartition[InternalRow] - with InputPartitionReader[InternalRow] { - private var current = start - 1 - - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new SimpleInputPartition(start, end) +abstract class SimpleReadSupport extends BatchReadSupport { + override def fullSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def next(): Boolean = { - current += 1 - current < end + override def newScanConfigBuilder(): ScanConfigBuilder = { + NoopScanConfigBuilder(fullSchema()) } - override def get(): InternalRow = InternalRow(current, -current) - - override def close(): Unit = {} + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + SimpleReaderFactory + } } +class SimpleSinglePartitionSource extends DataSourceV2 with BatchReadSupportProvider { -class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { + class ReadSupport extends SimpleReadSupport { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + Array(RangeInputPartition(0, 5)) + } + } + + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } +} - class Reader extends DataSourceReader - with SupportsPushDownRequiredColumns with SupportsPushDownFilters { - var requiredSchema = new StructType().add("i", "int").add("j", "int") - var filters = Array.empty[Filter] +class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { - override def pruneColumns(requiredSchema: StructType): Unit = { - this.requiredSchema = requiredSchema + class ReadSupport extends SimpleReadSupport { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) } + } - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - val (supported, unsupported) = filters.partition { - case GreaterThan("i", _: Int) => true - case _ => false - } - this.filters = supported - unsupported - } + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } +} - override def pushedFilters(): Array[Filter] = filters - override def readSchema(): StructType = { - requiredSchema - } +class AdvancedDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { + + class ReadSupport extends SimpleReadSupport { + override def newScanConfigBuilder(): ScanConfigBuilder = new AdvancedScanConfigBuilder() + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val filters = config.asInstanceOf[AdvancedScanConfigBuilder].filters - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { val lowerBound = filters.collectFirst { case GreaterThan("i", v: Int) => v } - val res = new ArrayList[InputPartition[InternalRow]] + val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] if (lowerBound.isEmpty) { - res.add(new AdvancedInputPartition(0, 5, requiredSchema)) - res.add(new AdvancedInputPartition(5, 10, requiredSchema)) + res.append(RangeInputPartition(0, 5)) + res.append(RangeInputPartition(5, 10)) } else if (lowerBound.get < 4) { - res.add(new AdvancedInputPartition(lowerBound.get + 1, 5, requiredSchema)) - res.add(new AdvancedInputPartition(5, 10, requiredSchema)) + res.append(RangeInputPartition(lowerBound.get + 1, 5)) + res.append(RangeInputPartition(5, 10)) } else if (lowerBound.get < 9) { - res.add(new AdvancedInputPartition(lowerBound.get + 1, 10, requiredSchema)) + res.append(RangeInputPartition(lowerBound.get + 1, 10)) } - res + res.toArray + } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + val requiredSchema = config.asInstanceOf[AdvancedScanConfigBuilder].requiredSchema + new AdvancedReaderFactory(requiredSchema) } } - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } } -class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType) - extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] { +class AdvancedScanConfigBuilder extends ScanConfigBuilder with ScanConfig + with SupportsPushDownRequiredColumns with SupportsPushDownFilters { - private var current = start - 1 + var requiredSchema = new StructType().add("i", "int").add("j", "int") + var filters = Array.empty[Filter] - override def createPartitionReader(): InputPartitionReader[InternalRow] = { - new AdvancedInputPartition(start, end, requiredSchema) + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema } - override def close(): Unit = {} + override def readSchema(): StructType = requiredSchema - override def next(): Boolean = { - current += 1 - current < end + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val (supported, unsupported) = filters.partition { + case GreaterThan("i", _: Int) => true + case _ => false + } + this.filters = supported + unsupported } - override def get(): InternalRow = { - val values = requiredSchema.map(_.name).map { - case "i" => current - case "j" => -current + override def pushedFilters(): Array[Filter] = filters + + override def build(): ScanConfig = this +} + +class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[InternalRow] { + private var current = start - 1 + + override def next(): Boolean = { + current += 1 + current < end + } + + override def get(): InternalRow = { + val values = requiredSchema.map(_.name).map { + case "i" => current + case "j" => -current + } + InternalRow.fromSeq(values) + } + + override def close(): Unit = {} } - InternalRow.fromSeq(values) } } -class SchemaRequiredDataSource extends DataSourceV2 with ReadSupport { +class SchemaRequiredDataSource extends DataSourceV2 with BatchReadSupportProvider { - class Reader(val readSchema: StructType) extends DataSourceReader { - override def planInputPartitions(): JList[InputPartition[InternalRow]] = - java.util.Collections.emptyList() + class ReadSupport(val schema: StructType) extends SimpleReadSupport { + override def fullSchema(): StructType = schema + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = + Array.empty } - override def createReader(options: DataSourceOptions): DataSourceReader = { + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { throw new IllegalArgumentException("requires a user-supplied schema") } - override def createReader(schema: StructType, options: DataSourceOptions): DataSourceReader = { - new Reader(schema) + override def createBatchReadSupport( + schema: StructType, options: DataSourceOptions): BatchReadSupport = { + new ReadSupport(schema) } } -class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { +class ColumnarDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { - class Reader extends DataSourceReader with SupportsScanColumnarBatch { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + class ReadSupport extends SimpleReadSupport { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + Array(RangeInputPartition(0, 50), RangeInputPartition(50, 90)) + } - override def planBatchInputPartitions(): JList[InputPartition[ColumnarBatch]] = { - java.util.Arrays.asList( - new BatchInputPartitionReader(0, 50), new BatchInputPartitionReader(50, 90)) + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + ColumnarReaderFactory } } - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } } -class BatchInputPartitionReader(start: Int, end: Int) - extends InputPartition[ColumnarBatch] with InputPartitionReader[ColumnarBatch] { - +object ColumnarReaderFactory extends PartitionReaderFactory { private final val BATCH_SIZE = 20 - private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val batch = new ColumnarBatch(Array(i, j)) - private var current = start + override def supportColumnarReads(partition: InputPartition): Boolean = true - override def createPartitionReader(): InputPartitionReader[ColumnarBatch] = this + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + throw new UnsupportedOperationException + } - override def next(): Boolean = { - i.reset() - j.reset() + override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[ColumnarBatch] { + private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val batch = new ColumnarBatch(Array(i, j)) + + private var current = start + + override def next(): Boolean = { + i.reset() + j.reset() + + var count = 0 + while (current < end && count < BATCH_SIZE) { + i.putInt(count, current) + j.putInt(count, -current) + current += 1 + count += 1 + } - var count = 0 - while (current < end && count < BATCH_SIZE) { - i.putInt(count, current) - j.putInt(count, -current) - current += 1 - count += 1 - } + if (count == 0) { + false + } else { + batch.setNumRows(count) + true + } + } - if (count == 0) { - false - } else { - batch.setNumRows(count) - true - } - } + override def get(): ColumnarBatch = batch - override def get(): ColumnarBatch = { - batch + override def close(): Unit = batch.close() + } } - - override def close(): Unit = batch.close() } -class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsReportPartitioning { - override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") +class PartitionAwareDataSource extends DataSourceV2 with BatchReadSupportProvider { - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { + class ReadSupport extends SimpleReadSupport with SupportsReportPartitioning { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { // Note that we don't have same value of column `a` across partitions. - java.util.Arrays.asList( - new SpecificInputPartitionReader(Array(1, 1, 3), Array(4, 4, 6)), - new SpecificInputPartitionReader(Array(2, 4, 4), Array(6, 2, 2))) + Array( + SpecificInputPartition(Array(1, 1, 3), Array(4, 4, 6)), + SpecificInputPartition(Array(2, 4, 4), Array(6, 2, 2))) + } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + SpecificReaderFactory } - override def outputPartitioning(): Partitioning = new MyPartitioning + override def outputPartitioning(config: ScanConfig): Partitioning = new MyPartitioning } class MyPartitioning extends Partitioning { override def numPartitions(): Int = 2 override def satisfy(distribution: Distribution): Boolean = distribution match { - case c: ClusteredDistribution => c.clusteredColumns.contains("a") + case c: ClusteredDistribution => c.clusteredColumns.contains("i") case _ => false } } - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } } -class SpecificInputPartitionReader(i: Array[Int], j: Array[Int]) - extends InputPartition[InternalRow] - with InputPartitionReader[InternalRow] { - assert(i.length == j.length) - - private var current = -1 +case class SpecificInputPartition(i: Array[Int], j: Array[Int]) extends InputPartition - override def createPartitionReader(): InputPartitionReader[InternalRow] = this +object SpecificReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val p = partition.asInstanceOf[SpecificInputPartition] + new PartitionReader[InternalRow] { + private var current = -1 - override def next(): Boolean = { - current += 1 - current < i.length - } + override def next(): Boolean = { + current += 1 + current < p.i.length + } - override def get(): InternalRow = InternalRow(i(current), j(current)) + override def get(): InternalRow = InternalRow(p.i(current), p.j(current)) - override def close(): Unit = {} + override def close(): Unit = {} + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index e1b8e9c44d72..952241b0b6be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -18,34 +18,36 @@ package org.apache.spark.sql.sources.v2 import java.io.{BufferedReader, InputStreamReader, IOException} -import java.util.{Collections, List => JList, Optional} +import java.util.Optional import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration /** * A HDFS based transactional writable data source. - * Each task writes data to `target/_temporary/jobId/$jobId-$partitionId-$attemptNumber`. - * Each job moves files from `target/_temporary/jobId/` to `target`. + * Each task writes data to `target/_temporary/queryId/$jobId-$partitionId-$attemptNumber`. + * Each job moves files from `target/_temporary/queryId/` to `target`. */ -class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteSupport { +class SimpleWritableDataSource extends DataSourceV2 + with BatchReadSupportProvider with BatchWriteSupportProvider { private val schema = new StructType().add("i", "long").add("j", "long") - class Reader(path: String, conf: Configuration) extends DataSourceReader { - override def readSchema(): StructType = schema + class ReadSupport(path: String, conf: Configuration) extends SimpleReadSupport { - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { + override def fullSchema(): StructType = schema + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -53,21 +55,23 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS val name = status.getPath.getName name.startsWith("_") || name.startsWith(".") }.map { f => - val serializableConf = new SerializableConfiguration(conf) - new SimpleCSVInputPartitionReader( - f.getPath.toUri.toString, - serializableConf): InputPartition[InternalRow] - }.toList.asJava + CSVInputPartitionReader(f.getPath.toUri.toString) + }.toArray } else { - Collections.emptyList() + Array.empty } } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + val serializableConf = new SerializableConfiguration(conf) + new CSVReaderFactory(serializableConf) + } } - class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { - override def createWriterFactory(): DataWriterFactory[InternalRow] = { + class WritSupport(queryId: String, path: String, conf: Configuration) extends BatchWriteSupport { + override def createBatchWriterFactory(): DataWriterFactory = { SimpleCounter.resetCounter - new CSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + new CSVDataWriterFactory(path, queryId, new SerializableConfiguration(conf)) } override def onDataWriterCommit(message: WriterCommitMessage): Unit = { @@ -76,7 +80,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS override def commit(messages: Array[WriterCommitMessage]): Unit = { val finalPath = new Path(path) - val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) + val jobPath = new Path(new Path(finalPath, "_temporary"), queryId) val fs = jobPath.getFileSystem(conf) try { for (file <- fs.listStatus(jobPath).map(_.getPath)) { @@ -91,23 +95,23 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } override def abort(messages: Array[WriterCommitMessage]): Unit = { - val jobPath = new Path(new Path(path, "_temporary"), jobId) + val jobPath = new Path(new Path(path, "_temporary"), queryId) val fs = jobPath.getFileSystem(conf) fs.delete(jobPath, true) } } - override def createReader(options: DataSourceOptions): DataSourceReader = { + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { val path = new Path(options.get("path").get()) val conf = SparkContext.getActive.get.hadoopConfiguration - new Reader(path.toUri.toString, conf) + new ReadSupport(path.toUri.toString, conf) } - override def createWriter( - jobId: String, + override def createBatchWriteSupport( + queryId: String, schema: StructType, mode: SaveMode, - options: DataSourceOptions): Optional[DataSourceWriter] = { + options: DataSourceOptions): Optional[BatchWriteSupport] = { assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) @@ -130,39 +134,42 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } val pathStr = path.toUri.toString - Optional.of(new Writer(jobId, pathStr, conf)) + Optional.of(new WritSupport(queryId, pathStr, conf)) } } -class SimpleCSVInputPartitionReader(path: String, conf: SerializableConfiguration) - extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] { +case class CSVInputPartitionReader(path: String) extends InputPartition - @transient private var lines: Iterator[String] = _ - @transient private var currentLine: String = _ - @transient private var inputStream: FSDataInputStream = _ +class CSVReaderFactory(conf: SerializableConfiguration) + extends PartitionReaderFactory { - override def createPartitionReader(): InputPartitionReader[InternalRow] = { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val path = partition.asInstanceOf[CSVInputPartitionReader].path val filePath = new Path(path) val fs = filePath.getFileSystem(conf.value) - inputStream = fs.open(filePath) - lines = new BufferedReader(new InputStreamReader(inputStream)) - .lines().iterator().asScala - this - } - override def next(): Boolean = { - if (lines.hasNext) { - currentLine = lines.next() - true - } else { - false - } - } + new PartitionReader[InternalRow] { + private val inputStream = fs.open(filePath) + private val lines = new BufferedReader(new InputStreamReader(inputStream)) + .lines().iterator().asScala - override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toLong): _*) + private var currentLine: String = _ - override def close(): Unit = { - inputStream.close() + override def next(): Boolean = { + if (lines.hasNext) { + currentLine = lines.next() + true + } else { + false + } + } + + override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toLong): _*) + + override def close(): Unit = { + inputStream.close() + } + } } } @@ -183,12 +190,11 @@ private[v2] object SimpleCounter { } class CSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) - extends DataWriterFactory[InternalRow] { + extends DataWriterFactory { - override def createDataWriter( + override def createWriter( partitionId: Int, - taskId: Long, - epochId: Long): DataWriter[InternalRow] = { + taskId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index df22bc1315b7..b52800629517 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -686,7 +686,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case r: StreamingExecutionRelation => r.source - case r: StreamingDataSourceV2Relation => r.reader + case r: StreamingDataSourceV2Relation => r.readSupport } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index b96f2bcbdd64..6e1c5a544092 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -299,9 +299,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { try { val input = new MemoryStream[Int](0, sqlContext) { @volatile var numTriggers = 0 - override def getEndOffset: OffsetV2 = { + override def latestOffset(): OffsetV2 = { numTriggers += 1 - super.getEndOffset + super.latestOffset() } } val clock = new StreamManualClock() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 268ed58315fd..73592526fb0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.streaming -import java.{util => ju} -import java.util.Optional import java.util.concurrent.CountDownLatch import scala.collection.mutable import org.apache.commons.lang3.RandomStringUtils import org.json4s.NoTypeHints -import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.Serialization import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter @@ -35,13 +32,12 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig} import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType @@ -218,25 +214,17 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi private def dataAdded: Boolean = currentOffset.offset != -1 - // setOffsetRange should take 50 ms the first time it is called after data is added - override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { - synchronized { - if (dataAdded) clock.waitTillTime(1050) - super.setOffsetRange(start, end) - } - } - - // getEndOffset should take 100 ms the first time it is called after data is added - override def getEndOffset(): OffsetV2 = synchronized { - if (dataAdded) clock.waitTillTime(1150) - super.getEndOffset() + // latestOffset should take 50 ms the first time it is called after data is added + override def latestOffset(): OffsetV2 = synchronized { + if (dataAdded) clock.waitTillTime(1050) + super.latestOffset() } // getBatch should take 100 ms the first time it is called - override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { synchronized { - clock.waitTillTime(1350) - super.planInputPartitions() + clock.waitTillTime(1150) + super.planInputPartitions(config) } } } @@ -277,34 +265,26 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Waiting for next trigger"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress when setOffsetRange is being called + // Test status and progress when `latestOffset` is being called AddData(inputData, 1, 2), - AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on setOffsetRange + AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on `latestOffset` AssertStreamExecThreadIsWaitingForTime(1050), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message.startsWith("Getting offsets from")), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(50), // time = 1050 to unblock setOffsetRange + AdvanceManualClock(50), // time = 1050 to unblock `latestOffset` AssertClockTime(1050), - AssertStreamExecThreadIsWaitingForTime(1150), // will block on getEndOffset that needs 1150 - AssertOnQuery(_.status.isDataAvailable === false), - AssertOnQuery(_.status.isTriggerActive === true), - AssertOnQuery(_.status.message.startsWith("Getting offsets from")), - AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - - AdvanceManualClock(100), // time = 1150 to unblock getEndOffset - AssertClockTime(1150), - // will block on planInputPartitions that needs 1350 - AssertStreamExecThreadIsWaitingForTime(1350), + // will block on `planInputPartitions` that needs 1350 + AssertStreamExecThreadIsWaitingForTime(1150), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(200), // time = 1350 to unblock planInputPartitions - AssertClockTime(1350), + AdvanceManualClock(100), // time = 1150 to unblock `planInputPartitions` + AssertClockTime(1150), AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), @@ -312,7 +292,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch processing has completed - AdvanceManualClock(150), // time = 1500 to unblock map task + AdvanceManualClock(350), // time = 1500 to unblock map task AssertClockTime(1500), CheckAnswer(2), AssertStreamExecThreadIsWaitingForTime(2000), // will block until the next trigger @@ -332,11 +312,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.numInputRows === 2) assert(progress.processedRowsPerSecond === 4.0) - assert(progress.durationMs.get("setOffsetRange") === 50) - assert(progress.durationMs.get("getEndOffset") === 100) - assert(progress.durationMs.get("queryPlanning") === 200) + assert(progress.durationMs.get("latestOffset") === 50) + assert(progress.durationMs.get("queryPlanning") === 100) assert(progress.durationMs.get("walCommit") === 0) - assert(progress.durationMs.get("addBatch") === 150) + assert(progress.durationMs.get("addBatch") === 350) assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index 4f198819b58d..d6819eacd07c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -22,16 +22,15 @@ import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} import org.mockito.Mockito._ import org.scalatest.mockito.MockitoSugar -import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} -import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.InputPartition -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, ContinuousReadSupport, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.types.{DataType, IntegerType, StructType} class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { case class LongPartitionOffset(offset: Long) extends PartitionOffset @@ -44,8 +43,8 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { override def beforeEach(): Unit = { super.beforeEach() epochEndpoint = EpochCoordinatorRef.create( - mock[StreamWriter], - mock[ContinuousReader], + mock[StreamingWriteSupport], + mock[ContinuousReadSupport], mock[ContinuousExecution], coordinatorId, startEpoch, @@ -73,26 +72,26 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { */ private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { val queue = new ArrayBlockingQueue[UnsafeRow](1024) - val factory = new InputPartition[InternalRow] { - override def createPartitionReader() = new ContinuousInputPartitionReader[InternalRow] { - var index = -1 - var curr: UnsafeRow = _ - - override def next() = { - curr = queue.take() - index += 1 - true - } + val partitionReader = new ContinuousPartitionReader[InternalRow] { + var index = -1 + var curr: UnsafeRow = _ + + override def next() = { + curr = queue.take() + index += 1 + true + } - override def get = curr + override def get = curr - override def getOffset = LongPartitionOffset(index) + override def getOffset = LongPartitionOffset(index) - override def close() = {} - } + override def close() = {} } val reader = new ContinuousQueuedDataReader( - new ContinuousDataSourceRDDPartition(0, factory), + 0, + partitionReader, + new StructType().add("i", "int"), mockContext, dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 4980b0cd41f8..3d21bc63e0cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -41,7 +41,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReadSupport, _) => r }.get val deltaMs = numTriggers * 1000 + 300 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index 82836dced9df..3c973d8ebc70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -27,9 +27,9 @@ import org.apache.spark._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.LocalSparkSession import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.test.TestSparkSession class EpochCoordinatorSuite @@ -40,20 +40,20 @@ class EpochCoordinatorSuite private var epochCoordinator: RpcEndpointRef = _ - private var writer: StreamWriter = _ + private var writeSupport: StreamingWriteSupport = _ private var query: ContinuousExecution = _ private var orderVerifier: InOrder = _ override def beforeEach(): Unit = { - val reader = mock[ContinuousReader] - writer = mock[StreamWriter] + val reader = mock[ContinuousReadSupport] + writeSupport = mock[StreamingWriteSupport] query = mock[ContinuousExecution] - orderVerifier = inOrder(writer, query) + orderVerifier = inOrder(writeSupport, query) spark = new TestSparkSession() epochCoordinator - = EpochCoordinatorRef.create(writer, reader, query, "test", 1, spark, SparkEnv.get) + = EpochCoordinatorRef.create(writeSupport, reader, query, "test", 1, spark, SparkEnv.get) } test("single epoch") { @@ -209,12 +209,12 @@ class EpochCoordinatorSuite } private def verifyCommit(epoch: Long): Unit = { - orderVerifier.verify(writer).commit(eqTo(epoch), any()) + orderVerifier.verify(writeSupport).commit(eqTo(epoch), any()) orderVerifier.verify(query).commit(epoch) } private def verifyNoCommitFor(epoch: Long): Unit = { - verify(writer, never()).commit(eqTo(epoch), any()) + verify(writeSupport, never()).commit(eqTo(epoch), any()) verify(query, never()).commit(epoch) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 52b833a19c23..aeef4c8fe933 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -17,73 +17,74 @@ package org.apache.spark.sql.streaming.sources -import java.util.Optional - import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper} import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.InputPartition -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReaderFactory, ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -case class FakeReader() extends MicroBatchReader with ContinuousReader { - def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {} - def getStartOffset: Offset = RateStreamOffset(Map()) - def getEndOffset: Offset = RateStreamOffset(Map()) - def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) - def commit(end: Offset): Unit = {} - def readSchema(): StructType = StructType(Seq()) - def stop(): Unit = {} - def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) - def setStartOffset(start: Optional[Offset]): Unit = {} - - def planInputPartitions(): java.util.ArrayList[InputPartition[InternalRow]] = { +case class FakeReadSupport() extends MicroBatchReadSupport with ContinuousReadSupport { + override def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) + override def commit(end: Offset): Unit = {} + override def stop(): Unit = {} + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) + override def fullSchema(): StructType = StructType(Seq()) + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = null + override def initialOffset(): Offset = RateStreamOffset(Map()) + override def latestOffset(): Offset = RateStreamOffset(Map()) + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = null + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + throw new IllegalStateException("fake source - cannot actually read") + } + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + throw new IllegalStateException("fake source - cannot actually read") + } + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { throw new IllegalStateException("fake source - cannot actually read") } } -trait FakeMicroBatchReadSupport extends MicroBatchReadSupport { - override def createMicroBatchReader( - schema: Optional[StructType], +trait FakeMicroBatchReadSupportProvider extends MicroBatchReadSupportProvider { + override def createMicroBatchReadSupport( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = FakeReader() + options: DataSourceOptions): MicroBatchReadSupport = FakeReadSupport() } -trait FakeContinuousReadSupport extends ContinuousReadSupport { - override def createContinuousReader( - schema: Optional[StructType], +trait FakeContinuousReadSupportProvider extends ContinuousReadSupportProvider { + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = FakeReader() + options: DataSourceOptions): ContinuousReadSupport = FakeReadSupport() } -trait FakeStreamWriteSupport extends StreamWriteSupport { - override def createStreamWriter( +trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider { + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { + options: DataSourceOptions): StreamingWriteSupport = { throw new IllegalStateException("fake sink - cannot actually write") } } -class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupport { +class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupportProvider { override def shortName(): String = "fake-read-microbatch-only" } -class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupport { +class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupportProvider { override def shortName(): String = "fake-read-continuous-only" } class FakeReadBothModes extends DataSourceRegister - with FakeMicroBatchReadSupport with FakeContinuousReadSupport { + with FakeMicroBatchReadSupportProvider with FakeContinuousReadSupportProvider { override def shortName(): String = "fake-read-microbatch-continuous" } @@ -91,7 +92,7 @@ class FakeReadNeitherMode extends DataSourceRegister { override def shortName(): String = "fake-read-neither-mode" } -class FakeWrite extends DataSourceRegister with FakeStreamWriteSupport { +class FakeWriteSupportProvider extends DataSourceRegister with FakeStreamingWriteSupportProvider { override def shortName(): String = "fake-write-microbatch-continuous" } @@ -106,8 +107,8 @@ class FakeSink extends Sink { override def addBatch(batchId: Long, data: DataFrame): Unit = {} } -class FakeWriteV1Fallback extends DataSourceRegister - with FakeStreamWriteSupport with StreamSinkProvider { +class FakeWriteSupportProviderV1Fallback extends DataSourceRegister + with FakeStreamingWriteSupportProvider with StreamSinkProvider { override def createSink( sqlContext: SQLContext, @@ -190,11 +191,11 @@ class StreamingDataSourceV2Suite extends StreamTest { val v2Query = testPositiveCase( "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink - .isInstanceOf[FakeWriteV1Fallback]) + .isInstanceOf[FakeWriteSupportProviderV1Fallback]) // Ensure we create a V1 sink with the config. Note the config is a comma separated // list, including other fake entries. - val fullSinkName = "org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback" + val fullSinkName = classOf[FakeWriteSupportProviderV1Fallback].getName withSQLConf(SQLConf.DISABLED_V2_STREAMING_WRITERS.key -> s"a,b,c,test,$fullSinkName,d,e") { val v1Query = testPositiveCase( "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) @@ -218,35 +219,37 @@ class StreamingDataSourceV2Suite extends StreamTest { val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() (readSource, writeSource, trigger) match { // Valid microbatch queries. - case (_: MicroBatchReadSupport, _: StreamWriteSupport, t) + case (_: MicroBatchReadSupportProvider, _: StreamingWriteSupportProvider, t) if !t.isInstanceOf[ContinuousTrigger] => testPositiveCase(read, write, trigger) // Valid continuous queries. - case (_: ContinuousReadSupport, _: StreamWriteSupport, _: ContinuousTrigger) => + case (_: ContinuousReadSupportProvider, _: StreamingWriteSupportProvider, + _: ContinuousTrigger) => testPositiveCase(read, write, trigger) // Invalid - can't read at all case (r, _, _) - if !r.isInstanceOf[MicroBatchReadSupport] - && !r.isInstanceOf[ContinuousReadSupport] => + if !r.isInstanceOf[MicroBatchReadSupportProvider] + && !r.isInstanceOf[ContinuousReadSupportProvider] => testNegativeCase(read, write, trigger, s"Data source $read does not support streamed reading") // Invalid - can't write - case (_, w, _) if !w.isInstanceOf[StreamWriteSupport] => + case (_, w, _) if !w.isInstanceOf[StreamingWriteSupportProvider] => testNegativeCase(read, write, trigger, s"Data source $write does not support streamed writing") // Invalid - trigger is continuous but reader is not - case (r, _: StreamWriteSupport, _: ContinuousTrigger) - if !r.isInstanceOf[ContinuousReadSupport] => + case (r, _: StreamingWriteSupportProvider, _: ContinuousTrigger) + if !r.isInstanceOf[ContinuousReadSupportProvider] => testNegativeCase(read, write, trigger, s"Data source $read does not support continuous processing") // Invalid - trigger is microbatch but reader is not case (r, _, t) - if !r.isInstanceOf[MicroBatchReadSupport] && !t.isInstanceOf[ContinuousTrigger] => + if !r.isInstanceOf[MicroBatchReadSupportProvider] && + !t.isInstanceOf[ContinuousTrigger] => testPostCreationNegativeCase(read, write, trigger, s"Data source $read does not support microbatch processing") }