Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,15 @@ import org.apache.kafka.common.TopicPartition

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset}
import org.apache.spark.sql.sources.v2.reader.streaming._
import org.apache.spark.sql.types.StructType

/**
* A [[ContinuousReader]] for data from kafka.
* A [[ContinuousReadSupport]] for data from kafka.
*
* @param offsetReader a reader used to get kafka offsets. Note that the actual data will be
* read by per-task consumers generated later.
Expand All @@ -47,70 +46,49 @@ import org.apache.spark.sql.types.StructType
* scenarios, where some offsets after the specified initial ones can't be
* properly read.
*/
class KafkaContinuousReader(
class KafkaContinuousReadSupport(
offsetReader: KafkaOffsetReader,
kafkaParams: ju.Map[String, Object],
sourceOptions: Map[String, String],
metadataPath: String,
initialOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
extends ContinuousReader with Logging {

private lazy val session = SparkSession.getActiveSession.get
private lazy val sc = session.sparkContext
extends ContinuousReadSupport with Logging {

private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong

// Initialized when creating reader factories. If this diverges from the partitions at the latest
// offsets, we need to reconfigure.
// Exposed outside this object only for unit tests.
@volatile private[sql] var knownPartitions: Set[TopicPartition] = _
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to KafkaContinuousScanConfig


override def readSchema: StructType = KafkaOffsetReader.kafkaSchema

private var offset: Offset = _
override def setStartOffset(start: ju.Optional[Offset]): Unit = {
offset = start.orElse {
val offsets = initialOffsets match {
case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets())
case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets())
case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss)
}
logInfo(s"Initial offsets: $offsets")
offsets
override def initialOffset(): Offset = {
val offsets = initialOffsets match {
case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets())
case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets())
case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss)
}
logInfo(s"Initial offsets: $offsets")
offsets
}

override def getStartOffset(): Offset = offset
override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema

override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = {
new KafkaContinuousScanConfigBuilder(fullSchema(), start, offsetReader, reportDataLoss)
}

override def deserializeOffset(json: String): Offset = {
KafkaSourceOffset(JsonUtils.partitionOffsets(json))
}

override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = {
import scala.collection.JavaConverters._

val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset)

val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet
val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet)
val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq)

val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet)
if (deletedPartitions.nonEmpty) {
reportDataLoss(s"Some partitions were deleted: $deletedPartitions")
}

val startOffsets = newPartitionOffsets ++
oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_))
knownPartitions = startOffsets.keySet

override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
val startOffsets = config.asInstanceOf[KafkaContinuousScanConfig].startOffsets
startOffsets.toSeq.map {
case (topicPartition, start) =>
KafkaContinuousInputPartition(
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss
): InputPartition[InternalRow]
}.asJava
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
}.toArray
}

override def createContinuousReaderFactory(
config: ScanConfig): ContinuousPartitionReaderFactory = {
KafkaContinuousReaderFactory
}

/** Stop this source and free any resources it has allocated. */
Expand All @@ -127,8 +105,9 @@ class KafkaContinuousReader(
KafkaSourceOffset(mergedMap)
}

override def needsReconfiguration(): Boolean = {
knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions
override def needsReconfiguration(config: ScanConfig): Boolean = {
val knownPartitions = config.asInstanceOf[KafkaContinuousScanConfig].knownPartitions
offsetReader.fetchLatestOffsets().keySet != knownPartitions
}

override def toString(): String = s"KafkaSource[$offsetReader]"
Expand Down Expand Up @@ -162,23 +141,51 @@ case class KafkaContinuousInputPartition(
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends ContinuousInputPartition[InternalRow] {

override def createContinuousReader(
offset: PartitionOffset): InputPartitionReader[InternalRow] = {
val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset]
require(kafkaOffset.topicPartition == topicPartition,
s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}")
new KafkaContinuousInputPartitionReader(
topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss)
failOnDataLoss: Boolean) extends InputPartition

object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory {
override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = {
val p = partition.asInstanceOf[KafkaContinuousInputPartition]
new KafkaContinuousPartitionReader(
p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs, p.failOnDataLoss)
}
}

class KafkaContinuousScanConfigBuilder(
schema: StructType,
startOffset: Offset,
offsetReader: KafkaOffsetReader,
reportDataLoss: String => Unit)
extends ScanConfigBuilder {

override def build(): ScanConfig = {
val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(startOffset)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet
val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet)
val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq)

override def createPartitionReader(): KafkaContinuousInputPartitionReader = {
new KafkaContinuousInputPartitionReader(
topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss)
val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet)
if (deletedPartitions.nonEmpty) {
reportDataLoss(s"Some partitions were deleted: $deletedPartitions")
}

val startOffsets = newPartitionOffsets ++
oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_))
KafkaContinuousScanConfig(schema, startOffsets)
}
}

case class KafkaContinuousScanConfig(
readSchema: StructType,
startOffsets: Map[TopicPartition, Long])
extends ScanConfig {

// Created when building the scan config builder. If this diverges from the partitions at the
// latest offsets, we need to reconfigure the kafka read support.
def knownPartitions: Set[TopicPartition] = startOffsets.keySet
}

/**
* A per-task data reader for continuous Kafka processing.
*
Expand All @@ -189,12 +196,12 @@ case class KafkaContinuousInputPartition(
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
* are skipped.
*/
class KafkaContinuousInputPartitionReader(
class KafkaContinuousPartitionReader(
topicPartition: TopicPartition,
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[InternalRow] {
failOnDataLoss: Boolean) extends ContinuousPartitionReader[InternalRow] {
private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false)
private val converter = new KafkaRecordToUnsafeRowConverter

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import java.{util => ju}
import java.io._
import java.nio.charset.StandardCharsets

import scala.collection.JavaConverters._

import org.apache.commons.io.IOUtils
import org.apache.kafka.common.TopicPartition

Expand All @@ -32,16 +30,17 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset}
import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder}
import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchReadSupport
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions}
import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset, SupportsCustomReaderMetrics}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset, SupportsCustomReaderMetrics}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.UninterruptibleThread

/**
* A [[MicroBatchReader]] that reads data from Kafka.
* A [[MicroBatchReadSupport]] that reads data from Kafka.
*
* The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains
* a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For
Expand All @@ -56,17 +55,14 @@ import org.apache.spark.util.UninterruptibleThread
* To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers
* and not use wrong broker addresses.
*/
private[kafka010] class KafkaMicroBatchReader(
private[kafka010] class KafkaMicroBatchReadSupport(
kafkaOffsetReader: KafkaOffsetReader,
executorKafkaParams: ju.Map[String, Object],
options: DataSourceOptions,
metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
extends MicroBatchReader with SupportsCustomReaderMetrics with Logging {

private var startPartitionOffsets: PartitionOffsetMap = _
private var endPartitionOffsets: PartitionOffsetMap = _
extends RateControlMicroBatchReadSupport with SupportsCustomReaderMetrics with Logging {

private val pollTimeoutMs = options.getLong(
"kafkaConsumer.pollTimeoutMs",
Expand All @@ -76,34 +72,40 @@ private[kafka010] class KafkaMicroBatchReader(
Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong)

private val rangeCalculator = KafkaOffsetRangeCalculator(options)

private var endPartitionOffsets: KafkaSourceOffset = _

/**
* Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only
* called in StreamExecutionThread. Otherwise, interrupting a thread while running
* `KafkaConsumer.poll` may hang forever (KAFKA-1894).
*/
private lazy val initialPartitionOffsets = getOrCreateInitialPartitionOffsets()

override def setOffsetRange(start: ju.Optional[Offset], end: ju.Optional[Offset]): Unit = {
// Make sure initialPartitionOffsets is initialized
initialPartitionOffsets

startPartitionOffsets = Option(start.orElse(null))
.map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets)
.getOrElse(initialPartitionOffsets)

endPartitionOffsets = Option(end.orElse(null))
.map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets)
.getOrElse {
val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets()
maxOffsetsPerTrigger.map { maxOffsets =>
rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets)
}.getOrElse {
latestPartitionOffsets
}
}
override def initialOffset(): Offset = {
KafkaSourceOffset(getOrCreateInitialPartitionOffsets())
}

override def latestOffset(start: Offset): Offset = {
val startPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets
val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets()
endPartitionOffsets = KafkaSourceOffset(maxOffsetsPerTrigger.map { maxOffsets =>
rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets)
}.getOrElse {
latestPartitionOffsets
})
endPartitionOffsets
}

override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema

override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = {
new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end))
}

override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = {
override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
val sc = config.asInstanceOf[SimpleStreamingScanConfig]
val startPartitionOffsets = sc.start.asInstanceOf[KafkaSourceOffset].partitionToOffsets
val endPartitionOffsets = sc.end.get.asInstanceOf[KafkaSourceOffset].partitionToOffsets

// Find the new partitions, and get their earliest offsets
val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet)
val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
Expand Down Expand Up @@ -145,30 +147,26 @@ private[kafka010] class KafkaMicroBatchReader(

// Generate factories based on the offset ranges
offsetRanges.map { range =>
new KafkaMicroBatchInputPartition(
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer
): InputPartition[InternalRow]
}.asJava
}

override def getStartOffset: Offset = {
KafkaSourceOffset(startPartitionOffsets)
KafkaMicroBatchInputPartition(
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
}.toArray
}

override def getEndOffset: Offset = {
KafkaSourceOffset(endPartitionOffsets)
override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
KafkaMicroBatchReaderFactory
}

override def getCustomMetrics: CustomMetrics = {
KafkaCustomMetrics(kafkaOffsetReader.fetchLatestOffsets(), endPartitionOffsets)
// TODO: figure out the life cycle of custom metrics, and make this method take `ScanConfig` as
// a parameter.
override def getCustomMetrics(): CustomMetrics = {
KafkaCustomMetrics(
kafkaOffsetReader.fetchLatestOffsets(), endPartitionOffsets.partitionToOffsets)
}

override def deserializeOffset(json: String): Offset = {
KafkaSourceOffset(JsonUtils.partitionOffsets(json))
}

override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema

override def commit(end: Offset): Unit = {}

override def stop(): Unit = {
Expand Down Expand Up @@ -311,22 +309,23 @@ private[kafka010] case class KafkaMicroBatchInputPartition(
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean,
reuseKafkaConsumer: Boolean) extends InputPartition[InternalRow] {
reuseKafkaConsumer: Boolean) extends InputPartition

override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray

override def createPartitionReader(): InputPartitionReader[InternalRow] =
new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs,
failOnDataLoss, reuseKafkaConsumer)
private[kafka010] object KafkaMicroBatchReaderFactory extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val p = partition.asInstanceOf[KafkaMicroBatchInputPartition]
KafkaMicroBatchPartitionReader(p.offsetRange, p.executorKafkaParams, p.pollTimeoutMs,
p.failOnDataLoss, p.reuseKafkaConsumer)
}
}

/** A [[InputPartitionReader]] for reading Kafka data in a micro-batch streaming query. */
private[kafka010] case class KafkaMicroBatchInputPartitionReader(
/** A [[PartitionReader]] for reading Kafka data in a micro-batch streaming query. */
private[kafka010] case class KafkaMicroBatchPartitionReader(
offsetRange: KafkaOffsetRange,
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean,
reuseKafkaConsumer: Boolean) extends InputPartitionReader[InternalRow] with Logging {
reuseKafkaConsumer: Boolean) extends PartitionReader[InternalRow] with Logging {

private val consumer = KafkaDataConsumer.acquire(
offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer)
Expand Down
Loading