From 9f63721677cea627f43f7d536bb32b588cee30a3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 27 Aug 2018 23:20:08 +0800 Subject: [PATCH] data source V2 read side API refactoring --- ...scala => KafkaContinuousInputStream.scala} | 106 +++++----- ...scala => KafkaMicroBatchInputStream.scala} | 181 ++++++++++-------- .../sql/kafka010/KafkaSourceProvider.scala | 166 ++++++++-------- .../kafka010/KafkaContinuousSourceSuite.scala | 14 +- .../sql/kafka010/KafkaContinuousTest.scala | 10 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 25 ++- .../sources/v2/BatchReadSupportProvider.java | 61 ------ .../v2/ContinuousReadSupportProvider.java | 70 ------- .../spark/sql/sources/v2/DataSourceV2.java | 2 +- .../apache/spark/sql/sources/v2/Format.java | 60 ++++++ .../v2/MicroBatchReadSupportProvider.java | 70 ------- .../sql/sources/v2/SupportsBatchRead.java | 40 ++++ .../sources/v2/SupportsContinuousRead.java | 47 +++++ .../sources/v2/SupportsMicroBatchRead.java | 47 +++++ .../BatchReadSupport.java => Table.java} | 42 ++-- .../sql/sources/v2/reader/BatchScan.java | 43 +++++ .../sql/sources/v2/reader/InputPartition.java | 2 +- .../v2/reader/{ReadSupport.java => Scan.java} | 23 +-- .../sql/sources/v2/reader/ScanConfig.java | 15 +- .../sql/sources/v2/reader/Statistics.java | 2 +- .../v2/reader/SupportsReportPartitioning.java | 10 +- .../v2/reader/SupportsReportStatistics.java | 8 +- .../v2/reader/partitioning/Partitioning.java | 5 +- .../streaming/ContinuousInputStream.java | 53 +++++ .../streaming/ContinuousReadSupport.java | 77 -------- .../v2/reader/streaming/ContinuousScan.java | 53 +++++ ...amingReadSupport.java => InputStream.java} | 14 +- .../streaming/MicroBatchInputStream.java | 38 ++++ .../streaming/MicroBatchReadSupport.java | 60 ------ .../v2/reader/streaming/MicroBatchScan.java | 48 +++++ .../sources/v2/reader/streaming/Offset.java | 4 +- .../apache/spark/sql/DataFrameReader.scala | 27 ++- .../apache/spark/sql/DataFrameWriter.scala | 4 +- .../datasources/v2/DataSourceV2Relation.scala | 87 +++++---- .../datasources/v2/DataSourceV2ScanExec.scala | 36 ++-- .../datasources/v2/DataSourceV2Strategy.scala | 77 +++++++- .../v2/NoopScanConfigBuilder.scala} | 23 +-- .../streaming/MicroBatchExecution.scala | 158 +++++++++------ .../streaming/ProgressReporter.scala | 18 +- .../streaming/StreamingRelation.scala | 78 +++++--- .../continuous/ContinuousExecution.scala | 172 +++++++++-------- .../ContinuousRateStreamSource.scala | 63 +++--- .../ContinuousTextSocketSource.scala | 53 +++-- .../continuous/EpochCoordinator.scala | 10 +- .../sql/execution/streaming/memory.scala | 81 +++++--- .../sources/ContinuousMemoryStream.scala | 60 +++--- ...=> RateControlMicroBatchInputStream.scala} | 6 +- ... => RateStreamMicroBatchInputStream.scala} | 80 ++++---- .../sources/RateStreamProvider.scala | 45 +++-- ... => TextSocketMicroBatchInputStream.scala} | 132 ++++--------- .../sources/TextSocketSourceProvider.scala | 95 +++++++++ .../sql/streaming/DataStreamReader.scala | 69 ++----- .../sources/v2/JavaAdvancedDataSourceV2.java | 46 +++-- .../sources/v2/JavaColumnarDataSourceV2.java | 17 +- .../v2/JavaPartitionAwareDataSource.java | 22 +-- .../v2/JavaSchemaRequiredDataSource.java | 22 +-- ...ort.java => JavaSimpleBatchReadTable.java} | 33 +--- .../sources/v2/JavaSimpleDataSourceV2.java | 15 +- .../sources/RateStreamProviderSuite.scala | 97 ++++------ .../sources/TextSocketStreamSuite.scala | 85 ++++---- .../sql/sources/v2/DataSourceV2Suite.scala | 159 +++++++-------- .../sources/v2/SimpleWritableDataSource.scala | 20 +- .../spark/sql/streaming/StreamSuite.scala | 154 +++++++++++---- .../spark/sql/streaming/StreamTest.scala | 8 +- .../StreamingQueryManagerSuite.scala | 9 +- .../sql/streaming/StreamingQuerySuite.scala | 11 +- .../ContinuousQueuedDataReaderSuite.scala | 4 +- .../continuous/ContinuousSuite.scala | 12 +- .../continuous/EpochCoordinatorSuite.scala | 6 +- .../sources/StreamingDataSourceV2Suite.scala | 86 +++++---- 70 files changed, 1966 insertions(+), 1610 deletions(-) rename external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/{KafkaContinuousReadSupport.scala => KafkaContinuousInputStream.scala} (83%) rename external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/{KafkaMicroBatchReadSupport.scala => KafkaMicroBatchInputStream.scala} (92%) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/Format.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{reader/BatchReadSupport.java => Table.java} (52%) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchScan.java rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ReadSupport.java => Scan.java} (68%) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputStream.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousScan.java rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/{StreamingReadSupport.java => InputStream.java} (76%) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchInputStream.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchScan.java rename sql/core/src/main/scala/org/apache/spark/sql/execution/{streaming/SimpleStreamingScanConfigBuilder.scala => datasources/v2/NoopScanConfigBuilder.scala} (62%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/{RateControlMicroBatchReadSupport.scala => RateControlMicroBatchInputStream.scala} (87%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/{RateStreamMicroBatchReadSupport.scala => RateStreamMicroBatchInputStream.scala} (84%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/{socket.scala => TextSocketMicroBatchInputStream.scala} (62%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala rename sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/{JavaSimpleReadSupport.java => JavaSimpleBatchReadTable.java} (78%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousInputStream.scala similarity index 83% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousInputStream.scala index 1753a28fba2f..bd301503d1db 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousInputStream.scala @@ -30,10 +30,9 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming._ -import org.apache.spark.sql.types.StructType /** - * A [[ContinuousReadSupport]] for data from kafka. + * A [[ContinuousInputStream]] that reads data from Kafka. * * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be * read by per-task consumers generated later. @@ -46,17 +45,22 @@ import org.apache.spark.sql.types.StructType * scenarios, where some offsets after the specified initial ones can't be * properly read. */ -class KafkaContinuousReadSupport( +class KafkaContinuousInputStream( offsetReader: KafkaOffsetReader, kafkaParams: ju.Map[String, Object], sourceOptions: Map[String, String], metadataPath: String, initialOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends ContinuousReadSupport with Logging { + extends ContinuousInputStream with Logging { private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong + // Initialized when creating read support. If this diverges from the partitions at the latest + // offsets, we need to reconfigure. + // Exposed outside this object only for unit tests. + @volatile private[sql] var knownPartitions: Set[TopicPartition] = _ + override def initialOffset(): Offset = { val offsets = initialOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) @@ -67,28 +71,29 @@ class KafkaContinuousReadSupport( offsets } - override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema - - override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { - new KafkaContinuousScanConfigBuilder(fullSchema(), start, offsetReader, reportDataLoss) - } - override def deserializeOffset(json: String): Offset = { KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val startOffsets = config.asInstanceOf[KafkaContinuousScanConfig].startOffsets - startOffsets.toSeq.map { - case (topicPartition, start) => - KafkaContinuousInputPartition( - topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) - }.toArray - } + override def createContinuousScan(start: Offset): ContinuousScan = { + val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(start) - override def createContinuousReaderFactory( - config: ScanConfig): ContinuousPartitionReaderFactory = { - KafkaContinuousReaderFactory + val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet + val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) + val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) + + val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"Some partitions were deleted: $deletedPartitions") + } + + val startOffsets = newPartitionOffsets ++ + oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) + + knownPartitions = startOffsets.keySet + + new KafkaContinuousScan( + offsetReader, kafkaParams, pollTimeoutMs, failOnDataLoss, startOffsets) } /** Stop this source and free any resources it has allocated. */ @@ -105,9 +110,8 @@ class KafkaContinuousReadSupport( KafkaSourceOffset(mergedMap) } - override def needsReconfiguration(config: ScanConfig): Boolean = { - val knownPartitions = config.asInstanceOf[KafkaContinuousScanConfig].knownPartitions - offsetReader.fetchLatestOffsets().keySet != knownPartitions + override def needsReconfiguration(): Boolean = { + knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions } override def toString(): String = s"KafkaSource[$offsetReader]" @@ -125,6 +129,25 @@ class KafkaContinuousReadSupport( } } +class KafkaContinuousScan( + offsetReader: KafkaOffsetReader, + kafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean, + startOffsets: Map[TopicPartition, Long]) extends ContinuousScan { + + override def createContinuousReaderFactory(): ContinuousPartitionReaderFactory = { + KafkaContinuousReaderFactory + } + + override def planInputPartitions(): Array[InputPartition] = { + startOffsets.toSeq.map { case (topicPartition, start) => + KafkaContinuousInputPartition( + topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) + }.toArray + } +} + /** * An input partition for continuous Kafka processing. This will be serialized and transformed * into a full reader on executors. @@ -151,41 +174,6 @@ object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory { } } -class KafkaContinuousScanConfigBuilder( - schema: StructType, - startOffset: Offset, - offsetReader: KafkaOffsetReader, - reportDataLoss: String => Unit) - extends ScanConfigBuilder { - - override def build(): ScanConfig = { - val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(startOffset) - - val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet - val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) - val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) - - val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) - if (deletedPartitions.nonEmpty) { - reportDataLoss(s"Some partitions were deleted: $deletedPartitions") - } - - val startOffsets = newPartitionOffsets ++ - oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) - KafkaContinuousScanConfig(schema, startOffsets) - } -} - -case class KafkaContinuousScanConfig( - readSchema: StructType, - startOffsets: Map[TopicPartition, Long]) - extends ScanConfig { - - // Created when building the scan config builder. If this diverges from the partitions at the - // latest offsets, we need to reconfigure the kafka read support. - def knownPartitions: Set[TopicPartition] = startOffsets.keySet -} - /** * A per-task data reader for continuous Kafka processing. * diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchInputStream.scala similarity index 92% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchInputStream.scala index bb4de674c3c7..afacd81043fa 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchInputStream.scala @@ -29,17 +29,16 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} -import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchReadSupport +import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} +import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchInputStream import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchInputStream, MicroBatchScan, Offset} import org.apache.spark.util.UninterruptibleThread /** - * A [[MicroBatchReadSupport]] that reads data from Kafka. + * A [[MicroBatchInputStream]] that reads data from Kafka. * * The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains * a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For @@ -54,13 +53,13 @@ import org.apache.spark.util.UninterruptibleThread * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers * and not use wrong broker addresses. */ -private[kafka010] class KafkaMicroBatchReadSupport( +private[kafka010] class KafkaMicroBatchInputStream( kafkaOffsetReader: KafkaOffsetReader, executorKafkaParams: ju.Map[String, Object], options: DataSourceOptions, metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, - failOnDataLoss: Boolean) extends RateControlMicroBatchReadSupport with Logging { + failOnDataLoss: Boolean) extends RateControlMicroBatchInputStream with Logging { private val pollTimeoutMs = options.getLong( "kafkaConsumer.pollTimeoutMs", @@ -93,65 +92,16 @@ private[kafka010] class KafkaMicroBatchReadSupport( endPartitionOffsets } - override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema - - override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) - } - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val sc = config.asInstanceOf[SimpleStreamingScanConfig] - val startPartitionOffsets = sc.start.asInstanceOf[KafkaSourceOffset].partitionToOffsets - val endPartitionOffsets = sc.end.get.asInstanceOf[KafkaSourceOffset].partitionToOffsets - - // Find the new partitions, and get their earliest offsets - val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) - val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) - if (newPartitionInitialOffsets.keySet != newPartitions) { - // We cannot get from offsets for some partitions. It means they got deleted. - val deletedPartitions = newPartitions.diff(newPartitionInitialOffsets.keySet) - reportDataLoss( - s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed") - } - logInfo(s"Partitions added: $newPartitionInitialOffsets") - newPartitionInitialOffsets.filter(_._2 != 0).foreach { case (p, o) => - reportDataLoss( - s"Added partition $p starts from $o instead of 0. Some data may have been missed") - } - - // Find deleted partitions, and report data loss if required - val deletedPartitions = startPartitionOffsets.keySet.diff(endPartitionOffsets.keySet) - if (deletedPartitions.nonEmpty) { - reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed") - } - - // Use the end partitions to calculate offset ranges to ignore partitions that have - // been deleted - val topicPartitions = endPartitionOffsets.keySet.filter { tp => - // Ignore partitions that we don't know the from offsets. - newPartitionInitialOffsets.contains(tp) || startPartitionOffsets.contains(tp) - }.toSeq - logDebug("TopicPartitions: " + topicPartitions.mkString(", ")) - - // Calculate offset ranges - val offsetRanges = rangeCalculator.getRanges( - fromOffsets = startPartitionOffsets ++ newPartitionInitialOffsets, - untilOffsets = endPartitionOffsets, - executorLocations = getSortedExecutorList()) - - // Reuse Kafka consumers only when all the offset ranges have distinct TopicPartitions, - // that is, concurrent tasks will not read the same TopicPartitions. - val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size - - // Generate factories based on the offset ranges - offsetRanges.map { range => - KafkaMicroBatchInputPartition( - range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) - }.toArray - } - - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - KafkaMicroBatchReaderFactory + override def createMicroBatchScan(start: Offset, end: Offset): MicroBatchScan = { + new KafkaMicroBatchScan( + kafkaOffsetReader, + rangeCalculator, + executorKafkaParams, + pollTimeoutMs, + failOnDataLoss, + reportDataLoss, + start.asInstanceOf[KafkaSourceOffset], + end.asInstanceOf[KafkaSourceOffset]) } override def deserializeOffset(json: String): Offset = { @@ -229,23 +179,6 @@ private[kafka010] class KafkaMicroBatchReadSupport( } } - private def getSortedExecutorList(): Array[String] = { - - def compare(a: ExecutorCacheTaskLocation, b: ExecutorCacheTaskLocation): Boolean = { - if (a.host == b.host) { - a.executorId > b.executorId - } else { - a.host > b.host - } - } - - val bm = SparkEnv.get.blockManager - bm.master.getPeers(bm.blockManagerId).toArray - .map(x => ExecutorCacheTaskLocation(x.host, x.executorId)) - .sortWith(compare) - .map(_.toString) - } - /** * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. * Otherwise, just log a warning. @@ -294,6 +227,88 @@ private[kafka010] class KafkaMicroBatchReadSupport( } } +private[kafka010] class KafkaMicroBatchScan( + kafkaOffsetReader: KafkaOffsetReader, + rangeCalculator: KafkaOffsetRangeCalculator, + executorKafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean, + reportDataLoss: String => Unit, + start: KafkaSourceOffset, + end: KafkaSourceOffset) extends MicroBatchScan with Logging { + + override def createReaderFactory(): PartitionReaderFactory = { + KafkaMicroBatchReaderFactory + } + + override def planInputPartitions(): Array[InputPartition] = { + val startPartitionOffsets = start.partitionToOffsets + val endPartitionOffsets = end.partitionToOffsets + + // Find the new partitions, and get their earliest offsets + val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) + val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) + if (newPartitionInitialOffsets.keySet != newPartitions) { + // We cannot get from offsets for some partitions. It means they got deleted. + val deletedPartitions = newPartitions.diff(newPartitionInitialOffsets.keySet) + reportDataLoss( + s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed") + } + logInfo(s"Partitions added: $newPartitionInitialOffsets") + newPartitionInitialOffsets.filter(_._2 != 0).foreach { case (p, o) => + reportDataLoss( + s"Added partition $p starts from $o instead of 0. Some data may have been missed") + } + + // Find deleted partitions, and report data loss if required + val deletedPartitions = startPartitionOffsets.keySet.diff(endPartitionOffsets.keySet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed") + } + + // Use the end partitions to calculate offset ranges to ignore partitions that have + // been deleted + val topicPartitions = endPartitionOffsets.keySet.filter { tp => + // Ignore partitions that we don't know the from offsets. + newPartitionInitialOffsets.contains(tp) || startPartitionOffsets.contains(tp) + }.toSeq + logDebug("TopicPartitions: " + topicPartitions.mkString(", ")) + + // Calculate offset ranges + val offsetRanges = rangeCalculator.getRanges( + fromOffsets = startPartitionOffsets ++ newPartitionInitialOffsets, + untilOffsets = endPartitionOffsets, + executorLocations = getSortedExecutorList()) + + // Reuse Kafka consumers only when all the offset ranges have distinct TopicPartitions, + // that is, concurrent tasks will not read the same TopicPartitions. + val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size + + // Generate factories based on the offset ranges + offsetRanges.map { range => + KafkaMicroBatchInputPartition( + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) + }.toArray + } + + private def getSortedExecutorList(): Array[String] = { + + def compare(a: ExecutorCacheTaskLocation, b: ExecutorCacheTaskLocation): Boolean = { + if (a.host == b.host) { + a.executorId > b.executorId + } else { + a.host > b.host + } + } + + val bm = SparkEnv.get.blockManager + bm.master.getPeers(bm.blockManagerId).toArray + .map(x => ExecutorCacheTaskLocation(x.host, x.executorId)) + .sortWith(compare) + .map(_.toString) + } +} + /** A [[InputPartition]] for reading Kafka data in a micro-batch streaming query. */ private[kafka010] case class KafkaMicroBatchInputPartition( offsetRange: KafkaOffsetRange, diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 28c9853bfea9..86f3f38837e7 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, Optional, UUID} +import java.util.{Locale, UUID} import scala.collection.JavaConverters._ @@ -31,6 +31,8 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSessio import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.ScanConfig +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, MicroBatchInputStream} import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -46,8 +48,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with RelationProvider with CreatableRelationProvider with StreamingWriteSupportProvider - with ContinuousReadSupportProvider - with MicroBatchReadSupportProvider + with Format with Logging { import KafkaSourceProvider._ @@ -106,85 +107,96 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } - /** - * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport]] to read - * batches of Kafka data in a micro-batch streaming query. - */ - override def createMicroBatchReadSupport( - metadataPath: String, - options: DataSourceOptions): KafkaMicroBatchReadSupport = { - - val parameters = options.asMap().asScala.toMap - validateStreamOptions(parameters) - // Each running query should use its own group id. Otherwise, the query may be only assigned - // partial data since Kafka will assign partitions to multiple consumers having the same group - // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" - - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - val specifiedKafkaParams = - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap - - val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, - STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) - - val kafkaOffsetReader = new KafkaOffsetReader( - strategy(caseInsensitiveParams), - kafkaParamsForDriver(specifiedKafkaParams), - parameters, - driverGroupIdPrefix = s"$uniqueGroupId-driver") - - new KafkaMicroBatchReadSupport( - kafkaOffsetReader, - kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), - options, - metadataPath, - startingStreamOffsets, - failOnDataLoss(caseInsensitiveParams)) + override def getTable(options: DataSourceOptions): KafkaTable.type = { + KafkaTable } - /** - * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport]] to read - * Kafka data in a continuous streaming query. - */ - override def createContinuousReadSupport( - metadataPath: String, - options: DataSourceOptions): KafkaContinuousReadSupport = { - val parameters = options.asMap().asScala.toMap - validateStreamOptions(parameters) - // Each running query should use its own group id. Otherwise, the query may be only assigned - // partial data since Kafka will assign partitions to multiple consumers having the same group - // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" - - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - val specifiedKafkaParams = - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + object KafkaTable extends Table + with SupportsMicroBatchRead with SupportsContinuousRead { - val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, - STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + override def schema(): StructType = KafkaOffsetReader.kafkaSchema - val kafkaOffsetReader = new KafkaOffsetReader( - strategy(caseInsensitiveParams), - kafkaParamsForDriver(specifiedKafkaParams), - parameters, - driverGroupIdPrefix = s"$uniqueGroupId-driver") + /** + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchInputStream]] to read + * batches of Kafka data in a micro-batch streaming query. + */ + override def createMicroBatchInputStream( + checkpointLocation: String, + config: ScanConfig, + options: DataSourceOptions): MicroBatchInputStream = { + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${checkpointLocation.hashCode}" + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( + caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(caseInsensitiveParams), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaMicroBatchInputStream( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + options, + checkpointLocation, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } - new KafkaContinuousReadSupport( - kafkaOffsetReader, - kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), - parameters, - metadataPath, - startingStreamOffsets, - failOnDataLoss(caseInsensitiveParams)) + /** + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputStream]] to read + * Kafka data in a continuous streaming query. + */ + override def createContinuousInputStream( + checkpointLocation: String, + config: ScanConfig, + options: DataSourceOptions): ContinuousInputStream = { + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${checkpointLocation.hashCode}" + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( + caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(caseInsensitiveParams), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaContinuousInputStream( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + parameters, + checkpointLocation, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } } /** diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index af510219a6f6..f2b796b78a34 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.clients.producer.ProducerRecord import org.apache.spark.sql.Dataset -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.streaming.Trigger @@ -207,13 +207,13 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { testUtils.createTopic(topic2, partitions = 5) eventually(timeout(streamingTimeout)) { assert( - query.lastExecution.executedPlan.collectFirst { - case scan: DataSourceV2ScanExec - if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] => - scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig] - }.exists { config => + query.lastExecution.logical.collectFirst { + case r: StreamingDataSourceV2Relation + if r.stream.isInstanceOf[KafkaContinuousInputStream] => + r.stream.asInstanceOf[KafkaContinuousInputStream] + }.exists { stream => // Ensure the new topic is present and the old topic is gone. - config.knownPartitions.exists(_.topic == topic2) + stream.knownPartitions.exists(_.topic == topic2) }, s"query never reconfigured to new topic $topic2") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index fa6bdc20bd4f..e7ada6b52c37 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.streaming.Trigger @@ -46,10 +46,10 @@ trait KafkaContinuousTest extends KafkaSourceTest { testUtils.addPartitions(topic, newCount) eventually(timeout(streamingTimeout)) { assert( - query.lastExecution.executedPlan.collectFirst { - case scan: DataSourceV2ScanExec - if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] => - scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig] + query.lastExecution.logical.collectFirst { + case r: StreamingDataSourceV2Relation + if r.stream.isInstanceOf[KafkaContinuousInputStream] => + r.stream.asInstanceOf[KafkaContinuousInputStream] }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 5ee76990b54f..b8712f0ecb78 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -117,13 +117,15 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with Kaf val sources: Seq[BaseStreamingSource] = { query.get.logicalPlan.collect { case StreamingExecutionRelation(source: KafkaSource, _) => source - case StreamingExecutionRelation(source: KafkaMicroBatchReadSupport, _) => source + case r: StreamingDataSourceV2Relation + if r.stream.isInstanceOf[KafkaMicroBatchInputStream] => + r.stream.asInstanceOf[KafkaMicroBatchInputStream] } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { case r: StreamingDataSourceV2Relation - if r.readSupport.isInstanceOf[KafkaContinuousReadSupport] => - r.readSupport.asInstanceOf[KafkaContinuousReadSupport] + if r.stream.isInstanceOf[KafkaContinuousInputStream] => + r.stream.asInstanceOf[KafkaContinuousInputStream] } }) }.distinct @@ -978,7 +980,8 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { makeSureGetOffsetCalled, AssertOnQuery { query => query.logicalPlan.collect { - case StreamingExecutionRelation(_: KafkaMicroBatchReadSupport, _) => true + case r: StreamingDataSourceV2Relation + if r.stream.isInstanceOf[KafkaMicroBatchInputStream] => true }.nonEmpty } ) @@ -1003,12 +1006,14 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { "kafka.bootstrap.servers" -> testUtils.brokerAddress, "subscribe" -> topic ) ++ Option(minPartitions).map { p => "minPartitions" -> p} - val readSupport = provider.createMicroBatchReadSupport( - dir.getAbsolutePath, new DataSourceOptions(options.asJava)) - val config = readSupport.newScanConfigBuilder( + val dsOptions = new DataSourceOptions(options.asJava) + val table = provider.getTable(dsOptions) + val config = table.newScanConfigBuilder(dsOptions).build() + val stream = table.createMicroBatchInputStream(dir.getAbsolutePath, config, dsOptions) + val scan = stream.createMicroBatchScan( KafkaSourceOffset(Map(tp -> 0L)), - KafkaSourceOffset(Map(tp -> 100L))).build() - val inputPartitions = readSupport.planInputPartitions(config) + KafkaSourceOffset(Map(tp -> 100L))) + val inputPartitions = scan.planInputPartitions() .map(_.asInstanceOf[KafkaMicroBatchInputPartition]) withClue(s"minPartitions = $minPartitions generated factories $inputPartitions\n\t") { assert(inputPartitions.size == numPartitionsGenerated) @@ -1326,7 +1331,7 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { val reader = spark .readStream .format("kafka") - .option("startingOffsets", s"latest") + .option("startingOffsets", "latest") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") .option("failOnDataLoss", failOnDataLoss.toString) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java deleted file mode 100644 index f403dc619e86..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; -import org.apache.spark.sql.sources.v2.reader.BatchReadSupport; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability for batch processing. - * - * This interface is used to create {@link BatchReadSupport} instances when end users run - * {@code SparkSession.read.format(...).option(...).load()}. - */ -@InterfaceStability.Evolving -public interface BatchReadSupportProvider extends DataSourceV2 { - - /** - * Creates a {@link BatchReadSupport} instance to load the data from this data source with a user - * specified schema, which is called by Spark at the beginning of each batch query. - * - * Spark will call this method at the beginning of each batch query to create a - * {@link BatchReadSupport} instance. - * - * By default this method throws {@link UnsupportedOperationException}, implementations should - * override this method to handle user specified schema. - * - * @param schema the user specified schema. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - default BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) { - return DataSourceV2Utils.failForUserSpecifiedSchema(this); - } - - /** - * Creates a {@link BatchReadSupport} instance to scan the data from this data source, which is - * called by Spark at the beginning of each batch query. - * - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - BatchReadSupport createBatchReadSupport(DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java deleted file mode 100644 index 824c290518ac..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability for continuous stream processing. - * - * This interface is used to create {@link ContinuousReadSupport} instances when end users run - * {@code SparkSession.readStream.format(...).option(...).load()} with a continuous trigger. - */ -@InterfaceStability.Evolving -public interface ContinuousReadSupportProvider extends DataSourceV2 { - - /** - * Creates a {@link ContinuousReadSupport} instance to scan the data from this streaming data - * source with a user specified schema, which is called by Spark at the beginning of each - * continuous streaming query. - * - * By default this method throws {@link UnsupportedOperationException}, implementations should - * override this method to handle user specified schema. - * - * @param schema the user provided schema. - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - default ContinuousReadSupport createContinuousReadSupport( - StructType schema, - String checkpointLocation, - DataSourceOptions options) { - return DataSourceV2Utils.failForUserSpecifiedSchema(this); - } - - /** - * Creates a {@link ContinuousReadSupport} instance to scan the data from this streaming data - * source, which is called by Spark at the beginning of each continuous streaming query. - * - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - ContinuousReadSupport createContinuousReadSupport( - String checkpointLocation, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java index 6e31e84bf6c7..257586a4a135 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java @@ -23,7 +23,7 @@ * The base interface for data source v2. Implementations must have a public, 0-arg constructor. * * Note that this is an empty interface. Data source implementations must mix in interfaces such as - * {@link BatchReadSupportProvider} or {@link BatchWriteSupportProvider}, which can provide + * {@link SupportsBatchRead} or {@link BatchWriteSupportProvider}, which can provide * batch or streaming read/write support instances. Otherwise it's just a dummy data source which * is un-readable/writable. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Format.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Format.java new file mode 100644 index 000000000000..6b54007ba8c2 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Format.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.DataSourceRegister; +import org.apache.spark.sql.types.StructType; + +/** + * The base interface for data source v2. Implementations must have a public, 0-arg constructor. + * + * The major responsibility of this interface is to return a {@link Table} for read/write. + */ +@InterfaceStability.Evolving +public interface Format extends DataSourceV2 { + + /** + * Return a {@link Table} instance to do read/write with user-specified options. + * + * @param options the user-specified options that can identify a table, e.g. path, table name, + * Kafka topic name, etc. It's an immutable case-insensitive string-to-string map. + */ + Table getTable(DataSourceOptions options); + + /** + * Return a {@link Table} instance to do read/write with user-specified schema and options. + * + * By default this method throws {@link UnsupportedOperationException}, implementations should + * override this method to handle user-specified schema. + * + * @param options the user-specified options that can identify a table, e.g. path, table name, + * Kafka topic name, etc. It's an immutable case-insensitive string-to-string map. + * @param schema the user-specified schema. + */ + default Table getTable(DataSourceOptions options, StructType schema) { + String name; + if (this instanceof DataSourceRegister) { + name = ((DataSourceRegister) this).shortName(); + } else { + name = this.getClass().getName(); + } + throw new UnsupportedOperationException( + name + " source does not support user-specified schema"); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java deleted file mode 100644 index 61c08e7fa89d..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability for micro-batch stream processing. - * - * This interface is used to create {@link MicroBatchReadSupport} instances when end users run - * {@code SparkSession.readStream.format(...).option(...).load()} with a micro-batch trigger. - */ -@InterfaceStability.Evolving -public interface MicroBatchReadSupportProvider extends DataSourceV2 { - - /** - * Creates a {@link MicroBatchReadSupport} instance to scan the data from this streaming data - * source with a user specified schema, which is called by Spark at the beginning of each - * micro-batch streaming query. - * - * By default this method throws {@link UnsupportedOperationException}, implementations should - * override this method to handle user specified schema. - * - * @param schema the user provided schema. - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - default MicroBatchReadSupport createMicroBatchReadSupport( - StructType schema, - String checkpointLocation, - DataSourceOptions options) { - return DataSourceV2Utils.failForUserSpecifiedSchema(this); - } - - /** - * Creates a {@link MicroBatchReadSupport} instance to scan the data from this streaming data - * source, which is called by Spark at the beginning of each micro-batch streaming query. - * - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - MicroBatchReadSupport createMicroBatchReadSupport( - String checkpointLocation, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java new file mode 100644 index 000000000000..be2ab028fe77 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.BatchScan; +import org.apache.spark.sql.sources.v2.reader.ScanConfig; +import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder; + +/** + * A mix-in interface for {@link Table}. Table implementations can mixin this interface to + * provide data reading ability for batch processing. + */ +@InterfaceStability.Evolving +public interface SupportsBatchRead extends Table { + + /** + * Creates a {@link BatchScan} instance with a {@link ScanConfig} and user-specified options. + * + * @param config a {@link ScanConfig} which may contains operator pushdown information. + * @param options the user-specified options, which is same as the one used to create the + * {@link ScanConfigBuilder} that built the given {@link ScanConfig}. + */ + BatchScan createBatchScan(ScanConfig config, DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java new file mode 100644 index 000000000000..6773a5b40d8c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.ScanConfig; +import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder; +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputStream; + +/** + * A mix-in interface for {@link Table}. Table implementations can mixin this interface to + * provide data reading ability for continuous stream processing. + */ +@InterfaceStability.Evolving +public interface SupportsContinuousRead extends Table { + + /** + * Creates a {@link ContinuousInputStream} instance with a checkpoint location, a + * {@link ScanConfig} and user-specified options. + * + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Input streams for the same logical source in the same query + * will be given the same checkpointLocation. + * @param config a {@link ScanConfig} which may contains operator pushdown information. + * @param options the user-specified options, which is same as the one used to create the + * {@link ScanConfigBuilder} that built the given {@link ScanConfig}. + */ + ContinuousInputStream createContinuousInputStream( + String checkpointLocation, + ScanConfig config, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java new file mode 100644 index 000000000000..04818e3a602d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.ScanConfig; +import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder; +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchInputStream; + +/** + * A mix-in interface for {@link Table}. Table implementations can mixin this interface to + * provide data reading ability for micro-batch stream processing. + */ +@InterfaceStability.Evolving +public interface SupportsMicroBatchRead extends Table { + + /** + * Creates a {@link MicroBatchInputStream} instance with a checkpoint location, a + * {@link ScanConfig} and user-specified options. + * + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Input streams for the same logical source in the same query + * will be given the same checkpointLocation. + * @param config a {@link ScanConfig} which may contains operator pushdown information. + * @param options the user-specified options, which is same as the one used to create the + * {@link ScanConfigBuilder} that built the given {@link ScanConfig}. + */ + MicroBatchInputStream createMicroBatchInputStream( + String checkpointLocation, + ScanConfig config, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java similarity index 52% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java index 452ee86675b4..3315306c8aa6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java @@ -15,22 +15,34 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.datasources.v2.NoopScanConfigBuilder; +import org.apache.spark.sql.sources.v2.reader.ScanConfig; +import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder; +import org.apache.spark.sql.types.StructType; /** - * An interface that defines how to load the data from data source for batch processing. + * An interface representing a logical structured data set of a data source. For example, the + * implementation can be a directory on the file system, or a table in the catalog, etc. * - * The execution engine will get an instance of this interface from a data source provider - * (e.g. {@link org.apache.spark.sql.sources.v2.BatchReadSupportProvider}) at the start of a batch - * query, then call {@link #newScanConfigBuilder()} and create an instance of {@link ScanConfig}. - * The {@link ScanConfigBuilder} can apply operator pushdown and keep the pushdown result in - * {@link ScanConfig}. The {@link ScanConfig} will be used to create input partitions and reader - * factory to scan data from the data source with a Spark job. + * This interface can mixin the following interfaces to support different operations: + * */ @InterfaceStability.Evolving -public interface BatchReadSupport extends ReadSupport { +public interface Table { + + /** + * Returns the schema of this table. + */ + StructType schema(); /** * Returns a builder of {@link ScanConfig}. Spark will call this method and create a @@ -38,14 +50,8 @@ public interface BatchReadSupport extends ReadSupport { * * The builder can take some query specific information to do operators pushdown, and keep these * information in the created {@link ScanConfig}. - * - * This is the first step of the data scan. All other methods in {@link BatchReadSupport} needs - * to take {@link ScanConfig} as an input. - */ - ScanConfigBuilder newScanConfigBuilder(); - - /** - * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. */ - PartitionReaderFactory createReaderFactory(ScanConfig config); + default ScanConfigBuilder newScanConfigBuilder(DataSourceOptions options) { + return new NoopScanConfigBuilder(schema()); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchScan.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchScan.java new file mode 100644 index 000000000000..c97357dced11 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchScan.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.SupportsBatchRead; +import org.apache.spark.sql.sources.v2.Table; + +/** + * A {@link Scan} for batch queries. + * + * The execution engine will get an instance of {@link Table} first, then call + * {@link Table#newScanConfigBuilder(DataSourceOptions)} and create an instance of + * {@link ScanConfig}. The {@link ScanConfigBuilder} can apply operator pushdown and keep the + * pushdown result in {@link ScanConfig}. Then + * {@link SupportsBatchRead#createBatchScan(ScanConfig, DataSourceOptions)} will be called to create + * a {@link BatchScan} instance, which will be used to create input partitions and reader factory to + * scan data from the data source with a Spark job. + */ +@InterfaceStability.Evolving +public interface BatchScan extends Scan { + + /** + * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. + */ + PartitionReaderFactory createReaderFactory(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index 95c30de907e4..cc9ce4694c3f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -23,7 +23,7 @@ /** * A serializable representation of an input partition returned by - * {@link ReadSupport#planInputPartitions(ScanConfig)}. + * {@link Scan#planInputPartitions()}. * * Note that {@link InputPartition} will be serialized and sent to executors, then * {@link PartitionReader} will be created by diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java similarity index 68% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java index a58ddb288f1e..cf9ee11d93bd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java @@ -18,24 +18,19 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.types.StructType; /** - * The base interface for all the batch and streaming read supports. Data sources should implement - * concrete read support interfaces like {@link BatchReadSupport}. + * The base interface for all the batch and streaming scans. Data sources should implement + * concrete scan interfaces like {@link BatchScan}. + * + * A scan is used to create input partitions and reader factory to scan data from the data source + * with a Spark job. * * If Spark fails to execute any methods in the implementations of this interface (by throwing an * exception), the read action will fail and no Spark job will be submitted. */ @InterfaceStability.Evolving -public interface ReadSupport { - - /** - * Returns the full schema of this data source, which is usually the physical schema of the - * underlying storage. This full schema should not be affected by column pruning or other - * optimizations. - */ - StructType fullSchema(); +public interface Scan { /** * Returns a list of {@link InputPartition input partitions}. Each {@link InputPartition} @@ -43,8 +38,8 @@ public interface ReadSupport { * partitions returned here is the same as the number of RDD partitions this scan outputs. * * Note that, this may not be a full scan if the data source supports optimization like filter - * push-down. Implementations should check the input {@link ScanConfig} and adjust the resulting - * {@link InputPartition input partitions}. + * push-down. Implementations should check the {@link ScanConfig} that created this scan and + * adjust the resulting {@link InputPartition input partitions}. */ - InputPartition[] planInputPartitions(ScanConfig config); + InputPartition[] planInputPartitions(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java index 7462ce282058..495334cb67dd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java @@ -22,21 +22,18 @@ /** * An interface that carries query specific information for the data scanning job, like operator - * pushdown information and streaming query offsets. This is defined as an empty interface, and data - * sources should define their own {@link ScanConfig} classes. + * pushdown information. This is defined as an empty interface, and data sources should define + * their own {@link ScanConfig} classes. * - * For APIs that take a {@link ScanConfig} as input, like - * {@link ReadSupport#planInputPartitions(ScanConfig)}, - * {@link BatchReadSupport#createReaderFactory(ScanConfig)} and - * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}, implementations mostly need to - * cast the input {@link ScanConfig} to the concrete {@link ScanConfig} class of the data source. + * {@link Scan} implementations usually need to cast the input {@link ScanConfig} to the concrete + * {@link ScanConfig} class of the data source. */ @InterfaceStability.Evolving public interface ScanConfig { /** - * Returns the actual schema of this data source reader, which may be different from the physical - * schema of the underlying storage, as column pruning or other optimizations may happen. + * Returns the actual schema of this scan, which may be different from the table schema, as + * column pruning or other optimizations may happen. * * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java index 44799c7d4913..031c7a73c367 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java @@ -23,7 +23,7 @@ /** * An interface to represent statistics for a data source, which is returned by - * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}. + * {@link SupportsReportStatistics#estimateStatistics()}. */ @InterfaceStability.Evolving public interface Statistics { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index db62cd451536..cdfc8bd22ab3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -21,17 +21,17 @@ import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; /** - * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to + * A mix in interface for {@link Scan}. Data sources can implement this interface to * report data partitioning and try to avoid shuffle at Spark side. * - * Note that, when a {@link ReadSupport} implementation creates exactly one {@link InputPartition}, + * Note that, when a {@link Scan} implementation creates exactly one {@link InputPartition}, * Spark may avoid adding a shuffle even if the reader does not implement this interface. */ @InterfaceStability.Evolving -public interface SupportsReportPartitioning extends ReadSupport { +public interface SupportsReportPartitioning extends Scan { /** - * Returns the output data partitioning that this reader guarantees. + * Returns the output data partitioning that this scan guarantees. */ - Partitioning outputPartitioning(ScanConfig config); + Partitioning outputPartitioning(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index 1831488ba096..ab50e3ff4098 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -20,7 +20,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to + * A mix in interface for {@link Scan}. Data sources can implement this interface to * report statistics to Spark. * * As of Spark 2.4, statistics are reported to the optimizer before any operator is pushed to the @@ -28,10 +28,10 @@ * not improve query performance until the planner can push operators before getting stats. */ @InterfaceStability.Evolving -public interface SupportsReportStatistics extends ReadSupport { +public interface SupportsReportStatistics extends Scan { /** - * Returns the estimated statistics of this data source scan. + * Returns the estimated statistics of this scan. */ - Statistics estimateStatistics(ScanConfig config); + Statistics estimateStatistics(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index fb0b6f1df43b..f460f6bfe3bb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -19,13 +19,12 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.InputPartition; -import org.apache.spark.sql.sources.v2.reader.ScanConfig; import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; /** * An interface to represent the output data partitioning for a data source, which is returned by - * {@link SupportsReportPartitioning#outputPartitioning(ScanConfig)}. Note that this should work - * like a snapshot. Once created, it should be deterministic and always report the same number of + * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this should work like a + * snapshot. Once created, it should be deterministic and always report the same number of * partitions and the same "satisfy" result for a certain distribution. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputStream.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputStream.java new file mode 100644 index 000000000000..6ff1513a41b6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputStream.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A {@link InputStream} for a streaming query with continuous mode. + */ +@InterfaceStability.Evolving +public interface ContinuousInputStream extends InputStream { + + /** + * Creates a {@link ContinuousScan} instance with a start offset, to scan the data from the start + * offset with a end-less Spark job. The job will be terminated if {@link #needsReconfiguration()} + * returns false, and the execution engine will call this method again, with a different start + * offset, and launch a new end-less Spark job. + */ + ContinuousScan createContinuousScan(Offset start); + + /** + * Merge partitioned offsets coming from {@link ContinuousPartitionReader} instances + * for each partition to a single global offset. + */ + Offset mergeOffsets(PartitionOffset[] offsets); + + /** + * The execution engine will call this method in every epoch to determine if new input + * partitions need to be generated, which may be required if for example the underlying + * source system has had partitions added or removed. + * + * If true, the query will be shut down and restarted with a new {@link ContinuousScan} + * instance. + */ + default boolean needsReconfiguration() { + return false; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java deleted file mode 100644 index 9a3ad2eb8a80..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader.streaming; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSource; -import org.apache.spark.sql.sources.v2.reader.InputPartition; -import org.apache.spark.sql.sources.v2.reader.ScanConfig; -import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder; - -/** - * An interface that defines how to load the data from data source for continuous streaming - * processing. - * - * The execution engine will get an instance of this interface from a data source provider - * (e.g. {@link org.apache.spark.sql.sources.v2.ContinuousReadSupportProvider}) at the start of a - * streaming query, then call {@link #newScanConfigBuilder(Offset)} and create an instance of - * {@link ScanConfig} for the duration of the streaming query or until - * {@link #needsReconfiguration(ScanConfig)} is true. The {@link ScanConfig} will be used to create - * input partitions and reader factory to scan data with a Spark job for its duration. At the end - * {@link #stop()} will be called when the streaming execution is completed. Note that a single - * query may have multiple executions due to restart or failure recovery. - */ -@InterfaceStability.Evolving -public interface ContinuousReadSupport extends StreamingReadSupport, BaseStreamingSource { - - /** - * Returns a builder of {@link ScanConfig}. Spark will call this method and create a - * {@link ScanConfig} for each data scanning job. - * - * The builder can take some query specific information to do operators pushdown, store streaming - * offsets, etc., and keep these information in the created {@link ScanConfig}. - * - * This is the first step of the data scan. All other methods in {@link ContinuousReadSupport} - * needs to take {@link ScanConfig} as an input. - */ - ScanConfigBuilder newScanConfigBuilder(Offset start); - - /** - * Returns a factory, which produces one {@link ContinuousPartitionReader} for one - * {@link InputPartition}. - */ - ContinuousPartitionReaderFactory createContinuousReaderFactory(ScanConfig config); - - /** - * Merge partitioned offsets coming from {@link ContinuousPartitionReader} instances - * for each partition to a single global offset. - */ - Offset mergeOffsets(PartitionOffset[] offsets); - - /** - * The execution engine will call this method in every epoch to determine if new input - * partitions need to be generated, which may be required if for example the underlying - * source system has had partitions added or removed. - * - * If true, the query will be shut down and restarted with a new {@link ContinuousReadSupport} - * instance. - */ - default boolean needsReconfiguration(ScanConfig config) { - return false; - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousScan.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousScan.java new file mode 100644 index 000000000000..9b9090a810ca --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousScan.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.SupportsContinuousRead; +import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.Scan; +import org.apache.spark.sql.sources.v2.reader.ScanConfig; +import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder; + +/** + * A {@link Scan} for streaming queries with continuous mode. + * + * The execution engine will get an instance of {@link Table} first, then call + * {@link Table#newScanConfigBuilder(DataSourceOptions)} and create an instance of + * {@link ScanConfig}. The {@link ScanConfigBuilder} can apply operator pushdown and keep the + * pushdown result in {@link ScanConfig}. Then + * {@link SupportsContinuousRead#createContinuousInputStream(String, ScanConfig, DataSourceOptions)} + * will be called to create a {@link ContinuousInputStream} instance. The + * {@link ContinuousInputStream} manages offsets and creates a {@link ContinuousScan} instance for + * the duration of the streaming query or until {@link ContinuousInputStream#needsReconfiguration()} + * returns true. The {@link ContinuousScan} will be used to create input partitions and reader + * factory to scan data with a Spark job for its duration. At the end {@link InputStream#stop()} + * will be called when the streaming execution is completed. Note that a single query may have + * multiple executions due to restart or failure recovery. + */ +@InterfaceStability.Evolving +public interface ContinuousScan extends Scan { + + /** + * Returns a factory, which produces one {@link ContinuousPartitionReader} for one + * {@link InputPartition}. + */ + ContinuousPartitionReaderFactory createContinuousReaderFactory(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/InputStream.java similarity index 76% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/InputStream.java index 84872d1ebc26..e1b026dcc332 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/InputStream.java @@ -17,14 +17,18 @@ package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.sql.sources.v2.reader.ReadSupport; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSource; /** - * A base interface for streaming read support. This is package private and is invisible to data - * sources. Data sources should implement concrete streaming read support interfaces: - * {@link MicroBatchReadSupport} or {@link ContinuousReadSupport}. + * An interface representing a readable data stream in a streaming query. It's responsible to manage + * the offsets of the streaming source in this streaming query. + * + * Data sources should implement concrete input stream interfaces: {@link MicroBatchInputStream} and + * {@link ContinuousInputStream}. */ -interface StreamingReadSupport extends ReadSupport { +@InterfaceStability.Evolving +public interface InputStream extends BaseStreamingSource { /** * Returns the initial offset for a streaming query to start reading from. Note that the diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchInputStream.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchInputStream.java new file mode 100644 index 000000000000..2e0e760da7e2 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchInputStream.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A {@link InputStream} for a streaming query with micro-batch mode. + */ +@InterfaceStability.Evolving +public interface MicroBatchInputStream extends InputStream { + + /** + * Creates a {@link MicroBatchScan} instance with a start and end offset, to scan the data within + * this offset range with a Spark job. + */ + MicroBatchScan createMicroBatchScan(Offset start, Offset end); + + /** + * Returns the most recent offset available. + */ + Offset latestOffset(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java deleted file mode 100644 index edb0db11bff2..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader.streaming; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSource; -import org.apache.spark.sql.sources.v2.reader.*; - -/** - * An interface that defines how to scan the data from data source for micro-batch streaming - * processing. - * - * The execution engine will get an instance of this interface from a data source provider - * (e.g. {@link org.apache.spark.sql.sources.v2.MicroBatchReadSupportProvider}) at the start of a - * streaming query, then call {@link #newScanConfigBuilder(Offset, Offset)} and create an instance - * of {@link ScanConfig} for each micro-batch. The {@link ScanConfig} will be used to create input - * partitions and reader factory to scan a micro-batch with a Spark job. At the end {@link #stop()} - * will be called when the streaming execution is completed. Note that a single query may have - * multiple executions due to restart or failure recovery. - */ -@InterfaceStability.Evolving -public interface MicroBatchReadSupport extends StreamingReadSupport, BaseStreamingSource { - - /** - * Returns a builder of {@link ScanConfig}. Spark will call this method and create a - * {@link ScanConfig} for each data scanning job. - * - * The builder can take some query specific information to do operators pushdown, store streaming - * offsets, etc., and keep these information in the created {@link ScanConfig}. - * - * This is the first step of the data scan. All other methods in {@link MicroBatchReadSupport} - * needs to take {@link ScanConfig} as an input. - */ - ScanConfigBuilder newScanConfigBuilder(Offset start, Offset end); - - /** - * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. - */ - PartitionReaderFactory createReaderFactory(ScanConfig config); - - /** - * Returns the most recent offset available. - */ - Offset latestOffset(); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchScan.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchScan.java new file mode 100644 index 000000000000..45d640af5750 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchScan.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.SupportsMicroBatchRead; +import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.reader.*; + +/** + * A {@link Scan} for streaming queries with micro-batch mode. + * + * The execution engine will get an instance of {@link Table} first, then call + * {@link Table#newScanConfigBuilder(DataSourceOptions)} and create an instance of + * {@link ScanConfig}. The {@link ScanConfigBuilder} can apply operator pushdown and keep the + * pushdown result in {@link ScanConfig}. Then + * {@link SupportsMicroBatchRead#createMicroBatchInputStream(String, ScanConfig, DataSourceOptions)} + * will be called to create a {@link MicroBatchInputStream} instance. The + * {@link MicroBatchInputStream} manages offsets and creates a {@link MicroBatchScan} instance for + * each micro-batch. The {@link MicroBatchScan} will be used to create input partitions and + * reader factory to scan a micro-batch with a Spark job. At the end {@link InputStream#stop()} + * will be called when the streaming execution is completed. Note that a single query may have + * multiple executions due to restart or failure recovery. + */ +@InterfaceStability.Evolving +public interface MicroBatchScan extends Scan { + + /** + * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. + */ + PartitionReaderFactory createReaderFactory(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index 6cf27734867c..d89c96360af2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -20,8 +20,8 @@ import org.apache.spark.annotation.InterfaceStability; /** - * An abstract representation of progress through a {@link MicroBatchReadSupport} or - * {@link ContinuousReadSupport}. + * An abstract representation of progress through a {@link MicroBatchScan} or + * {@link ContinuousScan}. * During execution, offsets provided by the data source implementation will be logged and used as * restart checkpoints. Each source should provide an offset implementation which the source can use * to reconstruct a position in the stream up to which data has been seen/processed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 4f6d8b8a0c34..9c9078dfe4e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2.{BatchReadSupportProvider, DataSourceOptions, DataSourceV2} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, Format} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -193,21 +193,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) - if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance().asInstanceOf[DataSourceV2] - if (ds.isInstanceOf[BatchReadSupportProvider]) { - val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - ds = ds, conf = sparkSession.sessionState.conf) - val pathsOption = { - val objectMapper = new ObjectMapper() - DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray) - } - Dataset.ofRows(sparkSession, DataSourceV2Relation.create( - ds, sessionOptions ++ extraOptions.toMap + pathsOption, - userSpecifiedSchema = userSpecifiedSchema)) - } else { - loadV1Source(paths: _*) + if (classOf[Format].isAssignableFrom(cls)) { + val format = cls.newInstance().asInstanceOf[Format] + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + ds = format, conf = sparkSession.sessionState.conf) + val pathsOption = { + val objectMapper = new ObjectMapper() + DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray) } + DataSourceV2Relation.create( + format, sessionOptions ++ extraOptions.toMap + pathsOption, + userSpecifiedSchema = userSpecifiedSchema + ).map(Dataset.ofRows(sparkSession, _)).getOrElse(loadV1Source(paths: _*)) } else { loadV1Source(paths: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 5a28870f5d3c..cdee3de261ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -252,7 +252,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val options = sessionOptions ++ extraOptions if (mode == SaveMode.Append) { - val relation = DataSourceV2Relation.create(source, options) + val relation = DataSourceV2Relation.create(source, options).getOrElse { + throw new AnalysisException(s"data source $source does not support append.") + } runCommand(df.sparkSession, "save") { AppendData.byName(relation, df.logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index f7e29593a635..8a593daf4b58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -24,11 +24,12 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{BatchReadSupportProvider, BatchWriteSupportProvider, DataSourceOptions, DataSourceV2} -import org.apache.spark.sql.sources.v2.reader.{BatchReadSupport, ReadSupport, ScanConfigBuilder, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.{Scan, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, InputStream, MicroBatchInputStream, Offset} import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport import org.apache.spark.sql.types.StructType @@ -41,7 +42,7 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2Relation( source: DataSourceV2, - readSupport: BatchReadSupport, + table: SupportsBatchRead, output: Seq[AttributeReference], options: Map[String, String], tableIdent: Option[TableIdentifier] = None, @@ -60,12 +61,16 @@ case class DataSourceV2Relation( def newWriteSupport(): BatchWriteSupport = source.createWriteSupport(options, schema) - override def computeStats(): Statistics = readSupport match { - case r: SupportsReportStatistics => - val statistics = r.estimateStatistics(readSupport.newScanConfigBuilder().build()) - Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) - case _ => - Statistics(sizeInBytes = conf.defaultSizeInBytes) + override def computeStats(): Statistics = { + val dsOptions = new DataSourceOptions(options.asJava) + val config = table.newScanConfigBuilder(dsOptions).build() + table.createBatchScan(config, dsOptions) match { + case r: SupportsReportStatistics => + val statistics = r.estimateStatistics() + Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + case _ => + Statistics(sizeInBytes = conf.defaultSizeInBytes) + } } override def newInstance(): DataSourceV2Relation = { @@ -81,11 +86,12 @@ case class DataSourceV2Relation( * after we figure out how to apply operator push-down for streaming data sources. */ case class StreamingDataSourceV2Relation( - output: Seq[AttributeReference], + output: Seq[Attribute], source: DataSourceV2, options: Map[String, String], - readSupport: ReadSupport, - scanConfigBuilder: ScanConfigBuilder) + stream: InputStream, + startOffset: Option[Offset] = None, + endOffset: Option[Offset] = None) extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { override def isStreaming: Boolean = true @@ -99,8 +105,8 @@ case class StreamingDataSourceV2Relation( // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: StreamingDataSourceV2Relation => - output == other.output && readSupport.getClass == other.readSupport.getClass && - options == other.options + output == other.output && source.getClass == other.source.getClass && + options == other.options && startOffset == other.startOffset && endOffset == other.endOffset case _ => false } @@ -108,24 +114,30 @@ case class StreamingDataSourceV2Relation( Seq(output, source, options).hashCode() } - override def computeStats(): Statistics = readSupport match { + def createScan(): Scan = (startOffset, endOffset) match { + case (Some(start), Some(end)) => + stream.asInstanceOf[MicroBatchInputStream].createMicroBatchScan(start, end) + case (Some(start), None) => + stream.asInstanceOf[ContinuousInputStream].createContinuousScan(start) + case _ => + throw new IllegalStateException("[BUG] wrong offsets in StreamingDataSourceV2Relation.") + } + + override def computeStats(): Statistics = createScan() match { case r: SupportsReportStatistics => - val statistics = r.estimateStatistics(scanConfigBuilder.build()) + val statistics = r.estimateStatistics() Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) - case _ => - Statistics(sizeInBytes = conf.defaultSizeInBytes) + case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } } object DataSourceV2Relation { private implicit class SourceHelpers(source: DataSourceV2) { - def asReadSupportProvider: BatchReadSupportProvider = { - source match { - case provider: BatchReadSupportProvider => - provider - case _ => - throw new AnalysisException(s"Data source is not readable: $name") - } + + def asFormat: Format = source match { + case f: Format => f + case _ => + throw new AnalysisException(s"Data source is not readable: $name") } def asWriteSupportProvider: BatchWriteSupportProvider = { @@ -146,15 +158,15 @@ object DataSourceV2Relation { } } - def createReadSupport( + def getTable( options: Map[String, String], - userSpecifiedSchema: Option[StructType]): BatchReadSupport = { + userSpecifiedSchema: Option[StructType]): Table = { val v2Options = new DataSourceOptions(options.asJava) userSpecifiedSchema match { case Some(s) => - asReadSupportProvider.createBatchReadSupport(s, v2Options) + asFormat.getTable(v2Options, s) case _ => - asReadSupportProvider.createBatchReadSupport(v2Options) + asFormat.getTable(v2Options) } } @@ -173,12 +185,17 @@ object DataSourceV2Relation { source: DataSourceV2, options: Map[String, String], tableIdent: Option[TableIdentifier] = None, - userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { - val readSupport = source.createReadSupport(options, userSpecifiedSchema) - val output = readSupport.fullSchema().toAttributes + userSpecifiedSchema: Option[StructType] = None): Option[DataSourceV2Relation] = { + val table = source.getTable(options, userSpecifiedSchema) + val output = table.schema().toAttributes val ident = tableIdent.orElse(tableFromOptions(options)) - DataSourceV2Relation( - source, readSupport, output, options, ident, userSpecifiedSchema) + table match { + case batch: SupportsBatchRead => + Some(DataSourceV2Relation( + source, batch, output, options, ident, userSpecifiedSchema)) + case _ => + None + } } private def tableFromOptions(options: Map[String, String]): Option[TableIdentifier] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 04a97735d024..743fd14174df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -26,18 +26,20 @@ import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeSta import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReaderFactory, ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming._ /** * Physical plan node for scanning data from a data source. */ case class DataSourceV2ScanExec( - output: Seq[AttributeReference], + output: Seq[Attribute], @transient source: DataSourceV2, @transient options: Map[String, String], @transient pushedFilters: Seq[Expression], - @transient readSupport: ReadSupport, - @transient scanConfig: ScanConfig) + @transient scan: Scan, + // `ProgressReporter` needs to know which stream a physical scan node associates to, so that + // it can collect metrics for a stream correctly. + @transient stream: Option[InputStream] = None) extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { override def simpleString: String = "ScanV2 " + metadataString @@ -45,33 +47,31 @@ case class DataSourceV2ScanExec( // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: DataSourceV2ScanExec => - output == other.output && readSupport.getClass == other.readSupport.getClass && - options == other.options + output == other.output && source.getClass == other.source.getClass && options == other.options case _ => false } override def hashCode(): Int = { - Seq(output, source, options).hashCode() + Seq(output, source.getClass, options).hashCode() } - override def outputPartitioning: physical.Partitioning = readSupport match { + override def outputPartitioning: physical.Partitioning = scan match { case _ if partitions.length == 1 => SinglePartition case s: SupportsReportPartitioning => - new DataSourcePartitioning( - s.outputPartitioning(scanConfig), AttributeMap(output.map(a => a -> a.name))) + new DataSourcePartitioning(s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) case _ => super.outputPartitioning } - private lazy val partitions: Seq[InputPartition] = readSupport.planInputPartitions(scanConfig) + private lazy val partitions: Seq[InputPartition] = scan.planInputPartitions() - private lazy val readerFactory = readSupport match { - case r: BatchReadSupport => r.createReaderFactory(scanConfig) - case r: MicroBatchReadSupport => r.createReaderFactory(scanConfig) - case r: ContinuousReadSupport => r.createContinuousReaderFactory(scanConfig) - case _ => throw new IllegalStateException("unknown read support: " + readSupport) + private lazy val readerFactory = scan match { + case scan: BatchScan => scan.createReaderFactory() + case scan: MicroBatchScan => scan.createReaderFactory() + case scan: ContinuousScan => scan.createContinuousReaderFactory() + case _ => throw new IllegalStateException("unknown read support: " + scan) } // TODO: clean this up when we have dedicated scan plan for continuous streaming. @@ -83,8 +83,8 @@ case class DataSourceV2ScanExec( partitions.exists(readerFactory.supportColumnarReads) } - private lazy val inputRDD: RDD[InternalRow] = readSupport match { - case _: ContinuousReadSupport => + private lazy val inputRDD: RDD[InternalRow] = scan match { + case _: ContinuousScan => assert(!supportsBatch, "continuous stream reader does not support columnar read yet.") EpochCoordinatorRef.get( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 9a3109e7c199..42b448e80b8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,17 +17,22 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.JavaConverters._ import scala.collection.mutable +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{sources, Strategy} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition} -import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, MicroBatchExecutionRelation, StreamingExecutionRelation} import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, MicroBatchInputStream} object DataSourceV2Strategy extends Strategy { @@ -102,7 +107,8 @@ object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => - val configBuilder = relation.readSupport.newScanConfigBuilder() + val dsOptions = new DataSourceOptions(relation.options.asJava) + val configBuilder = relation.table.newScanConfigBuilder(dsOptions) // `pushedFilters` will be pushed down and evaluated in the underlying data sources. // `postScanFilters` need to be evaluated after the scan. // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. @@ -121,8 +127,7 @@ object DataSourceV2Strategy extends Strategy { relation.source, relation.options, pushedFilters, - relation.readSupport, - config) + relation.table.createBatchScan(config, dsOptions)) val filterCondition = postScanFilters.reduceLeftOption(And) val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) @@ -130,13 +135,41 @@ object DataSourceV2Strategy extends Strategy { // always add the projection, which will produce unsafe rows required by some operators ProjectExec(project, withFilter) :: Nil + // Ideally `StreamingExecutionRelation`, `MicroBatchExecutionRelation` and + // `ContinuousExecutionRelation` are temporary and we don't need to handle them in strategy + // rules. However, the current streaming framework keeps a base logical plan instead of physical + // plan, so we need to do a temp query planning at the beginning to get operator pushdown + // result. Here we catch these temp logical plans, return fake physical plans to report the + // operator pushdown result. + case r: StreamingExecutionRelation => + FakeStreamingScanExec(r.output) :: Nil + + case r: MicroBatchExecutionRelation => + val options = new DataSourceOptions(r.options.asJava) + val configBuilder = r.table.newScanConfigBuilder(options) + // TODO: operator pushdown + val config = configBuilder.build() + val stream = r.table.createMicroBatchInputStream(r.metadataPath, config, options) + FakeMicroBatchExec(r, stream, config.readSchema().toAttributes) :: Nil + + case r: ContinuousExecutionRelation => + val options = new DataSourceOptions(r.options.asJava) + val configBuilder = r.table.newScanConfigBuilder(options) + // TODO: operator pushdown + val config = configBuilder.build() + val stream = r.table.createContinuousInputStream(r.metadataPath, config, options) + FakeContinuousExec(r, stream, config.readSchema().toAttributes) :: Nil + case r: StreamingDataSourceV2Relation => - // TODO: support operator pushdown for streaming data sources. - val scanConfig = r.scanConfigBuilder.build() // ensure there is a projection, which will produce unsafe rows required by some operators ProjectExec(r.output, DataSourceV2ScanExec( - r.output, r.source, r.options, r.pushedFilters, r.readSupport, scanConfig)) :: Nil + r.output, + r.source, + r.options, + r.pushedFilters, + r.createScan(), + Some(r.stream))) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil @@ -149,7 +182,7 @@ object DataSourceV2Strategy extends Strategy { case Repartition(1, false, child) => val isContinuous = child.find { - case s: StreamingDataSourceV2Relation => s.readSupport.isInstanceOf[ContinuousReadSupport] + case s: StreamingDataSourceV2Relation => s.stream.isInstanceOf[ContinuousInputStream] case _ => false }.isDefined @@ -162,3 +195,27 @@ object DataSourceV2Strategy extends Strategy { case _ => Nil } } + +case class FakeStreamingScanExec(output: Seq[Attribute]) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = { + throw new IllegalStateException("cannot execute FakeStreamingScanExec") + } +} + +case class FakeMicroBatchExec( + relation: MicroBatchExecutionRelation, + stream: MicroBatchInputStream, + output: Seq[Attribute]) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = { + throw new IllegalStateException("cannot execute FakeMicroBatchExec") + } +} + +case class FakeContinuousExec( + relation: ContinuousExecutionRelation, + stream: ContinuousInputStream, + output: Seq[Attribute]) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = { + throw new IllegalStateException("cannot execute FakeContinuousExec") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/NoopScanConfigBuilder.scala similarity index 62% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/NoopScanConfigBuilder.scala index 1be071614d92..56a5477d0c8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/NoopScanConfigBuilder.scala @@ -15,26 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.sources.v2.reader.{ScanConfig, ScanConfigBuilder} import org.apache.spark.sql.types.StructType -/** - * A very simple [[ScanConfigBuilder]] implementation that creates a simple [[ScanConfig]] to - * carry schema and offsets for streaming data sources. - */ -class SimpleStreamingScanConfigBuilder( - schema: StructType, - start: Offset, - end: Option[Offset] = None) - extends ScanConfigBuilder { - - override def build(): ScanConfig = SimpleStreamingScanConfig(schema, start, end) +class NoopScanConfigBuilder(schema: StructType) extends ScanConfigBuilder { + override def build(): ScanConfig = new NoopScanConfig(schema) } -case class SimpleStreamingScanConfig( - readSchema: StructType, - start: Offset, - end: Option[Offset]) - extends ScanConfig +class NoopScanConfig(schema: StructType) extends ScanConfig { + override def readSchema(): StructType = schema +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 2cac86599ef1..a99da5bd81bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -17,18 +17,20 @@ package org.apache.spark.sql.execution.streaming +import java.util.IdentityHashMap + import scala.collection.JavaConverters._ import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} -import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport} +import org.apache.spark.sql.execution.datasources.v2.{FakeMicroBatchExec, StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchInputStream} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchInputStream, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -49,9 +51,6 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty - private val readSupportToDataSourceMap = - MutableMap.empty[MicroBatchReadSupport, (DataSourceV2, Map[String, String])] - private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() @@ -67,6 +66,7 @@ class MicroBatchExecution( var nextSourceId = 0L val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]() val v2ToExecutionRelationMap = MutableMap[StreamingRelationV2, StreamingExecutionRelation]() + val v2ToMicroBatchExecutionMap = MutableMap[StreamingRelationV2, MicroBatchExecutionRelation]() // We transform each distinct streaming relation into a StreamingExecutionRelation, keeping a // map as we go to ensure each identical relation gets the same StreamingExecutionRelation // object. For each microbatch, the StreamingExecutionRelation will be replaced with a logical @@ -89,21 +89,18 @@ class MicroBatchExecution( StreamingExecutionRelation(source, output)(sparkSession) }) case s @ StreamingRelationV2( - dataSourceV2: MicroBatchReadSupportProvider, sourceName, options, output, _) if - !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) => - v2ToExecutionRelationMap.getOrElseUpdate(s, { - // Materialize source to avoid creating it in every batch + sourceName, ds, table: SupportsMicroBatchRead, options, output, _) + if !disabledSources.contains(ds.getClass.getCanonicalName) => + v2ToMicroBatchExecutionMap.getOrElseUpdate(s, { val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val readSupport = dataSourceV2.createMicroBatchReadSupport( - metadataPath, - new DataSourceOptions(options.asJava)) nextSourceId += 1 - readSupportToDataSourceMap(readSupport) = dataSourceV2 -> options - logInfo(s"Using MicroBatchReadSupport [$readSupport] from " + - s"DataSourceV2 named '$sourceName' [$dataSourceV2]") - StreamingExecutionRelation(readSupport, output)(sparkSession) + logInfo(s"Reading table [$table] from " + + s"DataSourceV2 named '$sourceName' [$ds]") + MicroBatchExecutionRelation( + sourceName, ds, table, output, metadataPath, options)(sparkSession) }) - case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) => + case s @ StreamingRelationV2( + sourceName, ds, _, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" @@ -113,13 +110,58 @@ class MicroBatchExecution( } val source = v1Relation.get.dataSource.createSource(metadataPath) nextSourceId += 1 - logInfo(s"Using Source [$source] from DataSourceV2 named '$sourceName' [$dataSourceV2]") + logInfo(s"Using Source [$source] from DataSourceV2 named '$sourceName' [$ds]") StreamingExecutionRelation(source, output)(sparkSession) }) } - sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source } + + // This is a temporary query planning, to get operator pushdown result of v2 sources. + // TODO: update the streaming engine to do query planning only once. + val relationToStream = new IdentityHashMap[MicroBatchExecutionRelation, MicroBatchInputStream] + createExecution(_logicalPlan, sparkSession).sparkPlan.foreach { + case exec: FakeMicroBatchExec => + if (relationToStream.containsKey(exec.relation)) { + // This is a self-union/self-join, don't apply operator pushdown, since we want to keep + // one stream instance for the self-unioned/self-joined source. + // TODO: we can push down shared operators to the self-unioned/self-joined sources. + val options = new DataSourceOptions(exec.relation.options.asJava) + val configBuilder = exec.relation.table.newScanConfigBuilder(options) + val config = configBuilder.build() + val stream = exec.relation.table.createMicroBatchInputStream( + exec.relation.metadataPath, config, options) + relationToStream.put(exec.relation, stream) + } else { + relationToStream.put(exec.relation, exec.stream) + } + + case _ => + } + + val finalPlan = _logicalPlan.transform { + case r: MicroBatchExecutionRelation => + val stream = relationToStream.get(r) + assert(stream != null) + StreamingDataSourceV2Relation(r.output, r.ds, r.options, stream) + } + + sources = finalPlan.collect { + case r: StreamingExecutionRelation => r.source + case r: StreamingDataSourceV2Relation => r.stream + } uniqueSources = sources.distinct - _logicalPlan + + finalPlan + } + + private def createExecution(plan: LogicalPlan, session: SparkSession): IncrementalExecution = { + new IncrementalExecution( + session, + plan, + outputMode, + checkpointFile("state"), + runId, + currentBatchId, + offsetSeqMetadata) } /** @@ -341,7 +383,7 @@ class MicroBatchExecution( reportTimeTaken("getOffset") { (s, s.getOffset) } - case s: RateControlMicroBatchReadSupport => + case s: RateControlMicroBatchInputStream => updateStatusMessage(s"Getting offsets from $s") reportTimeTaken("latestOffset") { val startOffset = availableOffsets @@ -349,7 +391,7 @@ class MicroBatchExecution( .getOrElse(s.initialOffset()) (s, Option(s.latestOffset(startOffset))) } - case s: MicroBatchReadSupport => + case s: MicroBatchInputStream => updateStatusMessage(s"Getting offsets from $s") reportTimeTaken("latestOffset") { (s, Option(s.latestOffset())) @@ -393,8 +435,8 @@ class MicroBatchExecution( if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { case (src: Source, off) => src.commit(off) - case (readSupport: MicroBatchReadSupport, off) => - readSupport.commit(readSupport.deserializeOffset(off.json)) + case (stream: MicroBatchInputStream, off) => + stream.commit(stream.deserializeOffset(off.json)) case (src, _) => throw new IllegalArgumentException( s"Unknown source is found at constructNextBatch: $src") @@ -439,39 +481,29 @@ class MicroBatchExecution( logDebug(s"Retrieving data from $source: $current -> $available") Some(source -> batch.logicalPlan) - // TODO(cloud-fan): for data source v2, the new batch is just a new `ScanConfigBuilder`, but - // to be compatible with streaming source v1, we return a logical plan as a new batch here. - case (readSupport: MicroBatchReadSupport, available) - if committedOffsets.get(readSupport).map(_ != available).getOrElse(true) => - val current = committedOffsets.get(readSupport).map { - off => readSupport.deserializeOffset(off.json) + case (stream: MicroBatchInputStream, available) + if committedOffsets.get(stream).map(_ != available).getOrElse(true) => + val current = committedOffsets.get(stream).map { + off => stream.deserializeOffset(off.json) } val endOffset: OffsetV2 = available match { - case v1: SerializedOffset => readSupport.deserializeOffset(v1.json) + case v1: SerializedOffset => stream.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - val startOffset = current.getOrElse(readSupport.initialOffset) - val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset, endOffset) - logDebug(s"Retrieving data from $readSupport: $current -> $endOffset") - - val (source, options) = readSupport match { - // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2` - // implementation. We provide a fake one here for explain. - case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String] - // Provide a fake value here just in case something went wrong, e.g. the reader gives - // a wrong `equals` implementation. - case _ => readSupportToDataSourceMap.getOrElse(readSupport, { - FakeDataSourceV2 -> Map.empty[String, String] - }) - } - Some(readSupport -> StreamingDataSourceV2Relation( - readSupport.fullSchema().toAttributes, source, options, readSupport, scanConfigBuilder)) + val startOffset = current.getOrElse(stream.initialOffset) + logInfo(s"Retrieving data from $stream: $startOffset -> $endOffset") + + // To be compatible with the v1 source, the `newData` is represented as a logical plan, + // while the `newData` of v2 source is just the start and end offsets. Here we return a + // fake logical plan to carry the offsets. + Some(stream -> OffsetHolder(startOffset, endOffset)) case _ => None } } // Replace sources in the logical plan with data that has arrived since the last batch. val newBatchesPlan = logicalPlan transform { + // For v1 sources. case StreamingExecutionRelation(source, output) => newData.get(source).map { dataPlan => assert(output.size == dataPlan.output.size, @@ -485,6 +517,15 @@ class MicroBatchExecution( }.getOrElse { LocalRelation(output, isStreaming = true) } + + // For v2 sources. + case r: StreamingDataSourceV2Relation => + newData.get(r.stream).map { + case OffsetHolder(start, end) => + r.copy(startOffset = Some(start), endOffset = Some(end)) + }.getOrElse { + LocalRelation(r.output, isStreaming = true) + } } // Rewire the plan to use the new attributes that were returned by the source. @@ -497,7 +538,7 @@ class MicroBatchExecution( cd.dataType, cd.timeZoneId) } - val triggerLogicalPlan = sink match { + val planWithSink = sink match { case _: Sink => newAttributePlan case s: StreamingWriteSupportProvider => val writer = s.createStreamingWriteSupport( @@ -515,14 +556,7 @@ class MicroBatchExecution( StreamExecution.IS_CONTINUOUS_PROCESSING, false.toString) reportTimeTaken("queryPlanning") { - lastExecution = new IncrementalExecution( - sparkSessionToRunBatch, - triggerLogicalPlan, - outputMode, - checkpointFile("state"), - runId, - currentBatchId, - offsetSeqMetadata) + lastExecution = createExecution(planWithSink, sparkSessionToRunBatch) lastExecution.executedPlan // Force the lazy generation of execution plan } @@ -563,6 +597,6 @@ object MicroBatchExecution { val BATCH_ID_KEY = "streaming.sql.batchId" } -object MemoryStreamDataSource extends DataSourceV2 - -object FakeDataSourceV2 extends DataSourceV2 +case class OffsetHolder(start: OffsetV2, end: OffsetV2) extends LeafNode { + override def output: Seq[Attribute] = Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 392229bcb5f5..78d50e6111e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation} +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchInputStream import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -245,10 +245,12 @@ trait ProgressReporter extends Logging { } val onlyDataSourceV2Sources = { - // Check whether the streaming query's logical plan has only V2 data sources - val allStreamingLeaves = - logicalPlan.collect { case s: StreamingExecutionRelation => s } - allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReadSupport] } + // Check whether the streaming query's logical plan has only V2 micro-batch data sources + val allStreamingLeaves = logicalPlan.collect { + case s: StreamingDataSourceV2Relation => s.stream.isInstanceOf[MicroBatchInputStream] + case _: StreamingExecutionRelation => false + } + allStreamingLeaves.forall(_ == true) } if (onlyDataSourceV2Sources) { @@ -256,9 +258,9 @@ trait ProgressReporter extends Logging { // (can happen with self-unions or self-joins). This means the source is scanned multiple // times in the query, we should count the numRows for each scan. val sourceToInputRowsTuples = lastExecution.executedPlan.collect { - case s: DataSourceV2ScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] => + case s: DataSourceV2ScanExec if s.stream.isDefined => val numRows = s.metrics.get("numOutputRows").map(_.value).getOrElse(0L) - val source = s.readSupport.asInstanceOf[BaseStreamingSource] + val source = s.stream.get source -> numRows } logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 4b696dfa5735..b603f3d057b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceV2} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2._ object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -81,6 +82,54 @@ case class StreamingExecutionRelation( override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } +case class MicroBatchExecutionRelation( + source: String, + ds: DataSourceV2, + table: SupportsMicroBatchRead, + output: Seq[Attribute], + metadataPath: String, + options: Map[String, String])(session: SparkSession) + extends LeafNode with MultiInstanceRelation { + + override def otherCopyArgs: Seq[AnyRef] = session :: Nil + override def isStreaming: Boolean = true + override def toString: String = source + + // There's no sensible value here. On the execution path, this relation will be swapped out with + // `StreamingDataSourceV2Relation`. But some dataframe operations (in particular explain) do lead + // to this node surviving analysis. So we satisfy the LeafNode contract with the session default + // value. + override def computeStats(): Statistics = Statistics( + sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) + ) + + override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance()))(session) +} + +case class ContinuousExecutionRelation( + source: String, + ds: DataSourceV2, + table: SupportsContinuousRead, + output: Seq[Attribute], + metadataPath: String, + options: Map[String, String])(session: SparkSession) + extends LeafNode with MultiInstanceRelation { + + override def otherCopyArgs: Seq[AnyRef] = session :: Nil + override def isStreaming: Boolean = true + override def toString: String = source + + // There's no sensible value here. On the execution path, this relation will be swapped out with + // `StreamingDataSourceV2Relation`. But some dataframe operations (in particular explain) do lead + // to this node surviving analysis. So we satisfy the LeafNode contract with the session default + // value. + override def computeStats(): Statistics = Statistics( + sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) + ) + + override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance()))(session) +} + // We have to pack in the V1 data source as a shim, for the case when a source implements // continuous processing (which is always V2) but only has V1 microbatch support. We don't // know at read time whether the query is continuous or not, so we need to be able to @@ -92,8 +141,9 @@ case class StreamingExecutionRelation( * and should be converted before passing to [[StreamExecution]]. */ case class StreamingRelationV2( - dataSource: DataSourceV2, sourceName: String, + dataSource: DataSourceV2, + table: Table, extraOptions: Map[String, String], output: Seq[Attribute], v1Relation: Option[StreamingRelation])(session: SparkSession) @@ -109,30 +159,6 @@ case class StreamingRelationV2( override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } -/** - * Used to link a [[DataSourceV2]] into a continuous processing execution. - */ -case class ContinuousExecutionRelation( - source: ContinuousReadSupportProvider, - extraOptions: Map[String, String], - output: Seq[Attribute])(session: SparkSession) - extends LeafNode with MultiInstanceRelation { - - override def otherCopyArgs: Seq[AnyRef] = session :: Nil - override def isStreaming: Boolean = true - override def toString: String = source.toString - - // There's no sensible value here. On the execution path, this relation will be - // swapped out with microbatches. But some dataframe operations (in particular explain) do lead - // to this node surviving analysis. So we satisfy the LeafNode contract with the session default - // value. - override def computeStats(): Statistics = Statistics( - sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) - ) - - override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) -} - /** * A dummy physical plan for [[StreamingRelation]] to support * [[org.apache.spark.sql.Dataset.explain]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index f009c52449ad..edeb189886c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -17,25 +17,26 @@ package org.apache.spark.sql.execution.streaming.continuous +import java.util.IdentityHashMap import java.util.UUID import java.util.concurrent.TimeUnit import java.util.function.UnaryOperator import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} +import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.SparkEnv import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} +import org.apache.spark.sql.catalyst.expressions.{CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation} -import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} +import org.apache.spark.sql.execution.datasources.v2.{FakeContinuousExec, StreamingDataSourceV2Relation} +import org.apache.spark.sql.execution.streaming.{StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider, SupportsContinuousRead} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, InputStream, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} -import org.apache.spark.util.{Clock, Utils} +import org.apache.spark.util.Clock class ContinuousExecution( sparkSession: SparkSession, @@ -52,25 +53,74 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReadSupport] = Seq() - override protected def sources: Seq[BaseStreamingSource] = continuousSources + @volatile protected var sources: Seq[InputStream] = Seq.empty // For use only in test harnesses. private[sql] var currentEpochCoordinatorId: String = _ override val logicalPlan: LogicalPlan = { - val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() - analyzedPlan.transform { - case r @ StreamingRelationV2( - source: ContinuousReadSupportProvider, _, extraReaderOptions, output, _) => - // TODO: shall we create `ContinuousReadSupport` here instead of each reconfiguration? - toExecutionRelationMap.getOrElseUpdate(r, { - ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession) + val v2ToContinuousExecutionMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() + var nextSourceId = 0 + val _logicalPlan = analyzedPlan.transform { + case s @ StreamingRelationV2( + sourceName, ds, table: SupportsContinuousRead, options, output, _) => + v2ToContinuousExecutionMap.getOrElseUpdate(s, { + val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" + nextSourceId += 1 + ContinuousExecutionRelation( + sourceName, ds, table, output, metadataPath, options)(sparkSession) }) - case StreamingRelationV2(_, sourceName, _, _, _) => + case r: StreamingRelationV2 => throw new UnsupportedOperationException( - s"Data source $sourceName does not support continuous processing.") + s"Data source ${r.sourceName} does not support continuous processing.") + } + + // This is a temporary query planning, to get operator pushdown result of v2 sources. + // TODO: update the streaming engine to do query planning only once. + val relationToStream = new IdentityHashMap[ContinuousExecutionRelation, ContinuousInputStream] + createExecution(_logicalPlan, sparkSession).sparkPlan.foreach { + case exec: FakeContinuousExec => + if (relationToStream.containsKey(exec.relation)) { + // This is a self-union/self-join, don't apply operator pushdown, since we want to keep + // one stream instance for the self-unioned/self-joined source. + // TODO: we can push down shared operators to the self-unioned/self-joined sources. + val options = new DataSourceOptions(exec.relation.options.asJava) + val configBuilder = exec.relation.table.newScanConfigBuilder(options) + val config = configBuilder.build() + val stream = exec.relation.table.createContinuousInputStream( + exec.relation.metadataPath, config, options) + relationToStream.put(exec.relation, stream) + } else { + relationToStream.put(exec.relation, exec.stream) + } + + case _ => } + + val finalPlan = _logicalPlan.transform { + case r: ContinuousExecutionRelation => + val stream = relationToStream.get(r) + assert(stream != null) + StreamingDataSourceV2Relation(r.output, r.ds, r.options, stream) + } + + sources = finalPlan.collect { + case r: StreamingDataSourceV2Relation => r.stream + } + uniqueSources = sources.distinct + + finalPlan + } + + private def createExecution(plan: LogicalPlan, session: SparkSession): IncrementalExecution = { + new IncrementalExecution( + session, + plan, + outputMode, + checkpointFile("state"), + runId, + currentBatchId, + offsetSeqMetadata) } private val triggerExecutor = trigger match { @@ -90,6 +140,8 @@ class ContinuousExecution( do { runContinuous(sparkSessionForStream) } while (state.updateAndGet(stateUpdate) == ACTIVE) + + stopSources() } /** @@ -130,7 +182,7 @@ class ContinuousExecution( // We are starting this stream for the first time. Offsets are all None. logInfo(s"Starting new streaming query.") currentBatchId = 0 - OffsetSeq.fill(continuousSources.map(_ => null): _*) + OffsetSeq.fill(sources.map(_ => null): _*) } } @@ -139,47 +191,17 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { - // A list of attributes that will need to be updated. - val replacements = new ArrayBuffer[(Attribute, Attribute)] - // Translate from continuous relation to the underlying data source. - var nextSourceId = 0 - continuousSources = logicalPlan.collect { - case ContinuousExecutionRelation(dataSource, extraReaderOptions, output) => - val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - nextSourceId += 1 - - dataSource.createContinuousReadSupport( - metadataPath, - new DataSourceOptions(extraReaderOptions.asJava)) - } - uniqueSources = continuousSources.distinct - val offsets = getStartOffsets(sparkSessionForQuery) - var insertedSourceId = 0 - val withNewSources = logicalPlan transform { - case ContinuousExecutionRelation(source, options, output) => - val readSupport = continuousSources(insertedSourceId) - insertedSourceId += 1 - val newOutput = readSupport.fullSchema().toAttributes - - assert(output.size == newOutput.size, - s"Invalid reader: ${Utils.truncatedString(output, ",")} != " + - s"${Utils.truncatedString(newOutput, ",")}") - replacements ++= output.zip(newOutput) - + val withNewSources: LogicalPlan = logicalPlan transform { + case relation: StreamingDataSourceV2Relation => val loggedOffset = offsets.offsets(0) - val realOffset = loggedOffset.map(off => readSupport.deserializeOffset(off.json)) - val startOffset = realOffset.getOrElse(readSupport.initialOffset) - val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset) - StreamingDataSourceV2Relation(newOutput, source, options, readSupport, scanConfigBuilder) + val realOffset = loggedOffset.map(off => relation.stream.deserializeOffset(off.json)) + val startOffset = realOffset.getOrElse(relation.stream.initialOffset) + relation.copy(startOffset = Some(startOffset)) } - // Rewire the plan to use the new attributes that were returned by the source. - val replacementMap = AttributeMap(replacements) - val triggerLogicalPlan = withNewSources transformAllExpressions { - case a: Attribute if replacementMap.contains(a) => - replacementMap(a).withMetadata(a.metadata) + withNewSources transformAllExpressions { case (_: CurrentTimestamp | _: CurrentDate) => throw new IllegalStateException( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") @@ -187,26 +209,19 @@ class ContinuousExecution( val writer = sink.createStreamingWriteSupport( s"$runId", - triggerLogicalPlan.schema, + withNewSources.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan) + val planWithSink = WriteToContinuousDataSource(writer, withNewSources) reportTimeTaken("queryPlanning") { - lastExecution = new IncrementalExecution( - sparkSessionForQuery, - withSink, - outputMode, - checkpointFile("state"), - runId, - currentBatchId, - offsetSeqMetadata) + lastExecution = createExecution(planWithSink, sparkSessionForQuery) lastExecution.executedPlan // Force the lazy generation of execution plan } - val (readSupport, scanConfig) = lastExecution.executedPlan.collect { - case scan: DataSourceV2ScanExec if scan.readSupport.isInstanceOf[ContinuousReadSupport] => - scan.readSupport.asInstanceOf[ContinuousReadSupport] -> scan.scanConfig + val stream = planWithSink.collect { + case relation: StreamingDataSourceV2Relation => + relation.stream.asInstanceOf[ContinuousInputStream] }.head sparkSessionForQuery.sparkContext.setLocalProperty( @@ -226,16 +241,14 @@ class ContinuousExecution( // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - writer, readSupport, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) + writer, stream, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { triggerExecutor.execute(() => { startTrigger() - val shouldReconfigure = readSupport.needsReconfiguration(scanConfig) && - state.compareAndSet(ACTIVE, RECONFIGURING) - if (shouldReconfigure) { + if (stream.needsReconfiguration && state.compareAndSet(ACTIVE, RECONFIGURING)) { if (queryExecutionThread.isAlive) { queryExecutionThread.interrupt() } @@ -276,7 +289,6 @@ class ContinuousExecution( epochUpdateThread.interrupt() epochUpdateThread.join() - stopSources() sparkSession.sparkContext.cancelJobGroup(runId.toString) } } @@ -286,11 +298,11 @@ class ContinuousExecution( */ def addOffset( epoch: Long, - readSupport: ContinuousReadSupport, + stream: ContinuousInputStream, partitionOffsets: Seq[PartitionOffset]): Unit = { - assert(continuousSources.length == 1, "only one continuous source supported currently") + assert(sources.length == 1, "only one continuous source supported currently") - val globalOffset = readSupport.mergeOffsets(partitionOffsets.toArray) + val globalOffset = stream.mergeOffsets(partitionOffsets.toArray) val oldOffset = synchronized { offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) offsetLog.get(epoch - 1) @@ -314,7 +326,7 @@ class ContinuousExecution( * before this is called. */ def commit(epoch: Long): Unit = { - assert(continuousSources.length == 1, "only one continuous source supported currently") + assert(sources.length == 1, "only one continuous source supported currently") assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit") synchronized { @@ -323,9 +335,9 @@ class ContinuousExecution( if (queryExecutionThread.isAlive) { commitLog.add(epoch, CommitMetadata()) val offset = - continuousSources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json) - committedOffsets ++= Seq(continuousSources(0) -> offset) - continuousSources(0).commit(offset.asInstanceOf[v2.reader.streaming.Offset]) + sources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json) + committedOffsets ++= Seq(sources(0) -> offset) + sources(0).commit(offset.asInstanceOf[v2.reader.streaming.Offset]) } else { return } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index a6cde2b8a710..3b6201049d4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -22,17 +22,16 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming._ -import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class RateStreamContinuousReadSupport(options: DataSourceOptions) extends ContinuousReadSupport { +class RateStreamContinuousInputStream(options: DataSourceOptions) extends ContinuousInputStream { implicit val defaultFormats: DefaultFormats = DefaultFormats val creationTime = System.currentTimeMillis() @@ -54,18 +53,36 @@ class RateStreamContinuousReadSupport(options: DataSourceOptions) extends Contin RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def fullSchema(): StructType = RateStreamProvider.SCHEMA + override def initialOffset: Offset = createInitialOffset(numPartitions, creationTime) - override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start) + override def createContinuousScan(start: Offset): ContinuousScan = { + new RateStreamContinuousScan(numPartitions, perPartitionRate, start) } - override def initialOffset: Offset = createInitialOffset(numPartitions, creationTime) + override def commit(end: Offset): Unit = {} + override def stop(): Unit = {} - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val startOffset = config.asInstanceOf[SimpleStreamingScanConfig].start + private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { + RateStreamOffset(Range(0, numPartitions).map { i => + // Note that the starting offset is exclusive, so we have to decrement the starting value by + // the increment that will later be applied. The first row output in each partition will have + // a value equal to the partition index. + (i, ValueRunTimeMsPair((i - numPartitions).toLong, creationTimeMs)) + }.toMap) + } +} + +class RateStreamContinuousScan( + numPartitions: Int, + perPartitionRate: Double, + start: Offset) extends ContinuousScan { + + override def createContinuousReaderFactory(): ContinuousPartitionReaderFactory = { + RateStreamContinuousReaderFactory + } - val partitionStartMap = startOffset match { + override def planInputPartitions(): Array[InputPartition] = { + val partitionStartMap = start match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => throw new IllegalArgumentException( @@ -74,8 +91,8 @@ class RateStreamContinuousReadSupport(options: DataSourceOptions) extends Contin if (partitionStartMap.keySet.size != numPartitions) { throw new IllegalArgumentException( s"The previous run contained ${partitionStartMap.keySet.size} partitions, but" + - s" $numPartitions partitions are currently configured. The numPartitions option" + - " cannot be changed.") + s" $numPartitions partitions are currently configured. The numPartitions option" + + " cannot be changed.") } Range(0, numPartitions).map { i => @@ -90,28 +107,6 @@ class RateStreamContinuousReadSupport(options: DataSourceOptions) extends Contin perPartitionRate) }.toArray } - - override def createContinuousReaderFactory( - config: ScanConfig): ContinuousPartitionReaderFactory = { - RateStreamContinuousReaderFactory - } - - override def commit(end: Offset): Unit = {} - override def stop(): Unit = {} - - private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { - RateStreamOffset( - Range(0, numPartitions).map { i => - // Note that the starting offset is exclusive, so we have to decrement the starting value - // by the increment that will later be applied. The first row output in each - // partition will have a value equal to the partition index. - (i, - ValueRunTimeMsPair( - (i - numPartitions).toLong, - creationTimeMs)) - }.toMap) - } - } case class RateStreamContinuousInputPartition( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala index 28ab2448a663..38b66a172e5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala @@ -38,20 +38,18 @@ import org.apache.spark.sql.execution.streaming.sources.TextSocketReader import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming._ -import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils - /** - * A ContinuousReadSupport that reads text lines through a TCP socket, designed only for tutorials - * and debugging. This ContinuousReadSupport will *not* work in production applications due to + * A ContinuousInputStream that reads text lines through a TCP socket, designed only for tutorials + * and debugging. This ContinuousInputStream will *not* work in production applications due to * multiple reasons, including no support for fault recovery. * * The driver maintains a socket connection to the host-port, keeps the received messages in * buckets and serves the messages to the executors via a RPC endpoint. */ -class TextSocketContinuousReadSupport(options: DataSourceOptions) - extends ContinuousReadSupport with Logging { +class TextSocketContinuousInputStream(options: DataSourceOptions) + extends ContinuousInputStream with Logging { implicit val defaultFormats: DefaultFormats = DefaultFormats @@ -60,7 +58,7 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions) assert(SparkSession.getActiveSession.isDefined) private val spark = SparkSession.getActiveSession.get - private val numPartitions = spark.sparkContext.defaultParallelism + private val numPartitions: Int = spark.sparkContext.defaultParallelism @GuardedBy("this") private var socket: Socket = _ @@ -101,21 +99,8 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions) startOffset } - override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start) - } - - override def fullSchema(): StructType = { - if (includeTimestamp) { - TextSocketReader.SCHEMA_TIMESTAMP - } else { - TextSocketReader.SCHEMA_REGULAR - } - } - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val startOffset = config.asInstanceOf[SimpleStreamingScanConfig] - .start.asInstanceOf[TextSocketOffset] + override def createContinuousScan(start: Offset): ContinuousScan = { + val startOffset = start.asInstanceOf[TextSocketOffset] recordEndpoint.setStartOffsets(startOffset.offsets) val endpointName = s"TextSocketContinuousReaderEndpoint-${java.util.UUID.randomUUID()}" endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) @@ -134,15 +119,12 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions) " cannot be changed.") } - startOffset.offsets.zipWithIndex.map { + val partitions: Array[InputPartition] = startOffset.offsets.zipWithIndex.map { case (offset, i) => TextSocketContinuousInputPartition(endpointName, i, offset, includeTimestamp) }.toArray - } - override def createContinuousReaderFactory( - config: ScanConfig): ContinuousPartitionReaderFactory = { - TextSocketReaderFactory + new TextSocketContinuousScan(partitions) } override def commit(end: Offset): Unit = synchronized { @@ -157,7 +139,7 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions) val max = startOffset.offsets(partition) + buckets(partition).size if (offset > max) { throw new IllegalStateException("Invalid offset " + offset + " to commit" + - " for partition " + partition + ". Max valid offset: " + max) + " for partition " + partition + ". Max valid offset: " + max) } val n = offset - startOffset.offsets(partition) buckets(partition).trimStart(n) @@ -197,7 +179,7 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions) logWarning(s"Stream closed by $host:$port") return } - TextSocketContinuousReadSupport.this.synchronized { + TextSocketContinuousInputStream.this.synchronized { currentOffset += 1 val newData = (line, Timestamp.valueOf( @@ -218,9 +200,20 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions) override def toString: String = s"TextSocketContinuousReader[host: $host, port: $port]" private def includeTimestamp: Boolean = options.getBoolean("includeTimestamp", false) +} + +class TextSocketContinuousScan(partitions: Array[InputPartition]) extends ContinuousScan { + + override def createContinuousReaderFactory(): ContinuousPartitionReaderFactory = { + TextSocketContinuousReaderFactory + } + override def planInputPartitions(): Array[InputPartition] = { + partitions + } } + /** * Continuous text socket input partition. */ @@ -231,7 +224,7 @@ case class TextSocketContinuousInputPartition( includeTimestamp: Boolean) extends InputPartition -object TextSocketReaderFactory extends ContinuousPartitionReaderFactory { +object TextSocketContinuousReaderFactory extends ContinuousPartitionReaderFactory { override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { val p = partition.asInstanceOf[TextSocketContinuousInputPartition] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 2238ce26e7b4..e4ceeef28ca1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.util.RpcUtils @@ -83,14 +83,14 @@ private[sql] object EpochCoordinatorRef extends Logging { */ def create( writeSupport: StreamingWriteSupport, - readSupport: ContinuousReadSupport, + inputStream: ContinuousInputStream, query: ContinuousExecution, epochCoordinatorId: String, startEpoch: Long, session: SparkSession, env: SparkEnv): RpcEndpointRef = synchronized { val coordinator = new EpochCoordinator( - writeSupport, readSupport, query, startEpoch, session, env.rpcEnv) + writeSupport, inputStream, query, startEpoch, session, env.rpcEnv) val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator) logInfo("Registered EpochCoordinator endpoint") ref @@ -116,7 +116,7 @@ private[sql] object EpochCoordinatorRef extends Logging { */ private[continuous] class EpochCoordinator( writeSupport: StreamingWriteSupport, - readSupport: ContinuousReadSupport, + inputStream: ContinuousInputStream, query: ContinuousExecution, startEpoch: Long, session: SparkSession, @@ -220,7 +220,7 @@ private[continuous] class EpochCoordinator( partitionOffsets.collect { case ((e, _), o) if e == epoch => o } if (thisEpochOffsets.size == numReaderPartitions) { logDebug(s"Epoch $epoch has offsets reported from all partitions: $thisEpochOffsets") - query.addOffset(epoch, readSupport, thisEpochOffsets.toSeq) + query.addOffset(epoch, inputStream, thisEpochOffsets.toSeq) resolveCommitsAtEpoch(epoch) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index adf52aba21a0..03549e72b625 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -28,11 +28,12 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, InputStream, MicroBatchInputStream, MicroBatchScan, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -45,11 +46,17 @@ object MemoryStream { new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) } +// This class is used to indicate the memory stream data source. We don't actually use it, as +// memory stream is for test only and we never look it up by name. +object MemoryStreamSource extends DataSourceV2 + /** * A base class for memory stream implementations. Supports adding data and resetting. */ -abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends BaseStreamingSource { - protected val encoder = encoderFor[A] +abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) + extends BaseStreamingSource with InputStream { + + val encoder = encoderFor[A] protected val attributes = encoder.schema.toAttributes def toDS(): Dataset[A] = { @@ -64,24 +71,44 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas addData(data.toTraversable) } - def fullSchema(): StructType = encoder.schema - - protected def logicalPlan: LogicalPlan + protected val logicalPlan = StreamingRelationV2( + "memory", + MemoryStreamSource, + new MemoryStreamTable(this), + Map.empty, + attributes, + None)(sqlContext.sparkSession) def addData(data: TraversableOnce[A]): Offset } +class MemoryStreamTable(val stream: MemoryStreamBase[_]) extends Table + with SupportsMicroBatchRead with SupportsContinuousRead { + + override def schema(): StructType = stream.encoder.schema + + override def createMicroBatchInputStream( + checkpointLocation: String, + config: ScanConfig, + options: DataSourceOptions): MicroBatchInputStream = { + stream.asInstanceOf[MicroBatchInputStream] + } + + override def createContinuousInputStream( + checkpointLocation: String, + config: ScanConfig, + options: DataSourceOptions): ContinuousInputStream = { + stream.asInstanceOf[ContinuousInputStream] + } +} + /** * A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]] * is intended for use in unit tests as it can only replay data when the object is still * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MemoryStreamBase[A](sqlContext) with MicroBatchReadSupport with Logging { - - protected val logicalPlan: LogicalPlan = - StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) - protected val output = logicalPlan.output + extends MemoryStreamBase[A](sqlContext) with MicroBatchInputStream with Logging { /** * All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive. @@ -117,7 +144,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } - override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" + override def toString: String = s"MemoryStream[${Utils.truncatedString(logicalPlan.output, ",")}]" override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) @@ -127,15 +154,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) if (currentOffset.offset == -1) null else currentOffset } - override def newScanConfigBuilder(start: OffsetV2, end: OffsetV2): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) - } - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val sc = config.asInstanceOf[SimpleStreamingScanConfig] - val startOffset = sc.start.asInstanceOf[LongOffset] - val endOffset = sc.end.get.asInstanceOf[LongOffset] - synchronized { + override def createMicroBatchScan(start: OffsetV2, end: OffsetV2): MicroBatchScan = { + val startOffset = start.asInstanceOf[LongOffset] + val endOffset = end.asInstanceOf[LongOffset] + val partitions: Array[InputPartition] = synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 val endOrdinal = endOffset.offset.toInt + 1 @@ -154,10 +176,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) new MemoryStreamInputPartition(block) }.toArray } - } - - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - MemoryStreamReaderFactory + new MemoryStreamMicroBatchScan(partitions) } private def generateDebugString( @@ -199,6 +218,16 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } +class MemoryStreamMicroBatchScan(partitions: Array[InputPartition]) extends MicroBatchScan { + + override def createReaderFactory(): PartitionReaderFactory = { + MemoryStreamReaderFactory + } + + override def planInputPartitions(): Array[InputPartition] = { + partitions + } +} class MemoryStreamInputPartition(val records: Array[UnsafeRow]) extends InputPartition diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index dbcc4483e577..097461976172 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -30,9 +30,10 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.{Encoder, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.streaming.{Offset => _, _} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, Format, SupportsContinuousRead, Table} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig} import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils /** @@ -44,16 +45,10 @@ import org.apache.spark.util.RpcUtils * the specified offset within the list, or null if that offset doesn't yet have a record. */ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) - extends MemoryStreamBase[A](sqlContext) - with ContinuousReadSupportProvider with ContinuousReadSupport { + extends MemoryStreamBase[A](sqlContext) with ContinuousInputStream with Format { private implicit val formats = Serialization.formats(NoTypeHints) - protected val logicalPlan = - StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession) - - // ContinuousReader implementation - @GuardedBy("this") private val records = Seq.fill(numPartitions)(new ListBuffer[A]) @@ -86,14 +81,9 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa ) } - override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start) - } - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val startOffset = config.asInstanceOf[SimpleStreamingScanConfig] - .start.asInstanceOf[ContinuousMemoryStreamOffset] - synchronized { + override def createContinuousScan(start: Offset): ContinuousScan = { + val startOffset = start.asInstanceOf[ContinuousMemoryStreamOffset] + val partitions: Array[InputPartition] = synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) @@ -102,11 +92,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa case (part, index) => ContinuousMemoryStreamInputPartition(endpointName, part, index) }.toArray } - } - - override def createContinuousReaderFactory( - config: ScanConfig): ContinuousPartitionReaderFactory = { - ContinuousMemoryStreamReaderFactory + new MemoryStreamContinuousScan(partitions) } override def stop(): Unit = { @@ -115,11 +101,33 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa override def commit(end: Offset): Unit = {} - // ContinuousReadSupportProvider implementation + // Format implementation // This is necessary because of how StreamTest finds the source for AddDataMemory steps. - override def createContinuousReadSupport( - checkpointLocation: String, - options: DataSourceOptions): ContinuousReadSupport = this + override def getTable(options: DataSourceOptions): Table = { + new Table with SupportsContinuousRead { + override def schema(): StructType = { + ContinuousMemoryStream.this.encoder.schema + } + + def createContinuousInputStream( + checkpointLocation: String, + config: ScanConfig, + options: DataSourceOptions): ContinuousInputStream = { + ContinuousMemoryStream.this + } + } + } +} + +class MemoryStreamContinuousScan(partitions: Array[InputPartition]) extends ContinuousScan { + + override def createContinuousReaderFactory(): ContinuousPartitionReaderFactory = { + ContinuousMemoryStreamReaderFactory + } + + override def planInputPartitions(): Array[InputPartition] = { + partitions + } } object ContinuousMemoryStream { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchInputStream.scala similarity index 87% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchInputStream.scala index 90680ea38fbd..d7c32de10968 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchInputStream.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.streaming.sources -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchInputStream, Offset} -// A special `MicroBatchReadSupport` that can get latestOffset with a start offset. -trait RateControlMicroBatchReadSupport extends MicroBatchReadSupport { +// A special `MicroBatchInputStream` that can get latestOffset with a start offset. +trait RateControlMicroBatchInputStream extends MicroBatchInputStream { override def latestOffset(): Offset = { throw new IllegalAccessException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchInputStream.scala similarity index 84% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchInputStream.scala index f5364047adff..3488a5f7d887 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchInputStream.scala @@ -31,12 +31,11 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchInputStream, MicroBatchScan, Offset} import org.apache.spark.util.{ManualClock, SystemClock} -class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReadSupport with Logging { +class RateStreamMicroBatchInputStream(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchInputStream with Logging { import RateStreamProvider._ private[sources] val clock = { @@ -60,6 +59,14 @@ class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLoca s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") } + private val numPartitions = { + val activeSession = SparkSession.getActiveSession + require(activeSession.isDefined) + Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) + .map(_.toInt) + .getOrElse(activeSession.get.sparkContext.defaultParallelism) + } + private[sources] val creationTimeMs = { val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) require(session.isDefined) @@ -70,7 +77,7 @@ class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLoca val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) writer.write("v" + VERSION + "\n") writer.write(metadata.json) - writer.flush + writer.flush() } override def deserialize(in: InputStream): LongOffset = { @@ -117,16 +124,44 @@ class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLoca LongOffset(json.toLong) } - override def fullSchema(): StructType = SCHEMA + override def createMicroBatchScan(start: Offset, end: Offset): MicroBatchScan = { + new RateSteamMicroBatchScan( + maxSeconds, + rowsPerSecond, + creationTimeMs, + rampUpTimeSeconds, + numPartitions, + start.asInstanceOf[LongOffset], end.asInstanceOf[LongOffset]) + } + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = {} + + override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" +} + +class RateSteamMicroBatchScan( + maxSeconds: Long, + rowsPerSecond: Long, + creationTimeMs: Long, + rampUpTimeSeconds: Long, + numPartitions: Int, + start: LongOffset, + end: LongOffset) extends MicroBatchScan with Logging { + import RateStreamProvider._ + + @volatile private var lastTimeMs: Long = creationTimeMs - override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + override def createReaderFactory(): PartitionReaderFactory = { + RateStreamMicroBatchReaderFactory } - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val sc = config.asInstanceOf[SimpleStreamingScanConfig] - val startSeconds = sc.start.asInstanceOf[LongOffset].offset - val endSeconds = sc.end.get.asInstanceOf[LongOffset].offset + override def planInputPartitions(): Array[InputPartition] = { + val startSeconds = start.offset + val endSeconds = end.offset assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") if (endSeconds > maxSeconds) { throw new ArithmeticException("Integer overflow. Max offset with " + @@ -148,31 +183,12 @@ class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLoca val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) val relativeMsPerValue = TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - val numPartitions = { - val activeSession = SparkSession.getActiveSession - require(activeSession.isDefined) - Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) - .map(_.toInt) - .getOrElse(activeSession.get.sparkContext.defaultParallelism) - } (0 until numPartitions).map { p => - new RateStreamMicroBatchInputPartition( + RateStreamMicroBatchInputPartition( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) }.toArray } - - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - RateStreamMicroBatchReaderFactory - } - - override def commit(end: Offset): Unit = {} - - override def stop(): Unit = {} - - override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " + - s"rampUpTimeSeconds=$rampUpTimeSeconds, " + - s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" } case class RateStreamMicroBatchInputPartition( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 6942dfbfe0ec..bfb2a7ad2afb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReadSupport +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousInputStream import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.ScanConfig +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, MicroBatchInputStream, MicroBatchScan} import org.apache.spark.sql.types._ /** @@ -38,14 +39,16 @@ import org.apache.spark.sql.types._ * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ -class RateStreamProvider extends DataSourceV2 - with MicroBatchReadSupportProvider with ContinuousReadSupportProvider with DataSourceRegister { +class RateStreamProvider extends Format with DataSourceRegister { import RateStreamProvider._ - override def createMicroBatchReadSupport( - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReadSupport = { - if (options.get(ROWS_PER_SECOND).isPresent) { + override def getTable(options: DataSourceOptions): Table = { + validateOptions(options) + RateStreamTable + } + + private def validateOptions(options: DataSourceOptions): Unit = { + if (options.get(ROWS_PER_SECOND).isPresent) { val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong if (rowsPerSecond <= 0) { throw new IllegalArgumentException( @@ -69,17 +72,29 @@ class RateStreamProvider extends DataSourceV2 s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") } } - - new RateStreamMicroBatchReadSupport(options, checkpointLocation) } - override def createContinuousReadSupport( - checkpointLocation: String, - options: DataSourceOptions): ContinuousReadSupport = { - new RateStreamContinuousReadSupport(options) + override def shortName(): String = "rate" +} + +object RateStreamTable extends Table + with SupportsMicroBatchRead with SupportsContinuousRead { + + override def schema(): StructType = RateStreamProvider.SCHEMA + + override def createMicroBatchInputStream( + checkpointLocation: String, + config: ScanConfig, + options: DataSourceOptions): MicroBatchInputStream = { + new RateStreamMicroBatchInputStream(options, checkpointLocation) } - override def shortName(): String = "rate" + override def createContinuousInputStream( + checkpointLocation: String, + config: ScanConfig, + options: DataSourceOptions): ContinuousInputStream = { + new RateStreamContinuousInputStream(options) + } } object RateStreamProvider { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchInputStream.scala similarity index 62% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchInputStream.scala index b2a573eae504..aa8c4f430a4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchInputStream.scala @@ -19,41 +19,29 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket -import java.text.SimpleDateFormat -import java.util.{Calendar, Locale} +import java.util.Calendar import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ListBuffer -import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging -import org.apache.spark.sql._ +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{LongOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} -import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReadSupport -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, DataSourceV2, MicroBatchReadSupportProvider} +import org.apache.spark.sql.execution.streaming.LongOffset +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport, Offset} -import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchInputStream, MicroBatchScan, Offset} import org.apache.spark.unsafe.types.UTF8String -object TextSocketReader { - val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) - val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: - StructField("timestamp", TimestampType) :: Nil) - val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) -} - /** - * A MicroBatchReadSupport that reads text lines through a TCP socket, designed only for tutorials - * and debugging. This MicroBatchReadSupport will *not* work in production applications due to + * A MicroBatchInputStream that reads text lines through a TCP socket, designed only for tutorials + * and debugging. This MicroBatchInputStream will *not* work in production applications due to * multiple reasons, including no support for fault recovery. */ -class TextSocketMicroBatchReadSupport(options: DataSourceOptions) - extends MicroBatchReadSupport with Logging { +class TextSocketMicroBatchInputStream(options: DataSourceOptions) + extends MicroBatchInputStream with Logging { private val host: String = options.get("host").get() private val port: Int = options.get("port").get().toInt @@ -99,7 +87,7 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) logWarning(s"Stream closed by $host:$port") return } - TextSocketMicroBatchReadSupport.this.synchronized { + TextSocketMicroBatchInputStream.this.synchronized { val newData = ( UTF8String.fromString(line), DateTimeUtils.fromMillis(Calendar.getInstance().getTimeInMillis) @@ -124,22 +112,9 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) LongOffset(json.toLong) } - override def fullSchema(): StructType = { - if (options.getBoolean("includeTimestamp", false)) { - TextSocketReader.SCHEMA_TIMESTAMP - } else { - TextSocketReader.SCHEMA_REGULAR - } - } - - override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) - } - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val sc = config.asInstanceOf[SimpleStreamingScanConfig] - val startOrdinal = sc.start.asInstanceOf[LongOffset].offset.toInt + 1 - val endOrdinal = sc.end.get.asInstanceOf[LongOffset].offset.toInt + 1 + override def createMicroBatchScan(start: Offset, end: Offset): MicroBatchScan = { + val startOrdinal = start.asInstanceOf[LongOffset].offset.toInt + 1 + val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1 // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { @@ -161,29 +136,7 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) slices(idx % numPartitions).append(r) } - slices.map(TextSocketInputPartition) - } - - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - new PartitionReaderFactory { - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val slice = partition.asInstanceOf[TextSocketInputPartition].slice - new PartitionReader[InternalRow] { - private var currentIdx = -1 - - override def next(): Boolean = { - currentIdx += 1 - currentIdx < slice.size - } - - override def get(): InternalRow = { - InternalRow(slice(currentIdx)._1, slice(currentIdx)._2) - } - - override def close(): Unit = {} - } - } - } + new TextSocketMicroBatchScan(slices.map(TextSocketInputPartition)) } override def commit(end: Offset): Unit = synchronized { @@ -219,44 +172,33 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) override def toString: String = s"TextSocketV2[host: $host, port: $port]" } -case class TextSocketInputPartition(slice: ListBuffer[(UTF8String, Long)]) extends InputPartition +class TextSocketMicroBatchScan(partitions: Array[InputPartition]) extends MicroBatchScan { -class TextSocketSourceProvider extends DataSourceV2 - with MicroBatchReadSupportProvider with ContinuousReadSupportProvider - with DataSourceRegister with Logging { + override def createReaderFactory(): PartitionReaderFactory = { + new PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val slice = partition.asInstanceOf[TextSocketInputPartition].slice + new PartitionReader[InternalRow] { + private var currentIdx = -1 - private def checkParameters(params: DataSourceOptions): Unit = { - logWarning("The socket source should not be used for production applications! " + - "It does not support recovery.") - if (!params.get("host").isPresent) { - throw new AnalysisException("Set a host to read from with option(\"host\", ...).") - } - if (!params.get("port").isPresent) { - throw new AnalysisException("Set a port to read from with option(\"port\", ...).") - } - Try { - params.get("includeTimestamp").orElse("false").toBoolean - } match { - case Success(_) => - case Failure(_) => - throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") - } - } + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } - override def createMicroBatchReadSupport( - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReadSupport = { - checkParameters(options) - new TextSocketMicroBatchReadSupport(options) - } + override def get(): InternalRow = { + InternalRow(slice(currentIdx)._1, slice(currentIdx)._2) + } - override def createContinuousReadSupport( - checkpointLocation: String, - options: DataSourceOptions): ContinuousReadSupport = { - checkParameters(options) - new TextSocketContinuousReadSupport(options) + override def close(): Unit = {} + } + } + } } - /** String that represents the format that this data source provider uses. */ - override def shortName(): String = "socket" + override def planInputPartitions(): Array[InputPartition] = { + partitions + } } + +case class TextSocketInputPartition(slice: ListBuffer[(UTF8String, Long)]) extends InputPartition diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala new file mode 100644 index 000000000000..e2a4f15b6752 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.text.SimpleDateFormat +import java.util.Locale + +import scala.util.{Failure, Success, Try} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousInputStream +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} + +object TextSocketReader { + val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) + val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: + StructField("timestamp", TimestampType) :: Nil) + val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) +} + +class TextSocketSourceProvider extends Format with DataSourceRegister with Logging { + + override def getTable(options: DataSourceOptions): TextSocketTable = { + new TextSocketTable(options) + } + + /** String that represents the format that this data source provider uses. */ + override def shortName(): String = "socket" +} + +class TextSocketTable(options: DataSourceOptions) extends Table + with SupportsMicroBatchRead with SupportsContinuousRead with Logging { + + override def schema(): StructType = { + if (options.getBoolean("includeTimestamp", false)) { + TextSocketReader.SCHEMA_TIMESTAMP + } else { + TextSocketReader.SCHEMA_REGULAR + } + } + + override def createMicroBatchInputStream( + checkpointLocation: String, + config: ScanConfig, + options: DataSourceOptions): MicroBatchInputStream = { + checkParameters(options) + new TextSocketMicroBatchInputStream(options) + } + + override def createContinuousInputStream( + checkpointLocation: String, + config: ScanConfig, + options: DataSourceOptions): ContinuousInputStream = { + checkParameters(options) + new TextSocketContinuousInputStream(options) + } + + private def checkParameters(params: DataSourceOptions): Unit = { + logWarning("The socket source should not be used for production applications! " + + "It does not support recovery.") + if (!params.get("host").isPresent) { + throw new AnalysisException("Set a host to read from with option(\"host\", ...).") + } + if (!params.get("port").isPresent) { + throw new AnalysisException("Set a port to read from with option(\"port\", ...).") + } + Try { + params.get("includeTimestamp").orElse("false").toBoolean + } match { + case Success(_) => + case Failure(_) => + throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 4c7dcedafeea..13255e0b6eed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -29,10 +29,8 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils /** * Interface used to load a streaming `Dataset` from external storage systems (e.g. file systems, @@ -172,60 +170,27 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo case _ => None } ds match { - case s: MicroBatchReadSupportProvider => - val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - ds = s, conf = sparkSession.sessionState.conf) - val options = sessionOptions ++ extraOptions - val dataSourceOptions = new DataSourceOptions(options.asJava) - var tempReadSupport: MicroBatchReadSupport = null - val schema = try { - val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath - tempReadSupport = if (userSpecifiedSchema.isDefined) { - s.createMicroBatchReadSupport( - userSpecifiedSchema.get, tmpCheckpointPath, dataSourceOptions) - } else { - s.createMicroBatchReadSupport(tmpCheckpointPath, dataSourceOptions) - } - tempReadSupport.fullSchema() - } finally { - // Stop tempReader to avoid side-effect thing - if (tempReadSupport != null) { - tempReadSupport.stop() - tempReadSupport = null - } + case f: Format => + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + ds = f, conf = sparkSession.sessionState.conf) + val options = sessionOptions ++ extraOptions + val dsOptions = new DataSourceOptions(options.asJava) + val table = userSpecifiedSchema match { + case Some(schema) => f.getTable(dsOptions, schema) + case _ => f.getTable(dsOptions) } - Dataset.ofRows( - sparkSession, - StreamingRelationV2( - s, source, options, - schema.toAttributes, v1Relation)(sparkSession)) - case s: ContinuousReadSupportProvider => - val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - ds = s, conf = sparkSession.sessionState.conf) - val options = sessionOptions ++ extraOptions - val dataSourceOptions = new DataSourceOptions(options.asJava) - var tempReadSupport: ContinuousReadSupport = null - val schema = try { - val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath - tempReadSupport = if (userSpecifiedSchema.isDefined) { - s.createContinuousReadSupport( - userSpecifiedSchema.get, tmpCheckpointPath, dataSourceOptions) - } else { - s.createContinuousReadSupport(tmpCheckpointPath, dataSourceOptions) - } - tempReadSupport.fullSchema() - } finally { - // Stop tempReader to avoid side-effect thing - if (tempReadSupport != null) { - tempReadSupport.stop() - tempReadSupport = null - } + + table match { + case _: SupportsMicroBatchRead => + case _: SupportsContinuousRead => + case _ => return Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) } + Dataset.ofRows( sparkSession, StreamingRelationV2( - s, source, options, - schema.toAttributes, v1Relation)(sparkSession)) + source, f, table, options, + table.schema().toAttributes, v1Relation)(sparkSession)) case _ => // Code path for data source v1. Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 5602310219a7..c47ef696f62c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -24,27 +24,42 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.*; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -public class JavaAdvancedDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { +public class JavaAdvancedDataSourceV2 implements Format { + + class MyTable implements Table, SupportsBatchRead { - public class ReadSupport extends JavaSimpleReadSupport { @Override - public ScanConfigBuilder newScanConfigBuilder() { + public ScanConfigBuilder newScanConfigBuilder(DataSourceOptions options) { return new AdvancedScanConfigBuilder(); } @Override - public InputPartition[] planInputPartitions(ScanConfig config) { - Filter[] filters = ((AdvancedScanConfigBuilder) config).filters; - List res = new ArrayList<>(); + public BatchScan createBatchScan(ScanConfig config, DataSourceOptions options) { + return new AdvancedBatchScan((AdvancedScanConfigBuilder) config); + } + + @Override + public StructType schema() { + return new StructType().add("i", "int").add("j", "int"); + } + } + public static class AdvancedBatchScan implements BatchScan { + public AdvancedScanConfigBuilder config; + + AdvancedBatchScan(AdvancedScanConfigBuilder config) { + this.config = config; + } + + @Override + public InputPartition[] planInputPartitions() { + List res = new ArrayList<>(); Integer lowerBound = null; - for (Filter filter : filters) { + for (Filter filter : config.filters) { if (filter instanceof GreaterThan) { GreaterThan f = (GreaterThan) filter; if ("i".equals(f.attribute()) && f.value() instanceof Integer) { @@ -68,12 +83,12 @@ public InputPartition[] planInputPartitions(ScanConfig config) { } @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { - StructType requiredSchema = ((AdvancedScanConfigBuilder) config).requiredSchema; - return new AdvancedReaderFactory(requiredSchema); + public PartitionReaderFactory createReaderFactory() { + return new AdvancedReaderFactory(config.requiredSchema); } } + public static class AdvancedScanConfigBuilder implements ScanConfigBuilder, ScanConfig, SupportsPushDownFilters, SupportsPushDownRequiredColumns { @@ -166,9 +181,8 @@ public void close() throws IOException { } } - @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); + public Table getTable(DataSourceOptions options) { + return new MyTable(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java index 28a933039831..df4f0c676591 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java @@ -21,21 +21,18 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.*; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; +public class JavaColumnarDataSourceV2 implements Format { -public class JavaColumnarDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { - - class ReadSupport extends JavaSimpleReadSupport { + class MyTable extends SimpleBatchReadTable { @Override - public InputPartition[] planInputPartitions(ScanConfig config) { + public InputPartition[] planInputPartitions() { InputPartition[] partitions = new InputPartition[2]; partitions[0] = new JavaRangeInputPartition(0, 50); partitions[1] = new JavaRangeInputPartition(50, 90); @@ -43,7 +40,7 @@ public InputPartition[] planInputPartitions(ScanConfig config) { } @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { + public PartitionReaderFactory createReaderFactory() { return new ColumnarReaderFactory(); } } @@ -108,7 +105,7 @@ public void close() throws IOException { } @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); + public Table getTable(DataSourceOptions options) { + return new MyTable(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 18a11dde8219..560c54ac1c84 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -28,12 +28,17 @@ import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; -public class JavaPartitionAwareDataSource implements DataSourceV2, BatchReadSupportProvider { +public class JavaPartitionAwareDataSource implements Format { - class ReadSupport extends JavaSimpleReadSupport implements SupportsReportPartitioning { + class MyTable extends SimpleBatchReadTable implements SupportsReportPartitioning { @Override - public InputPartition[] planInputPartitions(ScanConfig config) { + public PartitionReaderFactory createReaderFactory() { + return new SpecificReaderFactory(); + } + + @Override + public InputPartition[] planInputPartitions() { InputPartition[] partitions = new InputPartition[2]; partitions[0] = new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}); partitions[1] = new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2}); @@ -41,12 +46,7 @@ public InputPartition[] planInputPartitions(ScanConfig config) { } @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { - return new SpecificReaderFactory(); - } - - @Override - public Partitioning outputPartitioning(ScanConfig config) { + public Partitioning outputPartitioning() { return new MyPartitioning(); } } @@ -108,7 +108,7 @@ public void close() throws IOException { } @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); + public Table getTable(DataSourceOptions options) { + return new MyTable(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index cc9ac04a0dad..2b68f987568d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -17,39 +17,37 @@ package test.org.apache.spark.sql.sources.v2; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.*; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -public class JavaSchemaRequiredDataSource implements DataSourceV2, BatchReadSupportProvider { +public class JavaSchemaRequiredDataSource implements Format { - class ReadSupport extends JavaSimpleReadSupport { + class MyTable extends JavaSimpleBatchReadTable { private final StructType schema; - ReadSupport(StructType schema) { + MyTable(StructType schema) { this.schema = schema; } @Override - public StructType fullSchema() { - return schema; + public StructType schema() { + return this.schema; } @Override - public InputPartition[] planInputPartitions(ScanConfig config) { + public InputPartition[] planInputPartitions() { return new InputPartition[0]; } } @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + public Table getTable(DataSourceOptions options) { throw new IllegalArgumentException("requires a user-supplied schema"); } @Override - public BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) { - return new ReadSupport(schema); + public Table getTable(DataSourceOptions options, StructType schema) { + return new MyTable(schema); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchReadTable.java similarity index 78% rename from sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java rename to sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchReadTable.java index 685f9b9747e8..bafebc8cdea0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchReadTable.java @@ -21,46 +21,29 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.sources.v2.SupportsBatchRead; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -abstract class JavaSimpleReadSupport implements BatchReadSupport { +abstract class JavaSimpleBatchReadTable implements SupportsBatchRead, BatchScan { @Override - public StructType fullSchema() { - return new StructType().add("i", "int").add("j", "int"); + public BatchScan createBatchScan(ScanConfig config, DataSourceOptions options) { + return this; } @Override - public ScanConfigBuilder newScanConfigBuilder() { - return new JavaNoopScanConfigBuilder(fullSchema()); + public StructType schema() { + return new StructType().add("i", "int").add("j", "int"); } @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { + public PartitionReaderFactory createReaderFactory() { return new JavaSimpleReaderFactory(); } } -class JavaNoopScanConfigBuilder implements ScanConfigBuilder, ScanConfig { - - private StructType schema; - - JavaNoopScanConfigBuilder(StructType schema) { - this.schema = schema; - } - - @Override - public ScanConfig build() { - return this; - } - - @Override - public StructType readSchema() { - return schema; - } -} - class JavaSimpleReaderFactory implements PartitionReaderFactory { @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 2cdbba84ec4a..23896e29c160 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -17,17 +17,14 @@ package test.org.apache.spark.sql.sources.v2; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.*; import org.apache.spark.sql.sources.v2.reader.*; -public class JavaSimpleDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { - - class ReadSupport extends JavaSimpleReadSupport { +public class JavaSimpleDataSourceV2 implements Format { + class MyTable extends JavaSimpleBatchReadTable { @Override - public InputPartition[] planInputPartitions(ScanConfig config) { + public InputPartition[] planInputPartitions() { InputPartition[] partitions = new InputPartition[2]; partitions[0] = new JavaRangeInputPartition(0, 5); partitions[1] = new JavaRangeInputPartition(5, 10); @@ -36,7 +33,7 @@ public InputPartition[] planInputPartitions(ScanConfig config) { } @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); + public Table getTable(DataSourceOptions options) { + return new MyTable(); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index dd74af873c2e..4bb467350467 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -25,15 +25,16 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, Format, SupportsMicroBatchRead} import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock -class RateSourceSuite extends StreamTest { +class RateStreamProviderSuite extends StreamTest { import testImplicits._ @@ -41,7 +42,9 @@ class RateSourceSuite extends StreamTest { override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { assert(query.nonEmpty) val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReadSupport, _) => source + case r: StreamingDataSourceV2Relation + if r.stream.isInstanceOf[RateStreamMicroBatchInputStream] => + r.stream.asInstanceOf[RateStreamMicroBatchInputStream] }.head rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) @@ -51,27 +54,16 @@ class RateSourceSuite extends StreamTest { } } - test("microbatch in registry") { - withTempDir { temp => - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupportProvider => - val readSupport = ds.createMicroBatchReadSupport( - temp.getCanonicalPath, DataSourceOptions.empty()) - assert(readSupport.isInstanceOf[RateStreamMicroBatchReadSupport]) - case _ => - throw new IllegalStateException("Could not find read support for rate") - } - } + test("RateStreamProvider in registry") { + val ds = DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() + assert(ds.isInstanceOf[RateStreamProvider], "Could not find rate source") } test("compatible with old path in registry") { - DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", - spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupportProvider => - assert(ds.isInstanceOf[RateStreamProvider]) - case _ => - throw new IllegalStateException("Could not find read support for rate") - } + val ds = DataSource.lookupDataSource( + "org.apache.spark.sql.execution.streaming.RateSourceProvider", + spark.sqlContext.conf).newInstance() + assert(ds.isInstanceOf[RateStreamProvider], "Could not find rate source") } test("microbatch - basic") { @@ -141,17 +133,17 @@ class RateSourceSuite extends StreamTest { test("microbatch - infer offsets") { withTempDir { temp => - val readSupport = new RateStreamMicroBatchReadSupport( + val stream = new RateStreamMicroBatchInputStream( new DataSourceOptions( Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), temp.getCanonicalPath) - readSupport.clock.asInstanceOf[ManualClock].advance(100000) - val startOffset = readSupport.initialOffset() + stream.clock.asInstanceOf[ManualClock].advance(100000) + val startOffset = stream.initialOffset() startOffset match { case r: LongOffset => assert(r.offset === 0L) case _ => throw new IllegalStateException("unexpected offset type") } - readSupport.latestOffset() match { + stream.latestOffset() match { case r: LongOffset => assert(r.offset >= 100) case _ => throw new IllegalStateException("unexpected offset type") } @@ -160,16 +152,14 @@ class RateSourceSuite extends StreamTest { test("microbatch - predetermined batch size") { withTempDir { temp => - val readSupport = new RateStreamMicroBatchReadSupport( + val stream = new RateStreamMicroBatchInputStream( new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp.getCanonicalPath) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() - val tasks = readSupport.planInputPartitions(config) - val readerFactory = readSupport.createReaderFactory(config) - assert(tasks.size == 1) - val dataReader = readerFactory.createReader(tasks(0)) + val scan = stream.createMicroBatchScan(LongOffset(0L), LongOffset(1L)) + val partitions = scan.planInputPartitions() + val readerFactory = scan.createReaderFactory() + assert(partitions.size == 1) + val dataReader = readerFactory.createReader(partitions(0)) val data = ArrayBuffer[InternalRow]() while (dataReader.next()) { data.append(dataReader.get()) @@ -180,17 +170,15 @@ class RateSourceSuite extends StreamTest { test("microbatch - data read") { withTempDir { temp => - val readSupport = new RateStreamMicroBatchReadSupport( + val stream = new RateStreamMicroBatchInputStream( new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp.getCanonicalPath) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() - val tasks = readSupport.planInputPartitions(config) - val readerFactory = readSupport.createReaderFactory(config) - assert(tasks.size == 11) - - val readData = tasks + val scan = stream.createMicroBatchScan(LongOffset(0L), LongOffset(1L)) + val partitions = scan.planInputPartitions() + val readerFactory = scan.createReaderFactory() + assert(partitions.size == 11) + + val readData = partitions .map(readerFactory.createReader) .flatMap { reader => val buf = scala.collection.mutable.ListBuffer[InternalRow]() @@ -319,29 +307,18 @@ class RateSourceSuite extends StreamTest { "rate source does not support user-specified schema")) } - test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupportProvider => - val readSupport = ds.createContinuousReadSupport( - "", DataSourceOptions.empty()) - assert(readSupport.isInstanceOf[RateStreamContinuousReadSupport]) - case _ => - throw new IllegalStateException("Could not find read support for continuous rate") - } - } - test("continuous data") { - val readSupport = new RateStreamContinuousReadSupport( + val stream = new RateStreamContinuousInputStream( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - val config = readSupport.newScanConfigBuilder(readSupport.initialOffset).build() - val tasks = readSupport.planInputPartitions(config) - val readerFactory = readSupport.createContinuousReaderFactory(config) - assert(tasks.size == 2) + val scan = stream.createContinuousScan(stream.initialOffset) + val partitions = scan.planInputPartitions() + val readerFactory = scan.createContinuousReaderFactory() + assert(partitions.size == 2) val data = scala.collection.mutable.ListBuffer[InternalRow]() - tasks.foreach { + partitions.foreach { case t: RateStreamContinuousInputPartition => - val startTimeMs = readSupport.initialOffset() + val startTimeMs = stream.initialOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 409156e5ebc7..760c6f367d40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -30,10 +30,11 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsMicroBatchRead} import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext @@ -59,7 +60,9 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before "Cannot add data when there is no query for finding the active socket source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: TextSocketMicroBatchReadSupport, _) => source + case r: StreamingDataSourceV2Relation + if r.stream.isInstanceOf[TextSocketMicroBatchInputStream] => + r.stream.asInstanceOf[TextSocketMicroBatchInputStream] } if (sources.isEmpty) { throw new Exception( @@ -83,13 +86,10 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before } test("backward compatibility with old path") { - DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", - spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupportProvider => - assert(ds.isInstanceOf[TextSocketSourceProvider]) - case _ => - throw new IllegalStateException("Could not find socket source") - } + val ds = DataSource.lookupDataSource( + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", + spark.sqlContext.conf).newInstance() + assert(ds.isInstanceOf[TextSocketSourceProvider], "Could not find socket source") } test("basic usage") { @@ -173,39 +173,37 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before } test("params not given") { - val provider = new TextSocketSourceProvider + val table = new TextSocketSourceProvider().getTable(DataSourceOptions.empty()) intercept[AnalysisException] { - provider.createMicroBatchReadSupport( - "", new DataSourceOptions(Map.empty[String, String].asJava)) + table.createMicroBatchInputStream( + "", null, new DataSourceOptions(Map.empty[String, String].asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReadSupport( - "", new DataSourceOptions(Map("host" -> "localhost").asJava)) + table.createMicroBatchInputStream( + "", null, new DataSourceOptions(Map("host" -> "localhost").asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReadSupport( - "", new DataSourceOptions(Map("port" -> "1234").asJava)) + table.createMicroBatchInputStream( + "", null, new DataSourceOptions(Map("port" -> "1234").asJava)) } } test("non-boolean includeTimestamp") { - val provider = new TextSocketSourceProvider + val table = new TextSocketSourceProvider().getTable(DataSourceOptions.empty()) val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") intercept[AnalysisException] { val a = new DataSourceOptions(params.asJava) - provider.createMicroBatchReadSupport("", a) + table.createMicroBatchInputStream("", null, a) } } test("user-specified schema given") { - val provider = new TextSocketSourceProvider + val provider = new TextSocketSourceProvider() val userSpecifiedSchema = StructType( StructField("name", StringType) :: StructField("area", StringType) :: Nil) - val params = Map("host" -> "localhost", "port" -> "1234") val exception = intercept[UnsupportedOperationException] { - provider.createMicroBatchReadSupport( - userSpecifiedSchema, "", new DataSourceOptions(params.asJava)) + provider.getTable(DataSourceOptions.empty(), userSpecifiedSchema) } assert(exception.getMessage.contains( "socket source does not support user-specified schema")) @@ -299,25 +297,24 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val readSupport = new TextSocketContinuousReadSupport( + val stream = new TextSocketContinuousInputStream( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "port" -> serverThread.port.toString).asJava)) - - val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build() - val tasks = readSupport.planInputPartitions(scanConfig) - assert(tasks.size == 2) + val scan = stream.createContinuousScan(stream.initialOffset()) + val partitions = scan.planInputPartitions() + assert(partitions.size == 2) val numRecords = 10 val data = scala.collection.mutable.ListBuffer[Int]() val offsets = scala.collection.mutable.ListBuffer[Int]() - val readerFactory = readSupport.createContinuousReaderFactory(scanConfig) + val readerFactory = scan.createContinuousReaderFactory() import org.scalatest.time.SpanSugar._ failAfter(5 seconds) { // inject rows, read and check the data and offsets for (i <- 0 until numRecords) { serverThread.enqueue(i.toString) } - tasks.foreach { + partitions.foreach { case t: TextSocketContinuousInputPartition => val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] for (i <- 0 until numRecords / 2) { @@ -335,15 +332,15 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before data.clear() case _ => throw new IllegalStateException("Unexpected task type") } - assert(readSupport.startOffset.offsets == List(3, 3)) - readSupport.commit(TextSocketOffset(List(5, 5))) - assert(readSupport.startOffset.offsets == List(5, 5)) + assert(stream.startOffset.offsets == List(3, 3)) + stream.commit(TextSocketOffset(List(5, 5))) + assert(stream.startOffset.offsets == List(5, 5)) } def commitOffset(partition: Int, offset: Int): Unit = { - val offsetsToCommit = readSupport.startOffset.offsets.updated(partition, offset) - readSupport.commit(TextSocketOffset(offsetsToCommit)) - assert(readSupport.startOffset.offsets == offsetsToCommit) + val offsetsToCommit = stream.startOffset.offsets.updated(partition, offset) + stream.commit(TextSocketOffset(offsetsToCommit)) + assert(stream.startOffset.offsets == offsetsToCommit) } } @@ -351,13 +348,13 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val readSupport = new TextSocketContinuousReadSupport( + val stream = new TextSocketContinuousInputStream( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "port" -> serverThread.port.toString).asJava)) - readSupport.startOffset = TextSocketOffset(List(5, 5)) + stream.startOffset = TextSocketOffset(List(5, 5)) assertThrows[IllegalStateException] { - readSupport.commit(TextSocketOffset(List(6, 6))) + stream.commit(TextSocketOffset(List(6, 6))) } } @@ -365,21 +362,21 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val readSupport = new TextSocketContinuousReadSupport( + val stream = new TextSocketContinuousInputStream( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "includeTimestamp" -> "true", "port" -> serverThread.port.toString).asJava)) - val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build() - val tasks = readSupport.planInputPartitions(scanConfig) - assert(tasks.size == 2) + val scan = stream.createContinuousScan(stream.initialOffset()) + val partitions = scan.planInputPartitions() + assert(partitions.size == 2) val numRecords = 4 // inject rows, read and check the data and offsets for (i <- 0 until numRecords) { serverThread.enqueue(i.toString) } - val readerFactory = readSupport.createContinuousReaderFactory(scanConfig) - tasks.foreach { + val readerFactory = scan.createContinuousReaderFactory() + partitions.foreach { case t: TextSocketContinuousInputPartition => val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] for (i <- 0 until numRecords / 2) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index e8f291af13ba..c4607086afdf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -41,7 +41,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { private def getScanConfig(query: DataFrame): AdvancedScanConfigBuilder = { query.queryExecution.executedPlan.collect { case d: DataSourceV2ScanExec => - d.scanConfig.asInstanceOf[AdvancedScanConfigBuilder] + d.scan.asInstanceOf[AdvancedBatchScan].config }.head } @@ -49,7 +49,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { query: DataFrame): JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder = { query.queryExecution.executedPlan.collect { case d: DataSourceV2ScanExec => - d.scanConfig.asInstanceOf[JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder] + d.scan.asInstanceOf[JavaAdvancedDataSourceV2.AdvancedBatchScan].config }.head } @@ -374,10 +374,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { case class RangeInputPartition(start: Int, end: Int) extends InputPartition -case class NoopScanConfigBuilder(readSchema: StructType) extends ScanConfigBuilder with ScanConfig { - override def build(): ScanConfig = this -} - object SimpleReaderFactory extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val RangeInputPartition(start, end) = partition @@ -396,83 +392,54 @@ object SimpleReaderFactory extends PartitionReaderFactory { } } -abstract class SimpleReadSupport extends BatchReadSupport { - override def fullSchema(): StructType = new StructType().add("i", "int").add("j", "int") +abstract class SimpleBatchReadTable extends Table with SupportsBatchRead with BatchScan { - override def newScanConfigBuilder(): ScanConfigBuilder = { - NoopScanConfigBuilder(fullSchema()) - } + override def createBatchScan(config: ScanConfig, options: DataSourceOptions): BatchScan = this - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - SimpleReaderFactory - } + override def schema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createReaderFactory(): PartitionReaderFactory = SimpleReaderFactory } -class SimpleSinglePartitionSource extends DataSourceV2 with BatchReadSupportProvider { +class SimpleSinglePartitionSource extends Format { - class ReadSupport extends SimpleReadSupport { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + class MyTable extends SimpleBatchReadTable { + override def planInputPartitions(): Array[InputPartition] = { Array(RangeInputPartition(0, 5)) } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } + override def getTable(options: DataSourceOptions): Table = new MyTable() } // This class is used by pyspark tests. If this class is modified/moved, make sure pyspark // tests still pass. -class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { +class SimpleDataSourceV2 extends Format { - class ReadSupport extends SimpleReadSupport { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + class MyTable extends SimpleBatchReadTable { + override def planInputPartitions(): Array[InputPartition] = { Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } + override def getTable(options: DataSourceOptions): Table = new MyTable } -class AdvancedDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { - - class ReadSupport extends SimpleReadSupport { - override def newScanConfigBuilder(): ScanConfigBuilder = new AdvancedScanConfigBuilder() - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val filters = config.asInstanceOf[AdvancedScanConfigBuilder].filters - - val lowerBound = filters.collectFirst { - case GreaterThan("i", v: Int) => v - } - - val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] - - if (lowerBound.isEmpty) { - res.append(RangeInputPartition(0, 5)) - res.append(RangeInputPartition(5, 10)) - } else if (lowerBound.get < 4) { - res.append(RangeInputPartition(lowerBound.get + 1, 5)) - res.append(RangeInputPartition(5, 10)) - } else if (lowerBound.get < 9) { - res.append(RangeInputPartition(lowerBound.get + 1, 10)) - } - - res.toArray +class AdvancedDataSourceV2 extends Format { + class MyTable extends SupportsBatchRead { + override def newScanConfigBuilder(options: DataSourceOptions): ScanConfigBuilder = { + new AdvancedScanConfigBuilder() } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - val requiredSchema = config.asInstanceOf[AdvancedScanConfigBuilder].requiredSchema - new AdvancedReaderFactory(requiredSchema) + override def createBatchScan(config: ScanConfig, options: DataSourceOptions): BatchScan = { + new AdvancedBatchScan(config.asInstanceOf[AdvancedScanConfigBuilder]) } - } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport + override def schema(): StructType = new StructType().add("i", "int").add("j", "int") } + + override def getTable(options: DataSourceOptions): Table = new MyTable } class AdvancedScanConfigBuilder extends ScanConfigBuilder with ScanConfig @@ -501,6 +468,33 @@ class AdvancedScanConfigBuilder extends ScanConfigBuilder with ScanConfig override def build(): ScanConfig = this } +class AdvancedBatchScan(val config: AdvancedScanConfigBuilder) extends BatchScan { + + override def createReaderFactory(): PartitionReaderFactory = { + new AdvancedReaderFactory(config.requiredSchema) + } + + override def planInputPartitions(): Array[InputPartition] = { + val lowerBound = config.filters.collectFirst { + case GreaterThan("i", v: Int) => v + } + + val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] + + if (lowerBound.isEmpty) { + res.append(RangeInputPartition(0, 5)) + res.append(RangeInputPartition(5, 10)) + } else if (lowerBound.get < 4) { + res.append(RangeInputPartition(lowerBound.get + 1, 5)) + res.append(RangeInputPartition(5, 10)) + } else if (lowerBound.get < 9) { + res.append(RangeInputPartition(lowerBound.get + 1, 10)) + } + + res.toArray + } +} + class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val RangeInputPartition(start, end) = partition @@ -526,40 +520,30 @@ class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderF } -class SchemaRequiredDataSource extends DataSourceV2 with BatchReadSupportProvider { +class SchemaRequiredDataSource extends Format { - class ReadSupport(val schema: StructType) extends SimpleReadSupport { - override def fullSchema(): StructType = schema - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = - Array.empty + class MyTable(override val schema: StructType) extends SimpleBatchReadTable { + override def planInputPartitions(): Array[InputPartition] = Array.empty } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + override def getTable(options: DataSourceOptions): Table = { throw new IllegalArgumentException("requires a user-supplied schema") } - override def createBatchReadSupport( - schema: StructType, options: DataSourceOptions): BatchReadSupport = { - new ReadSupport(schema) - } + override def getTable(options: DataSourceOptions, schema: StructType): Table = new MyTable(schema) } -class ColumnarDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { +class ColumnarDataSourceV2 extends Format { - class ReadSupport extends SimpleReadSupport { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + class MyTable extends SimpleBatchReadTable { + override def planInputPartitions(): Array[InputPartition] = { Array(RangeInputPartition(0, 50), RangeInputPartition(50, 90)) } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - ColumnarReaderFactory - } + override def createReaderFactory(): PartitionReaderFactory = ColumnarReaderFactory } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } + override def getTable(options: DataSourceOptions): Table = new MyTable } object ColumnarReaderFactory extends PartitionReaderFactory { @@ -608,21 +592,20 @@ object ColumnarReaderFactory extends PartitionReaderFactory { } -class PartitionAwareDataSource extends DataSourceV2 with BatchReadSupportProvider { +class PartitionAwareDataSource extends Format { + + class MyTable extends SimpleBatchReadTable with SupportsReportPartitioning { - class ReadSupport extends SimpleReadSupport with SupportsReportPartitioning { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + override def createReaderFactory(): PartitionReaderFactory = SpecificReaderFactory + + override def planInputPartitions(): Array[InputPartition] = { // Note that we don't have same value of column `a` across partitions. Array( SpecificInputPartition(Array(1, 1, 3), Array(4, 4, 6)), SpecificInputPartition(Array(2, 4, 4), Array(6, 2, 2))) } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - SpecificReaderFactory - } - - override def outputPartitioning(config: ScanConfig): Partitioning = new MyPartitioning + override def outputPartitioning(): Partitioning = new MyPartitioning } class MyPartitioning extends Partitioning { @@ -634,9 +617,7 @@ class PartitionAwareDataSource extends DataSourceV2 with BatchReadSupportProvide } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } + override def getTable(options: DataSourceOptions): Table = new MyTable } case class SpecificInputPartition(i: Array[Int], j: Array[Int]) extends InputPartition @@ -662,7 +643,7 @@ object SpecificReaderFactory extends PartitionReaderFactory { class SchemaReadAttemptException(m: String) extends RuntimeException(m) class SimpleWriteOnlyDataSource extends SimpleWritableDataSource { - override def fullSchema(): StructType = { + override def schema(): StructType = { // This is a bit hacky since this source implements read support but throws // during schema retrieval. Might have to rewrite but it's done // such so for minimised changes. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index a7dfc2d1deac..3b6ca89630fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration /** @@ -39,19 +39,17 @@ import org.apache.spark.util.SerializableConfiguration * Each job moves files from `target/_temporary/queryId/` to `target`. */ class SimpleWritableDataSource extends DataSourceV2 - with BatchReadSupportProvider - with BatchWriteSupportProvider - with SessionConfigSupport { + with Format with BatchWriteSupportProvider with SessionConfigSupport { - protected def fullSchema(): StructType = new StructType().add("i", "long").add("j", "long") + protected def schema(): StructType = new StructType().add("i", "long").add("j", "long") override def keyPrefix: String = "simpleWritableDataSource" - class ReadSupport(path: String, conf: Configuration) extends SimpleReadSupport { + class MyTable(path: String, conf: Configuration) extends SimpleBatchReadTable { - override def fullSchema(): StructType = SimpleWritableDataSource.this.fullSchema() + override def schema(): StructType = SimpleWritableDataSource.this.schema() - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + override def planInputPartitions(): Array[InputPartition] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -66,7 +64,7 @@ class SimpleWritableDataSource extends DataSourceV2 } } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + override def createReaderFactory(): PartitionReaderFactory = { val serializableConf = new SerializableConfiguration(conf) new CSVReaderFactory(serializableConf) } @@ -105,10 +103,10 @@ class SimpleWritableDataSource extends DataSourceV2 } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + override def getTable(options: DataSourceOptions): Table = { val path = new Path(options.get("path").get()) val conf = SparkContext.getActive.get.hadoopConfiguration - new ReadSupport(path.toUri.toString, conf) + new MyTable(path.toUri.toString, conf) } override def createBatchWriteSupport( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index f55ddb5419d2..406d29474776 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -30,20 +30,27 @@ import org.apache.hadoop.conf.Configuration import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, TaskContext} +import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.plans.logical.Range +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchInputStream, MicroBatchScan, Offset => OffsetV2} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils class StreamSuite extends StreamTest { @@ -102,12 +109,10 @@ class StreamSuite extends StreamTest { } test("StreamingExecutionRelation.computeStats") { - val streamingExecutionRelation = MemoryStream[Int].toDF.logicalPlan collect { - case s: StreamingExecutionRelation => s - } - assert(streamingExecutionRelation.nonEmpty, "cannot find StreamingExecutionRelation") - assert(streamingExecutionRelation.head.computeStats.sizeInBytes - == spark.sessionState.conf.defaultSizeInBytes) + val memoryStream = MemoryStream[Int] + val executionRelation = StreamingExecutionRelation( + memoryStream, memoryStream.encoder.schema.toAttributes)(memoryStream.sqlContext.sparkSession) + assert(executionRelation.computeStats.sizeInBytes == spark.sessionState.conf.defaultSizeInBytes) } test("explain join with a normal source") { @@ -154,21 +159,25 @@ class StreamSuite extends StreamTest { } test("SPARK-20432: union one stream with itself") { - val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load().select("a") - val unioned = df.union(df) - withTempDir { outputDir => - withTempDir { checkpointDir => - val query = - unioned - .writeStream.format("parquet") - .option("checkpointLocation", checkpointDir.getAbsolutePath) - .start(outputDir.getAbsolutePath) - try { - query.processAllAvailable() - val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long] - checkDatasetUnorderly[Long](outputDf, (0L to 10L).union((0L to 10L)).toArray: _*) - } finally { - query.stop() + val v1Source = spark.readStream.format(classOf[FakeDefaultSource].getName).load().select("a") + val v2Source = spark.readStream.format(classOf[FakeFormat].getName).load().select("a") + + Seq(v1Source, v2Source).foreach { df => + val unioned = df.union(df) + withTempDir { outputDir => + withTempDir { checkpointDir => + val query = + unioned + .writeStream.format("parquet") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .start(outputDir.getAbsolutePath) + try { + query.processAllAvailable() + val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long] + checkDatasetUnorderly[Long](outputDf, (0L to 10L).union((0L to 10L)).toArray: _*) + } finally { + query.stop() + } } } } @@ -381,7 +390,7 @@ class StreamSuite extends StreamTest { test("insert an extraStrategy") { try { - spark.experimental.extraStrategies = TestStrategy :: Nil + spark.experimental.extraStrategies = CustomStrategy :: Nil val inputData = MemoryStream[(String, Int)] val df = inputData.toDS().map(_._1).toDF("a") @@ -495,9 +504,9 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("Streaming RelationV2 MemoryStreamDataSource".r + assert("Streaming RelationV2 MemoryStreamSource".r .findAllMatchIn(explainWithoutExtended).size === 0) - assert("ScanV2 MemoryStreamDataSource".r + assert("ScanV2 MemoryStreamSource".r .findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) @@ -505,9 +514,9 @@ class StreamSuite extends StreamTest { val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("Streaming RelationV2 MemoryStreamDataSource".r + assert("Streaming RelationV2 MemoryStreamSource".r .findAllMatchIn(explainWithExtended).size === 3) - assert("ScanV2 MemoryStreamDataSource".r + assert("ScanV2 MemoryStreamSource".r .findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) @@ -550,17 +559,17 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("Streaming RelationV2 ContinuousMemoryStream".r + assert("Streaming RelationV2 MemoryStreamSource".r .findAllMatchIn(explainWithoutExtended).size === 0) - assert("ScanV2 ContinuousMemoryStream".r + assert("ScanV2 MemoryStreamSource".r .findAllMatchIn(explainWithoutExtended).size === 1) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("Streaming RelationV2 ContinuousMemoryStream".r + assert("Streaming RelationV2 MemoryStreamSource".r .findAllMatchIn(explainWithExtended).size === 3) - assert("ScanV2 ContinuousMemoryStream".r + assert("ScanV2 MemoryStreamSource".r .findAllMatchIn(explainWithExtended).size === 1) } finally { q.stop() @@ -1137,6 +1146,67 @@ class FakeDefaultSource extends FakeSource { } } +// Similar to `FakeDefaultSource`, but with v2 source API. +class FakeFormat extends Format { + override def getTable(options: DataSourceOptions): Table = { + new SupportsMicroBatchRead { + override def createMicroBatchInputStream( + checkpointLocation: String, + config: ScanConfig, + options: DataSourceOptions): MicroBatchInputStream = { + FakeMicroBatchInputStream + } + + override def schema(): StructType = StructType(StructField("a", IntegerType) :: Nil) + } + } + + object FakeMicroBatchInputStream extends MicroBatchInputStream { + override def createMicroBatchScan(start: OffsetV2, end: OffsetV2): MicroBatchScan = { + val s = start.asInstanceOf[LongOffset].offset.toInt + val e = end.asInstanceOf[LongOffset].offset.toInt + new FakeMicroBatchReadSupport(s, e) + } + + override def latestOffset(): OffsetV2 = LongOffset(10) + + override def initialOffset(): OffsetV2 = LongOffset(0) + + override def deserializeOffset(json: String): OffsetV2 = { + LongOffset(json.toLong) + } + + override def commit(end: OffsetV2): Unit = {} + + override def stop(): Unit = {} + } + + class FakeMicroBatchReadSupport(start: Int, end: Int) extends MicroBatchScan { + override def createReaderFactory(): PartitionReaderFactory = { + new PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[InternalRow] { + var current = start - 1 + override def next(): Boolean = { + current += 1 + current <= end + } + + override def get(): InternalRow = InternalRow(current) + + override def close(): Unit = {} + } + } + } + } + + override def planInputPartitions(): Array[InputPartition] = { + Array(RangeInputPartition(start, end)) + } + } +} + /** A fake source that throws the same IOException like pre Hadoop 2.8 when it's interrupted. */ class ThrowingIOExceptionLikeHadoop12074 extends FakeSource { import ThrowingIOExceptionLikeHadoop12074._ @@ -1244,3 +1314,23 @@ object ThrowingExceptionInCreateSource { @volatile var createSourceLatch: CountDownLatch = null @volatile var exception: Exception = null } + +object CustomStrategy extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case Project(Seq(attr), child) if attr.name == "a" => + CustomProjectExec(Seq(attr.toAttribute), planLater(child)) :: Nil + case _ => Nil + } +} + +case class CustomProjectExec(output: Seq[Attribute], child: SparkPlan) extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { it => + val str = UTF8String.fromString("so fast") + val row = new GenericInternalRow(Array[Any](str)) + val unsafeProj = UnsafeProjection.create(schema) + val unsafeRow = unsafeProj(row) + it.map(_ => unsafeRow) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index d878c345c298..074de90d4fcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -688,8 +688,14 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan .collect { + // v1 source case r: StreamingExecutionRelation => r.source - case r: StreamingDataSourceV2Relation => r.readSupport + // v2 source + case r: StreamingDataSourceV2Relation => r.stream + // We can add data to memory stream before starting it. Then the input plan has + // not been processed by the streaming engine and contains `StreamingRelationV2`. + case r: StreamingRelationV2 if r.sourceName == "memory" => + r.table.asInstanceOf[MemoryStreamTable].stream } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala index 46eec736d402..13b8866c22b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala @@ -24,15 +24,14 @@ import scala.util.Random import scala.util.control.NonFatal import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.Span import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Dataset} +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.BlockingSource import org.apache.spark.util.Utils @@ -304,8 +303,8 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { if (withError) { logDebug(s"Terminating query ${queryToStop.name} with error") queryToStop.asInstanceOf[StreamingQueryWrapper].streamingQuery.logicalPlan.collect { - case StreamingExecutionRelation(source, _) => - source.asInstanceOf[MemoryStream[Int]].addData(0) + case r: StreamingDataSourceV2Relation => + r.stream.asInstanceOf[MemoryStream[Int]].addData(0) } } else { logDebug(s"Stopping query ${queryToStop.name}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index c170641372d6..92e9186241cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -36,8 +36,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig} -import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchScan, Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType @@ -220,10 +219,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } // getBatch should take 100 ms the first time it is called - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + override def createMicroBatchScan(start: OffsetV2, end: OffsetV2): MicroBatchScan = { synchronized { clock.waitTillTime(1150) - super.planInputPartitions(config) + super.createMicroBatchScan(start, end) } } } @@ -906,12 +905,12 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(df.logicalPlan.toJSON.contains("StreamingRelationV2")) testStream(df)( - AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingExecutionRelation")) + AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingDataSourceV2Relation")) ) testStream(df, useV2Sink = true)( StartStream(trigger = Trigger.Continuous(100)), - AssertOnQuery(_.logicalPlan.toJSON.contains("ContinuousExecutionRelation")) + AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingDataSourceV2Relation")) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index d6819eacd07c..286675f68654 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, ContinuousReadSupport, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, ContinuousPartitionReader, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType, StructType} @@ -44,7 +44,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { super.beforeEach() epochEndpoint = EpochCoordinatorRef.create( mock[StreamingWriteSupport], - mock[ContinuousReadSupport], + mock[ContinuousInputStream], mock[ContinuousExecution], coordinatorId, startEpoch, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 3d21bc63e0cc..f54970576b13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming.continuous import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream @@ -40,13 +40,15 @@ class ContinuousSuiteBase extends StreamTest { query match { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") - val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReadSupport, _) => r + val stream = s.lastExecution.logical.collectFirst { + case r: StreamingDataSourceV2Relation + if r.stream.isInstanceOf[RateStreamContinuousInputStream] => + r.stream.asInstanceOf[RateStreamContinuousInputStream] }.get val deltaMs = numTriggers * 1000 + 300 - while (System.currentTimeMillis < reader.creationTime + deltaMs) { - Thread.sleep(reader.creationTime + deltaMs - System.currentTimeMillis) + while (System.currentTimeMillis < stream.creationTime + deltaMs) { + Thread.sleep(stream.creationTime + deltaMs - System.currentTimeMillis) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index 3c973d8ebc70..60f58082347f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.LocalSparkSession import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.test.TestSparkSession @@ -45,7 +45,7 @@ class EpochCoordinatorSuite private var orderVerifier: InOrder = _ override def beforeEach(): Unit = { - val reader = mock[ContinuousReadSupport] + val inputStream = mock[ContinuousInputStream] writeSupport = mock[StreamingWriteSupport] query = mock[ContinuousExecution] orderVerifier = inOrder(writeSupport, query) @@ -53,7 +53,7 @@ class EpochCoordinatorSuite spark = new TestSparkSession() epochCoordinator - = EpochCoordinatorRef.create(writeSupport, reader, query, "test", 1, spark, SparkEnv.get) + = EpochCoordinatorRef.create(writeSupport, inputStream, query, "test", 1, spark, SparkEnv.get) } test("single epoch") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 3a0e780a7391..b99dd32f5b22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -24,50 +24,49 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReaderFactory, ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.sources.v2.reader.ScanConfig import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -case class FakeReadSupport() extends MicroBatchReadSupport with ContinuousReadSupport { +class FakeInputStream extends MicroBatchInputStream with ContinuousInputStream { override def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) override def commit(end: Offset): Unit = {} override def stop(): Unit = {} override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) - override def fullSchema(): StructType = StructType(Seq()) - override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = null override def initialOffset(): Offset = RateStreamOffset(Map()) override def latestOffset(): Offset = RateStreamOffset(Map()) - override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = null - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + override def createMicroBatchScan(start: Offset, end: Offset): MicroBatchScan = { throw new IllegalStateException("fake source - cannot actually read") } - override def createContinuousReaderFactory( - config: ScanConfig): ContinuousPartitionReaderFactory = { - throw new IllegalStateException("fake source - cannot actually read") - } - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + override def createContinuousScan(start: Offset): ContinuousScan = { throw new IllegalStateException("fake source - cannot actually read") } } -trait FakeMicroBatchReadSupportProvider extends MicroBatchReadSupportProvider { - override def createMicroBatchReadSupport( +trait FakeMicroBatchReadTable extends Table with SupportsMicroBatchRead { + override def schema(): StructType = StructType(Seq()) + + override def createMicroBatchInputStream( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReadSupport = { + config: ScanConfig, + options: DataSourceOptions): MicroBatchInputStream = { LastReadOptions.options = options - FakeReadSupport() + new FakeInputStream } } -trait FakeContinuousReadSupportProvider extends ContinuousReadSupportProvider { - override def createContinuousReadSupport( +trait FakeContinuousReadTable extends Table with SupportsContinuousRead { + override def schema(): StructType = StructType(Seq()) + + override def createContinuousInputStream( checkpointLocation: String, - options: DataSourceOptions): ContinuousReadSupport = { + config: ScanConfig, + options: DataSourceOptions): ContinuousInputStream = { LastReadOptions.options = options - FakeReadSupport() + new FakeInputStream } } @@ -82,31 +81,43 @@ trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider { } } -class FakeReadMicroBatchOnly - extends DataSourceRegister - with FakeMicroBatchReadSupportProvider - with SessionConfigSupport { +class FakeReadMicroBatchOnly extends Format with DataSourceRegister with SessionConfigSupport { override def shortName(): String = "fake-read-microbatch-only" override def keyPrefix: String = shortName() + + override def getTable(options: DataSourceOptions): Table = { + new FakeMicroBatchReadTable {} + } } -class FakeReadContinuousOnly - extends DataSourceRegister - with FakeContinuousReadSupportProvider - with SessionConfigSupport { +class FakeReadContinuousOnly extends Format with DataSourceRegister with SessionConfigSupport { override def shortName(): String = "fake-read-continuous-only" override def keyPrefix: String = shortName() + + override def getTable(options: DataSourceOptions): Table = { + new FakeContinuousReadTable {} + } } -class FakeReadBothModes extends DataSourceRegister - with FakeMicroBatchReadSupportProvider with FakeContinuousReadSupportProvider { +class FakeReadBothModes extends Format with DataSourceRegister { override def shortName(): String = "fake-read-microbatch-continuous" + + override def getTable(options: DataSourceOptions): Table = { + new Table + with FakeMicroBatchReadTable with FakeContinuousReadTable {} + } } -class FakeReadNeitherMode extends DataSourceRegister { +class FakeReadNeitherMode extends Format with DataSourceRegister { override def shortName(): String = "fake-read-neither-mode" + + override def getTable(options: DataSourceOptions): Table = { + new Table { + override def schema(): StructType = StructType(Nil) + } + } } class FakeWriteSupportProvider @@ -299,23 +310,24 @@ class StreamingDataSourceV2Suite extends StreamTest { for ((read, write, trigger) <- cases) { testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { - val readSource = DataSource.lookupDataSource(read, spark.sqlContext.conf).newInstance() + val readSource = DataSource.lookupDataSource(read, spark.sqlContext.conf) + .newInstance().asInstanceOf[Format].getTable(DataSourceOptions.empty()) val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() (readSource, writeSource, trigger) match { // Valid microbatch queries. - case (_: MicroBatchReadSupportProvider, _: StreamingWriteSupportProvider, t) + case (_: SupportsMicroBatchRead, _: StreamingWriteSupportProvider, t) if !t.isInstanceOf[ContinuousTrigger] => testPositiveCase(read, write, trigger) // Valid continuous queries. - case (_: ContinuousReadSupportProvider, _: StreamingWriteSupportProvider, + case (_: SupportsContinuousRead, _: StreamingWriteSupportProvider, _: ContinuousTrigger) => testPositiveCase(read, write, trigger) // Invalid - can't read at all case (r, _, _) - if !r.isInstanceOf[MicroBatchReadSupportProvider] - && !r.isInstanceOf[ContinuousReadSupportProvider] => + if !r.isInstanceOf[SupportsMicroBatchRead] + && !r.isInstanceOf[SupportsContinuousRead] => testNegativeCase(read, write, trigger, s"Data source $read does not support streamed reading") @@ -326,13 +338,13 @@ class StreamingDataSourceV2Suite extends StreamTest { // Invalid - trigger is continuous but reader is not case (r, _: StreamingWriteSupportProvider, _: ContinuousTrigger) - if !r.isInstanceOf[ContinuousReadSupportProvider] => + if !r.isInstanceOf[SupportsContinuousRead] => testNegativeCase(read, write, trigger, s"Data source $read does not support continuous processing") // Invalid - trigger is microbatch but reader is not case (r, _, t) - if !r.isInstanceOf[MicroBatchReadSupportProvider] && + if !r.isInstanceOf[SupportsMicroBatchRead] && !t.isInstanceOf[ContinuousTrigger] => testPostCreationNegativeCase(read, write, trigger, s"Data source $read does not support microbatch processing")