@@ -25,15 +25,16 @@ import org.apache.kafka.common.TopicPartition
2525
2626import org .apache .spark .TaskContext
2727import org .apache .spark .internal .Logging
28+ import org .apache .spark .sql .SparkSession
2829import org .apache .spark .sql .catalyst .InternalRow
2930import org .apache .spark .sql .catalyst .expressions .UnsafeRow
3031import org .apache .spark .sql .kafka010 .KafkaSourceProvider .{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE , INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE }
3132import org .apache .spark .sql .sources .v2 .reader ._
32- import org .apache .spark .sql .sources .v2 .reader .streaming ._
33+ import org .apache .spark .sql .sources .v2 .reader .streaming .{ ContinuousInputPartitionReader , ContinuousReader , Offset , PartitionOffset }
3334import org .apache .spark .sql .types .StructType
3435
3536/**
36- * A [[ContinuousReadSupport ]] for data from kafka.
37+ * A [[ContinuousReader ]] for data from kafka.
3738 *
3839 * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be
3940 * read by per-task consumers generated later.
@@ -46,49 +47,70 @@ import org.apache.spark.sql.types.StructType
4647 * scenarios, where some offsets after the specified initial ones can't be
4748 * properly read.
4849 */
49- class KafkaContinuousReadSupport (
50+ class KafkaContinuousReader (
5051 offsetReader : KafkaOffsetReader ,
5152 kafkaParams : ju.Map [String , Object ],
5253 sourceOptions : Map [String , String ],
5354 metadataPath : String ,
5455 initialOffsets : KafkaOffsetRangeLimit ,
5556 failOnDataLoss : Boolean )
56- extends ContinuousReadSupport with Logging {
57+ extends ContinuousReader with Logging {
58+
59+ private lazy val session = SparkSession .getActiveSession.get
60+ private lazy val sc = session.sparkContext
5761
5862 private val pollTimeoutMs = sourceOptions.getOrElse(" kafkaConsumer.pollTimeoutMs" , " 512" ).toLong
5963
60- override def initialOffset (): Offset = {
61- val offsets = initialOffsets match {
62- case EarliestOffsetRangeLimit => KafkaSourceOffset (offsetReader.fetchEarliestOffsets())
63- case LatestOffsetRangeLimit => KafkaSourceOffset (offsetReader.fetchLatestOffsets())
64- case SpecificOffsetRangeLimit (p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss)
65- }
66- logInfo(s " Initial offsets: $offsets" )
67- offsets
68- }
64+ // Initialized when creating reader factories. If this diverges from the partitions at the latest
65+ // offsets, we need to reconfigure.
66+ // Exposed outside this object only for unit tests.
67+ @ volatile private [sql] var knownPartitions : Set [TopicPartition ] = _
6968
70- override def fullSchema () : StructType = KafkaOffsetReader .kafkaSchema
69+ override def readSchema : StructType = KafkaOffsetReader .kafkaSchema
7170
72- override def newScanConfigBuilder (start : Offset ): ScanConfigBuilder = {
73- new KafkaContinuousScanConfigBuilder (fullSchema(), start, offsetReader, reportDataLoss)
71+ private var offset : Offset = _
72+ override def setStartOffset (start : ju.Optional [Offset ]): Unit = {
73+ offset = start.orElse {
74+ val offsets = initialOffsets match {
75+ case EarliestOffsetRangeLimit => KafkaSourceOffset (offsetReader.fetchEarliestOffsets())
76+ case LatestOffsetRangeLimit => KafkaSourceOffset (offsetReader.fetchLatestOffsets())
77+ case SpecificOffsetRangeLimit (p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss)
78+ }
79+ logInfo(s " Initial offsets: $offsets" )
80+ offsets
81+ }
7482 }
7583
84+ override def getStartOffset (): Offset = offset
85+
7686 override def deserializeOffset (json : String ): Offset = {
7787 KafkaSourceOffset (JsonUtils .partitionOffsets(json))
7888 }
7989
80- override def planInputPartitions (config : ScanConfig ): Array [InputPartition ] = {
81- val startOffsets = config.asInstanceOf [KafkaContinuousScanConfig ].startOffsets
90+ override def planInputPartitions (): ju.List [InputPartition [InternalRow ]] = {
91+ import scala .collection .JavaConverters ._
92+
93+ val oldStartPartitionOffsets = KafkaSourceOffset .getPartitionOffsets(offset)
94+
95+ val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet
96+ val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet)
97+ val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq)
98+
99+ val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet)
100+ if (deletedPartitions.nonEmpty) {
101+ reportDataLoss(s " Some partitions were deleted: $deletedPartitions" )
102+ }
103+
104+ val startOffsets = newPartitionOffsets ++
105+ oldStartPartitionOffsets.filterKeys(! deletedPartitions.contains(_))
106+ knownPartitions = startOffsets.keySet
107+
82108 startOffsets.toSeq.map {
83109 case (topicPartition, start) =>
84110 KafkaContinuousInputPartition (
85- topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
86- }.toArray
87- }
88-
89- override def createContinuousReaderFactory (
90- config : ScanConfig ): ContinuousPartitionReaderFactory = {
91- KafkaContinuousReaderFactory
111+ topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss
112+ ): InputPartition [InternalRow ]
113+ }.asJava
92114 }
93115
94116 /** Stop this source and free any resources it has allocated. */
@@ -105,9 +127,8 @@ class KafkaContinuousReadSupport(
105127 KafkaSourceOffset (mergedMap)
106128 }
107129
108- override def needsReconfiguration (config : ScanConfig ): Boolean = {
109- val knownPartitions = config.asInstanceOf [KafkaContinuousScanConfig ].knownPartitions
110- offsetReader.fetchLatestOffsets().keySet != knownPartitions
130+ override def needsReconfiguration (): Boolean = {
131+ knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions
111132 }
112133
113134 override def toString (): String = s " KafkaSource[ $offsetReader] "
@@ -141,51 +162,23 @@ case class KafkaContinuousInputPartition(
141162 startOffset : Long ,
142163 kafkaParams : ju.Map [String , Object ],
143164 pollTimeoutMs : Long ,
144- failOnDataLoss : Boolean ) extends InputPartition
145-
146- object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory {
147- override def createReader (partition : InputPartition ): ContinuousPartitionReader [InternalRow ] = {
148- val p = partition.asInstanceOf [KafkaContinuousInputPartition ]
149- new KafkaContinuousPartitionReader (
150- p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs, p.failOnDataLoss)
165+ failOnDataLoss : Boolean ) extends ContinuousInputPartition [InternalRow ] {
166+
167+ override def createContinuousReader (
168+ offset : PartitionOffset ): InputPartitionReader [InternalRow ] = {
169+ val kafkaOffset = offset.asInstanceOf [KafkaSourcePartitionOffset ]
170+ require(kafkaOffset.topicPartition == topicPartition,
171+ s " Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}" )
172+ new KafkaContinuousInputPartitionReader (
173+ topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss)
151174 }
152- }
153-
154- class KafkaContinuousScanConfigBuilder (
155- schema : StructType ,
156- startOffset : Offset ,
157- offsetReader : KafkaOffsetReader ,
158- reportDataLoss : String => Unit )
159- extends ScanConfigBuilder {
160-
161- override def build (): ScanConfig = {
162- val oldStartPartitionOffsets = KafkaSourceOffset .getPartitionOffsets(startOffset)
163-
164- val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet
165- val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet)
166- val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq)
167175
168- val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet)
169- if (deletedPartitions.nonEmpty) {
170- reportDataLoss(s " Some partitions were deleted: $deletedPartitions" )
171- }
172-
173- val startOffsets = newPartitionOffsets ++
174- oldStartPartitionOffsets.filterKeys(! deletedPartitions.contains(_))
175- KafkaContinuousScanConfig (schema, startOffsets)
176+ override def createPartitionReader (): KafkaContinuousInputPartitionReader = {
177+ new KafkaContinuousInputPartitionReader (
178+ topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss)
176179 }
177180}
178181
179- case class KafkaContinuousScanConfig (
180- readSchema : StructType ,
181- startOffsets : Map [TopicPartition , Long ])
182- extends ScanConfig {
183-
184- // Created when building the scan config builder. If this diverges from the partitions at the
185- // latest offsets, we need to reconfigure the kafka read support.
186- def knownPartitions : Set [TopicPartition ] = startOffsets.keySet
187- }
188-
189182/**
190183 * A per-task data reader for continuous Kafka processing.
191184 *
@@ -196,12 +189,12 @@ case class KafkaContinuousScanConfig(
196189 * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
197190 * are skipped.
198191 */
199- class KafkaContinuousPartitionReader (
192+ class KafkaContinuousInputPartitionReader (
200193 topicPartition : TopicPartition ,
201194 startOffset : Long ,
202195 kafkaParams : ju.Map [String , Object ],
203196 pollTimeoutMs : Long ,
204- failOnDataLoss : Boolean ) extends ContinuousPartitionReader [InternalRow ] {
197+ failOnDataLoss : Boolean ) extends ContinuousInputPartitionReader [InternalRow ] {
205198 private val consumer = KafkaDataConsumer .acquire(topicPartition, kafkaParams, useCache = false )
206199 private val converter = new KafkaRecordToUnsafeRowConverter
207200
0 commit comments