Skip to content

Commit c2ce1be

Browse files
brkyvzzsxwing
authored andcommitted
[SPARK-18475] Be able to increase parallelism in StructuredStreaming Kafka source
## What changes were proposed in this pull request? This PR adds the configuration `numPartitions` to the StructuredStreaming Kafka Source. Setting this value to a value higher than the number of `TopicPartitions` that you're going to consume will allow Spark to have multiple tasks reading from the same `TopicPartition` allowing users to handle skewed partitions. While the number of `TopicPartitions` could be dynamic from batch to batch, e.g. you may delete/create topics, in ETL use cases where you generally have a set of static number of TopicPartitions, this configuration has been very useful. If the `TopicPartitions` are dynamic, then we will always have a parallelism of `max(topicPartitions.length, numPartitions)`. ## How was this patch tested? Unit tests. I used this on production data and it certainly helped in handling peak loads and skewed partitions. Author: Burak Yavuz <[email protected]> Closes apache#166 from brkyvz/kafka-par-split.
1 parent f8bf2b0 commit c2ce1be

File tree

5 files changed

+307
-44
lines changed

5 files changed

+307
-44
lines changed

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ private[kafka010] case class CachedKafkaConsumer private(
271271
}
272272
}
273273

274-
private def close(): Unit = consumer.close()
274+
private[kafka010] def close(): Unit = consumer.close()
275275

276276
private def seek(offset: Long): Unit = {
277277
logDebug(s"Seeking to $groupId $topicPartition $offset")
@@ -334,22 +334,27 @@ private[kafka010] object CachedKafkaConsumer extends Logging {
334334
def getOrCreate(
335335
topic: String,
336336
partition: Int,
337-
kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = synchronized {
337+
kafkaParams: ju.Map[String, Object],
338+
reuse: Boolean): CachedKafkaConsumer = synchronized {
338339
val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
339340
val topicPartition = new TopicPartition(topic, partition)
340341
val key = CacheKey(groupId, topicPartition)
341342

342343
// If this is reattempt at running the task, then invalidate cache and start with
343344
// a new consumer
344-
if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) {
345+
if (!reuse || TaskContext.get != null && TaskContext.get.attemptNumber > 1) {
346+
logDebug("Creating new CachedKafkaConsumer")
345347
val removedConsumer = cache.remove(key)
346348
if (removedConsumer != null) {
347349
removedConsumer.close()
348350
}
349351
new CachedKafkaConsumer(topicPartition, kafkaParams)
350352
} else {
351353
if (!cache.containsKey(key)) {
354+
logDebug("Creating new CachedKafkaConsumer")
352355
cache.put(key, new CachedKafkaConsumer(topicPartition, kafkaParams))
356+
} else {
357+
logDebug(s"CachedKafkaConsumer exists for key: $key. Reusing.")
353358
}
354359
cache.get(key)
355360
}

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala

Lines changed: 93 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,13 @@ private[kafka010] class KafkaSource(
119119
groupId
120120
}
121121

122+
/**
123+
* Number of partitions to read from Kafka. If this value is greater than the number of Kafka
124+
* topicPartitions, we will not use the CachedConsumer.
125+
*/
126+
private val minNumParitions =
127+
sourceOptions.getOrElse("minNumParitions", "0").toInt
128+
122129
/**
123130
* A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the
124131
* offsets and never commits them.
@@ -279,39 +286,15 @@ private[kafka010] class KafkaSource(
279286
}.toSeq
280287
logDebug("TopicPartitions: " + topicPartitions.mkString(", "))
281288

282-
val sortedExecutors = getSortedExecutorList(sc)
283-
val numExecutors = sortedExecutors.length
284-
logDebug("Sorted executors: " + sortedExecutors.mkString(", "))
285-
286-
// Calculate offset ranges
287-
val offsetRanges = topicPartitions.map { tp =>
288-
val fromOffset = fromPartitionOffsets.get(tp).getOrElse {
289-
newPartitionOffsets.getOrElse(tp, {
290-
// This should not happen since newPartitionOffsets contains all partitions not in
291-
// fromPartitionOffsets
292-
throw new IllegalStateException(s"$tp doesn't have a from offset")
293-
})
294-
}
295-
val untilOffset = untilPartitionOffsets(tp)
296-
val preferredLoc = if (numExecutors > 0) {
297-
// This allows cached KafkaConsumers in the executors to be re-used to read the same
298-
// partition in every batch.
299-
Some(sortedExecutors(floorMod(tp.hashCode, numExecutors)))
300-
} else None
301-
KafkaSourceRDDOffsetRange(tp, fromOffset, untilOffset, preferredLoc)
302-
}.filter { range =>
303-
if (range.untilOffset < range.fromOffset) {
304-
reportDataLoss(s"Partition ${range.topicPartition}'s offset was changed from " +
305-
s"${range.fromOffset} to ${range.untilOffset}, some data may have been missed")
306-
false
307-
} else {
308-
true
309-
}
310-
}.toArray
289+
val offsetRanges = getOffsetRanges(topicPartitions, fromPartitionOffsets, newPartitionOffsets,
290+
untilPartitionOffsets)
291+
// We can't re-use CachedConsumers if we are using multiple partitions to read from a
292+
// single Kafka TopicPartition
293+
val reuseCachedConsumers = canReuseCachedConsumers(topicPartitions.length)
311294

312295
// Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays.
313-
val rdd = new KafkaSourceRDD(
314-
sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss).map { cr =>
296+
val rdd = new KafkaSourceRDD(sc, executorKafkaParams, offsetRanges, pollTimeoutMs,
297+
failOnDataLoss, reuseCachedConsumers).map { cr =>
315298
InternalRow(
316299
cr.key,
317300
cr.value,
@@ -393,6 +376,85 @@ private[kafka010] class KafkaSource(
393376
partitionOffsets
394377
}
395378

379+
/**
380+
* If we divide topic partitions into multiple read tasks, we can't re-use CachedConsumers on
381+
* the executors.
382+
*/
383+
private def canReuseCachedConsumers(numTopicPartitions: Int): Boolean = {
384+
math.max(minNumParitions, numTopicPartitions) == numTopicPartitions
385+
}
386+
387+
/**
388+
* Calculate the offset ranges that we are going to process this batch. If `numPartitions`
389+
* is not set or is set less than or equal the number of `topicPartitions` that we're going to
390+
* consume, then we fall back to a 1-1 mapping of Spark tasks to Kafka partitions. If
391+
* `numPartitions` is set higher than the number of our `topicPartitions`, then we will split up
392+
* the read tasks of the skewed partitions to multiple Spark tasks.
393+
* The number of Spark tasks will be *approximately* `numPartitions`. It can be less or more
394+
* depending on rounding errors or Kafka partitions that didn't receive any new data.
395+
*/
396+
private def getOffsetRanges(
397+
topicPartitions: Seq[TopicPartition],
398+
fromPartitionOffsets: Map[TopicPartition, Long],
399+
newPartitionOffsets: Map[TopicPartition, Long],
400+
untilPartitionOffsets: Map[TopicPartition, Long]): Seq[KafkaSourceRDDOffsetRange] = {
401+
val numPartitionsToRead = math.max(minNumParitions, topicPartitions.length)
402+
403+
val offsets = topicPartitions.flatMap { tp =>
404+
val fromOffset = fromPartitionOffsets.get(tp).getOrElse {
405+
newPartitionOffsets.getOrElse(tp, {
406+
// This should not happen since newPartitionOffsets contains all partitions not in
407+
// fromPartitionOffsets
408+
throw new IllegalStateException(s"$tp doesn't have a from offset")
409+
})
410+
}
411+
val untilOffset = untilPartitionOffsets(tp)
412+
if (untilOffset < fromOffset) {
413+
reportDataLoss(s"Partition $tp's offset was changed from " +
414+
s"$fromOffset to $untilOffset, some data may have been missed")
415+
None
416+
} else {
417+
Some(KafkaSourceRDDOffsetRange(tp, fromOffset, untilOffset, None))
418+
}
419+
}
420+
421+
if (numPartitionsToRead == topicPartitions.length) {
422+
val sortedExecutors = getSortedExecutorList(sc)
423+
val numExecutors = sortedExecutors.length
424+
logDebug("Sorted executors: " + sortedExecutors.mkString(", "))
425+
426+
// One-to-One mapping
427+
offsets.map { case KafkaSourceRDDOffsetRange(tp, fromOffset, untilOffset, _) =>
428+
val preferredLoc = if (numExecutors > 0) {
429+
// This allows cached KafkaConsumers in the executors to be re-used to read the same
430+
// partition in every batch.
431+
Some(sortedExecutors(floorMod(tp.hashCode, numExecutors)))
432+
} else None
433+
KafkaSourceRDDOffsetRange(tp, fromOffset, untilOffset, preferredLoc)
434+
}.toList
435+
} else {
436+
// one-to-many mapping. We can't re-use CachedConsumers in this instance.
437+
val totalSize = offsets.map(o => o.untilOffset - o.fromOffset).sum
438+
offsets.flatMap { offsetRange =>
439+
val tp = offsetRange.topicPartition
440+
val size = offsetRange.untilOffset - offsetRange.fromOffset
441+
// number of partitions to divvy up this topic partition to
442+
val parts = math.max(math.round(size * 1.0 / totalSize * numPartitionsToRead), 1).toInt
443+
var remaining = size
444+
var startOffset = offsetRange.fromOffset
445+
(0 until parts).map { part =>
446+
// Fine to do integer division. Last partition will consume all the round off errors
447+
val thisPartition = remaining / (parts - part)
448+
remaining -= thisPartition
449+
val endOffset = startOffset + thisPartition
450+
val offsetRange = KafkaSourceRDDOffsetRange(tp, startOffset, endOffset, None)
451+
startOffset = endOffset
452+
offsetRange
453+
}
454+
}.toList
455+
}
456+
}
457+
396458
/**
397459
* Fetch the latest offset of partitions.
398460
*/

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider
212212
|Instead set the source option '$STARTING_OFFSETS_OPTION_KEY' to 'earliest' or 'latest'
213213
|to specify where to start. Structured Streaming manages which offsets are consumed
214214
|internally, rather than relying on the kafkaConsumer to do it. This will ensure that no
215-
|data is missed when when new topics/partitions are dynamically subscribed. Note that
215+
|data is missed when new topics/partitions are dynamically subscribed. Note that
216216
|'$STARTING_OFFSETS_OPTION_KEY' only applies when a new Streaming query is started, and
217217
|that resuming will always pick up from where the query left off. See the docs for more
218218
|details.

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ import java.{util => ju}
2121

2222
import scala.collection.mutable.ArrayBuffer
2323

24-
import org.apache.kafka.clients.consumer.ConsumerRecord
24+
import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord}
2525
import org.apache.kafka.common.TopicPartition
2626

2727
import org.apache.spark.{Partition, SparkContext, TaskContext}
2828
import org.apache.spark.partial.{BoundedDouble, PartialResult}
2929
import org.apache.spark.rdd.RDD
3030
import org.apache.spark.storage.StorageLevel
31+
import org.apache.spark.util.CompletionIterator
3132
import org.apache.spark.util.NextIterator
3233

3334

@@ -63,7 +64,8 @@ private[kafka010] class KafkaSourceRDD(
6364
executorKafkaParams: ju.Map[String, Object],
6465
offsetRanges: Seq[KafkaSourceRDDOffsetRange],
6566
pollTimeoutMs: Long,
66-
failOnDataLoss: Boolean)
67+
failOnDataLoss: Boolean,
68+
reuseCachedConsumers: Boolean = true)
6769
extends RDD[ConsumerRecord[Array[Byte], Array[Byte]]](sc, Nil) {
6870

6971
override def persist(newLevel: StorageLevel): this.type = {
@@ -119,6 +121,15 @@ private[kafka010] class KafkaSourceRDD(
119121
part.offsetRange.preferredLoc.map(Seq(_)).getOrElse(Seq.empty)
120122
}
121123

124+
/** Pulled out for mockability in testing. */
125+
protected def getOrCreateKafkaConsumer(
126+
topic: String,
127+
partition: Int,
128+
kafkaParams: ju.Map[String, Object],
129+
reuseCachedConsumers: Boolean): CachedKafkaConsumer = {
130+
CachedKafkaConsumer.getOrCreate(topic, partition, executorKafkaParams, reuseCachedConsumers)
131+
}
132+
122133
override def compute(
123134
thePart: Partition,
124135
context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = {
@@ -133,9 +144,18 @@ private[kafka010] class KafkaSourceRDD(
133144
s"skipping ${range.topic} ${range.partition}")
134145
Iterator.empty
135146
} else {
136-
new NextIterator[ConsumerRecord[Array[Byte], Array[Byte]]]() {
137-
val consumer = CachedKafkaConsumer.getOrCreate(
138-
range.topic, range.partition, executorKafkaParams)
147+
if (!reuseCachedConsumers) {
148+
// if we can't reuse CachedKafkaConsumers, let's reset the groupId, because we will have
149+
// multiple tasks reading from the same topic partitions
150+
val old = executorKafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
151+
executorKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, old + "-" + thePart.index.toString)
152+
}
153+
154+
logDebug(s"Creating iterator for $range")
155+
156+
val underlying = new NextIterator[ConsumerRecord[Array[Byte], Array[Byte]]]() {
157+
val consumer = getOrCreateKafkaConsumer(range.topic, range.partition, executorKafkaParams,
158+
reuseCachedConsumers)
139159
var requestOffset = range.fromOffset
140160

141161
override def getNext(): ConsumerRecord[Array[Byte], Array[Byte]] = {
@@ -156,8 +176,19 @@ private[kafka010] class KafkaSourceRDD(
156176
}
157177
}
158178

159-
override protected def close(): Unit = {}
179+
override protected def close(): Unit = {
180+
if (!reuseCachedConsumers) {
181+
consumer.close()
182+
}
183+
}
184+
}
185+
if (!reuseCachedConsumers) {
186+
// Don't forget to close consumers! You may take down your Kafka cluster.
187+
context.addTaskCompletionListener { _ =>
188+
underlying.closeIfNeeded()
189+
}
160190
}
191+
underlying
161192
}
162193
}
163194
}

0 commit comments

Comments
 (0)