Skip to content

Commit 486f99e

Browse files
committed
[SPARK-23541][SS] Allow Kafka source to read data with greater parallelism than the number of topic-partitions
## What changes were proposed in this pull request? Currently, when the Kafka source reads from Kafka, it generates as many tasks as the number of partitions in the topic(s) to be read. In some case, it may be beneficial to read the data with greater parallelism, that is, with more number partitions/tasks. That means, offset ranges must be divided up into smaller ranges such the number of records in partition ~= total records in batch / desired partitions. This would also balance out any data skews between topic-partitions. In this patch, I have added a new option called `minPartitions`, which allows the user to specify the desired level of parallelism. ## How was this patch tested? New tests in KafkaMicroBatchV2SourceSuite. Author: Tathagata Das <[email protected]> Closes #20698 from tdas/SPARK-23541.
1 parent dea381d commit 486f99e

File tree

6 files changed

+388
-60
lines changed

6 files changed

+388
-60
lines changed

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala

Lines changed: 51 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import java.nio.charset.StandardCharsets
2424
import scala.collection.JavaConverters._
2525

2626
import org.apache.commons.io.IOUtils
27-
import org.apache.kafka.common.TopicPartition
2827

2928
import org.apache.spark.SparkEnv
3029
import org.apache.spark.internal.Logging
@@ -64,8 +63,6 @@ private[kafka010] class KafkaMicroBatchReader(
6463
failOnDataLoss: Boolean)
6564
extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
6665

67-
type PartitionOffsetMap = Map[TopicPartition, Long]
68-
6966
private var startPartitionOffsets: PartitionOffsetMap = _
7067
private var endPartitionOffsets: PartitionOffsetMap = _
7168

@@ -76,6 +73,7 @@ private[kafka010] class KafkaMicroBatchReader(
7673
private val maxOffsetsPerTrigger =
7774
Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong)
7875

76+
private val rangeCalculator = KafkaOffsetRangeCalculator(options)
7977
/**
8078
* Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only
8179
* called in StreamExecutionThread. Otherwise, interrupting a thread while running
@@ -106,15 +104,15 @@ private[kafka010] class KafkaMicroBatchReader(
106104
override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
107105
// Find the new partitions, and get their earliest offsets
108106
val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet)
109-
val newPartitionOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
110-
if (newPartitionOffsets.keySet != newPartitions) {
107+
val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
108+
if (newPartitionInitialOffsets.keySet != newPartitions) {
111109
// We cannot get from offsets for some partitions. It means they got deleted.
112-
val deletedPartitions = newPartitions.diff(newPartitionOffsets.keySet)
110+
val deletedPartitions = newPartitions.diff(newPartitionInitialOffsets.keySet)
113111
reportDataLoss(
114112
s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed")
115113
}
116-
logInfo(s"Partitions added: $newPartitionOffsets")
117-
newPartitionOffsets.filter(_._2 != 0).foreach { case (p, o) =>
114+
logInfo(s"Partitions added: $newPartitionInitialOffsets")
115+
newPartitionInitialOffsets.filter(_._2 != 0).foreach { case (p, o) =>
118116
reportDataLoss(
119117
s"Added partition $p starts from $o instead of 0. Some data may have been missed")
120118
}
@@ -125,46 +123,28 @@ private[kafka010] class KafkaMicroBatchReader(
125123
reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed")
126124
}
127125

128-
// Use the until partitions to calculate offset ranges to ignore partitions that have
126+
// Use the end partitions to calculate offset ranges to ignore partitions that have
129127
// been deleted
130128
val topicPartitions = endPartitionOffsets.keySet.filter { tp =>
131129
// Ignore partitions that we don't know the from offsets.
132-
newPartitionOffsets.contains(tp) || startPartitionOffsets.contains(tp)
130+
newPartitionInitialOffsets.contains(tp) || startPartitionOffsets.contains(tp)
133131
}.toSeq
134132
logDebug("TopicPartitions: " + topicPartitions.mkString(", "))
135133

136-
val sortedExecutors = getSortedExecutorList()
137-
val numExecutors = sortedExecutors.length
138-
logDebug("Sorted executors: " + sortedExecutors.mkString(", "))
139-
140134
// Calculate offset ranges
141-
val factories = topicPartitions.flatMap { tp =>
142-
val fromOffset = startPartitionOffsets.get(tp).getOrElse {
143-
newPartitionOffsets.getOrElse(
144-
tp, {
145-
// This should not happen since newPartitionOffsets contains all partitions not in
146-
// fromPartitionOffsets
147-
throw new IllegalStateException(s"$tp doesn't have a from offset")
148-
})
149-
}
150-
val untilOffset = endPartitionOffsets(tp)
151-
152-
if (untilOffset >= fromOffset) {
153-
// This allows cached KafkaConsumers in the executors to be re-used to read the same
154-
// partition in every batch.
155-
val preferredLoc = if (numExecutors > 0) {
156-
Some(sortedExecutors(Math.floorMod(tp.hashCode, numExecutors)))
157-
} else None
158-
val range = KafkaOffsetRange(tp, fromOffset, untilOffset)
159-
Some(
160-
new KafkaMicroBatchDataReaderFactory(
161-
range, preferredLoc, executorKafkaParams, pollTimeoutMs, failOnDataLoss))
162-
} else {
163-
reportDataLoss(
164-
s"Partition $tp's offset was changed from " +
165-
s"$fromOffset to $untilOffset, some data may have been missed")
166-
None
167-
}
135+
val offsetRanges = rangeCalculator.getRanges(
136+
fromOffsets = startPartitionOffsets ++ newPartitionInitialOffsets,
137+
untilOffsets = endPartitionOffsets,
138+
executorLocations = getSortedExecutorList())
139+
140+
// Reuse Kafka consumers only when all the offset ranges have distinct TopicPartitions,
141+
// that is, concurrent tasks will not read the same TopicPartitions.
142+
val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size
143+
144+
// Generate factories based on the offset ranges
145+
val factories = offsetRanges.map { range =>
146+
new KafkaMicroBatchDataReaderFactory(
147+
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
168148
}
169149
factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava
170150
}
@@ -320,28 +300,39 @@ private[kafka010] class KafkaMicroBatchReader(
320300
}
321301

322302
/** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */
323-
private[kafka010] class KafkaMicroBatchDataReaderFactory(
324-
range: KafkaOffsetRange,
325-
preferredLoc: Option[String],
303+
private[kafka010] case class KafkaMicroBatchDataReaderFactory(
304+
offsetRange: KafkaOffsetRange,
326305
executorKafkaParams: ju.Map[String, Object],
327306
pollTimeoutMs: Long,
328-
failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] {
307+
failOnDataLoss: Boolean,
308+
reuseKafkaConsumer: Boolean) extends DataReaderFactory[UnsafeRow] {
329309

330-
override def preferredLocations(): Array[String] = preferredLoc.toArray
310+
override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray
331311

332312
override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader(
333-
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss)
313+
offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
334314
}
335315

336316
/** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */
337-
private[kafka010] class KafkaMicroBatchDataReader(
317+
private[kafka010] case class KafkaMicroBatchDataReader(
338318
offsetRange: KafkaOffsetRange,
339319
executorKafkaParams: ju.Map[String, Object],
340320
pollTimeoutMs: Long,
341-
failOnDataLoss: Boolean) extends DataReader[UnsafeRow] with Logging {
321+
failOnDataLoss: Boolean,
322+
reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging {
323+
324+
private val consumer = {
325+
if (!reuseKafkaConsumer) {
326+
// If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. We
327+
// uses `assign` here, hence we don't need to worry about the "group.id" conflicts.
328+
CachedKafkaConsumer.createUncached(
329+
offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams)
330+
} else {
331+
CachedKafkaConsumer.getOrCreate(
332+
offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams)
333+
}
334+
}
342335

343-
private val consumer = CachedKafkaConsumer.getOrCreate(
344-
offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams)
345336
private val rangeToRead = resolveRange(offsetRange)
346337
private val converter = new KafkaRecordToUnsafeRowConverter
347338

@@ -369,9 +360,14 @@ private[kafka010] class KafkaMicroBatchDataReader(
369360
}
370361

371362
override def close(): Unit = {
372-
// Indicate that we're no longer using this consumer
373-
CachedKafkaConsumer.releaseKafkaConsumer(
374-
offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams)
363+
if (!reuseKafkaConsumer) {
364+
// Don't forget to close non-reuse KafkaConsumers. You may take down your cluster!
365+
consumer.close()
366+
} else {
367+
// Indicate that we're no longer using this consumer
368+
CachedKafkaConsumer.releaseKafkaConsumer(
369+
offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams)
370+
}
375371
}
376372

377373
private def resolveRange(range: KafkaOffsetRange): KafkaOffsetRange = {
@@ -392,12 +388,9 @@ private[kafka010] class KafkaMicroBatchDataReader(
392388
} else {
393389
range.untilOffset
394390
}
395-
KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset)
391+
KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset, None)
396392
} else {
397393
range
398394
}
399395
}
400396
}
401-
402-
private[kafka010] case class KafkaOffsetRange(
403-
topicPartition: TopicPartition, fromOffset: Long, untilOffset: Long)
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.kafka010
19+
20+
import org.apache.kafka.common.TopicPartition
21+
22+
import org.apache.spark.sql.sources.v2.DataSourceOptions
23+
24+
25+
/**
26+
* Class to calculate offset ranges to process based on the the from and until offsets, and
27+
* the configured `minPartitions`.
28+
*/
29+
private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int]) {
30+
require(minPartitions.isEmpty || minPartitions.get > 0)
31+
32+
import KafkaOffsetRangeCalculator._
33+
/**
34+
* Calculate the offset ranges that we are going to process this batch. If `minPartitions`
35+
* is not set or is set less than or equal the number of `topicPartitions` that we're going to
36+
* consume, then we fall back to a 1-1 mapping of Spark tasks to Kafka partitions. If
37+
* `numPartitions` is set higher than the number of our `topicPartitions`, then we will split up
38+
* the read tasks of the skewed partitions to multiple Spark tasks.
39+
* The number of Spark tasks will be *approximately* `numPartitions`. It can be less or more
40+
* depending on rounding errors or Kafka partitions that didn't receive any new data.
41+
*/
42+
def getRanges(
43+
fromOffsets: PartitionOffsetMap,
44+
untilOffsets: PartitionOffsetMap,
45+
executorLocations: Seq[String] = Seq.empty): Seq[KafkaOffsetRange] = {
46+
val partitionsToRead = untilOffsets.keySet.intersect(fromOffsets.keySet)
47+
48+
val offsetRanges = partitionsToRead.toSeq.map { tp =>
49+
KafkaOffsetRange(tp, fromOffsets(tp), untilOffsets(tp), preferredLoc = None)
50+
}.filter(_.size > 0)
51+
52+
// If minPartitions not set or there are enough partitions to satisfy minPartitions
53+
if (minPartitions.isEmpty || offsetRanges.size > minPartitions.get) {
54+
// Assign preferred executor locations to each range such that the same topic-partition is
55+
// preferentially read from the same executor and the KafkaConsumer can be reused.
56+
offsetRanges.map { range =>
57+
range.copy(preferredLoc = getLocation(range.topicPartition, executorLocations))
58+
}
59+
} else {
60+
61+
// Splits offset ranges with relatively large amount of data to smaller ones.
62+
val totalSize = offsetRanges.map(_.size).sum
63+
val idealRangeSize = totalSize.toDouble / minPartitions.get
64+
65+
offsetRanges.flatMap { range =>
66+
// Split the current range into subranges as close to the ideal range size
67+
val numSplitsInRange = math.round(range.size.toDouble / idealRangeSize).toInt
68+
69+
(0 until numSplitsInRange).map { i =>
70+
val splitStart = range.fromOffset + range.size * (i.toDouble / numSplitsInRange)
71+
val splitEnd = range.fromOffset + range.size * ((i.toDouble + 1) / numSplitsInRange)
72+
KafkaOffsetRange(
73+
range.topicPartition, splitStart.toLong, splitEnd.toLong, preferredLoc = None)
74+
}
75+
}
76+
}
77+
}
78+
79+
private def getLocation(tp: TopicPartition, executorLocations: Seq[String]): Option[String] = {
80+
def floorMod(a: Long, b: Int): Int = ((a % b).toInt + b) % b
81+
82+
val numExecutors = executorLocations.length
83+
if (numExecutors > 0) {
84+
// This allows cached KafkaConsumers in the executors to be re-used to read the same
85+
// partition in every batch.
86+
Some(executorLocations(floorMod(tp.hashCode, numExecutors)))
87+
} else None
88+
}
89+
}
90+
91+
private[kafka010] object KafkaOffsetRangeCalculator {
92+
93+
def apply(options: DataSourceOptions): KafkaOffsetRangeCalculator = {
94+
val optionalValue = Option(options.get("minPartitions").orElse(null)).map(_.toInt)
95+
new KafkaOffsetRangeCalculator(optionalValue)
96+
}
97+
}
98+
99+
private[kafka010] case class KafkaOffsetRange(
100+
topicPartition: TopicPartition,
101+
fromOffset: Long,
102+
untilOffset: Long,
103+
preferredLoc: Option[String]) {
104+
lazy val size: Long = untilOffset - fromOffset
105+
}

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
348348
throw new IllegalArgumentException("Unknown option")
349349
}
350350

351+
// Validate minPartitions value if present
352+
if (caseInsensitiveParams.contains(MIN_PARTITIONS_OPTION_KEY)) {
353+
val p = caseInsensitiveParams(MIN_PARTITIONS_OPTION_KEY).toInt
354+
if (p <= 0) throw new IllegalArgumentException("minPartitions must be positive")
355+
}
356+
351357
// Validate user-specified Kafka options
352358

353359
if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) {
@@ -455,6 +461,7 @@ private[kafka010] object KafkaSourceProvider extends Logging {
455461
private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets"
456462
private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets"
457463
private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss"
464+
private val MIN_PARTITIONS_OPTION_KEY = "minpartitions"
458465

459466
val TOPIC_OPTION_KEY = "topic"
460467

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql
18+
19+
import org.apache.kafka.common.TopicPartition
20+
21+
package object kafka010 { // scalastyle:ignore
22+
// ^^ scalastyle:ignore is for ignoring warnings about digits in package name
23+
type PartitionOffsetMap = Map[TopicPartition, Long]
24+
}

0 commit comments

Comments
 (0)