Skip to content

Commit 823baca

Browse files
zsxwingtdas
authored andcommitted
[SPARK-20452][SS][KAFKA] Fix a potential ConcurrentModificationException for batch Kafka DataFrame
## What changes were proposed in this pull request? Cancel a batch Kafka query but one of task cannot be cancelled, and rerun the same DataFrame may cause ConcurrentModificationException because it may launch two tasks sharing the same group id. This PR always create a new consumer when `reuseKafkaConsumer = false` to avoid ConcurrentModificationException. It also contains other minor fixes. ## How was this patch tested? Jenkins. Author: Shixiong Zhu <[email protected]> Closes #17752 from zsxwing/kafka-fix.
1 parent 01c999e commit 823baca

File tree

6 files changed

+119
-97
lines changed

6 files changed

+119
-97
lines changed

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ private[kafka010] case class CachedKafkaConsumer private(
287287
reportDataLoss0(failOnDataLoss, finalMessage, cause)
288288
}
289289

290-
private def close(): Unit = consumer.close()
290+
def close(): Unit = consumer.close()
291291

292292
private def seek(offset: Long): Unit = {
293293
logDebug(s"Seeking to $groupId $topicPartition $offset")
@@ -382,7 +382,7 @@ private[kafka010] object CachedKafkaConsumer extends Logging {
382382

383383
// If this is reattempt at running the task, then invalidate cache and start with
384384
// a new consumer
385-
if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) {
385+
if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) {
386386
removeKafkaConsumer(topic, partition, kafkaParams)
387387
val consumer = new CachedKafkaConsumer(topicPartition, kafkaParams)
388388
consumer.inuse = true
@@ -398,6 +398,14 @@ private[kafka010] object CachedKafkaConsumer extends Logging {
398398
}
399399
}
400400

401+
/** Create an [[CachedKafkaConsumer]] but don't put it into cache. */
402+
def createUncached(
403+
topic: String,
404+
partition: Int,
405+
kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = {
406+
new CachedKafkaConsumer(new TopicPartition(topic, partition), kafkaParams)
407+
}
408+
401409
private def reportDataLoss0(
402410
failOnDataLoss: Boolean,
403411
finalMessage: String,

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,10 @@ private[kafka010] class KafkaOffsetReader(
9595
* Closes the connection to Kafka, and cleans up state.
9696
*/
9797
def close(): Unit = {
98-
consumer.close()
99-
kafkaReaderThread.shutdownNow()
98+
runUninterruptibly {
99+
consumer.close()
100+
}
101+
kafkaReaderThread.shutdown()
100102
}
101103

102104
/**

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.kafka010
1919

2020
import java.{util => ju}
21+
import java.util.UUID
2122

2223
import org.apache.kafka.common.TopicPartition
2324

@@ -33,9 +34,9 @@ import org.apache.spark.unsafe.types.UTF8String
3334

3435
private[kafka010] class KafkaRelation(
3536
override val sqlContext: SQLContext,
36-
kafkaReader: KafkaOffsetReader,
37-
executorKafkaParams: ju.Map[String, Object],
37+
strategy: ConsumerStrategy,
3838
sourceOptions: Map[String, String],
39+
specifiedKafkaParams: Map[String, String],
3940
failOnDataLoss: Boolean,
4041
startingOffsets: KafkaOffsetRangeLimit,
4142
endingOffsets: KafkaOffsetRangeLimit)
@@ -53,9 +54,27 @@ private[kafka010] class KafkaRelation(
5354
override def schema: StructType = KafkaOffsetReader.kafkaSchema
5455

5556
override def buildScan(): RDD[Row] = {
57+
// Each running query should use its own group id. Otherwise, the query may be only assigned
58+
// partial data since Kafka will assign partitions to multiple consumers having the same group
59+
// id. Hence, we should generate a unique id for each query.
60+
val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}"
61+
62+
val kafkaOffsetReader = new KafkaOffsetReader(
63+
strategy,
64+
KafkaSourceProvider.kafkaParamsForDriver(specifiedKafkaParams),
65+
sourceOptions,
66+
driverGroupIdPrefix = s"$uniqueGroupId-driver")
67+
5668
// Leverage the KafkaReader to obtain the relevant partition offsets
57-
val fromPartitionOffsets = getPartitionOffsets(startingOffsets)
58-
val untilPartitionOffsets = getPartitionOffsets(endingOffsets)
69+
val (fromPartitionOffsets, untilPartitionOffsets) = {
70+
try {
71+
(getPartitionOffsets(kafkaOffsetReader, startingOffsets),
72+
getPartitionOffsets(kafkaOffsetReader, endingOffsets))
73+
} finally {
74+
kafkaOffsetReader.close()
75+
}
76+
}
77+
5978
// Obtain topicPartitions in both from and until partition offset, ignoring
6079
// topic partitions that were added and/or deleted between the two above calls.
6180
if (fromPartitionOffsets.keySet != untilPartitionOffsets.keySet) {
@@ -82,6 +101,8 @@ private[kafka010] class KafkaRelation(
82101
offsetRanges.sortBy(_.topicPartition.toString).mkString(", "))
83102

84103
// Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays.
104+
val executorKafkaParams =
105+
KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId)
85106
val rdd = new KafkaSourceRDD(
86107
sqlContext.sparkContext, executorKafkaParams, offsetRanges,
87108
pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer = false).map { cr =>
@@ -98,6 +119,7 @@ private[kafka010] class KafkaRelation(
98119
}
99120

100121
private def getPartitionOffsets(
122+
kafkaReader: KafkaOffsetReader,
101123
kafkaOffsets: KafkaOffsetRangeLimit): Map[TopicPartition, Long] = {
102124
def validateTopicPartitions(partitions: Set[TopicPartition],
103125
partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = {

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala

Lines changed: 69 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
111111
sqlContext: SQLContext,
112112
parameters: Map[String, String]): BaseRelation = {
113113
validateBatchOptions(parameters)
114-
// Each running query should use its own group id. Otherwise, the query may be only assigned
115-
// partial data since Kafka will assign partitions to multiple consumers having the same group
116-
// id. Hence, we should generate a unique id for each query.
117-
val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}"
118114
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
119115
val specifiedKafkaParams =
120116
parameters
@@ -131,20 +127,14 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
131127
ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
132128
assert(endingRelationOffsets != EarliestOffsetRangeLimit)
133129

134-
val kafkaOffsetReader = new KafkaOffsetReader(
135-
strategy(caseInsensitiveParams),
136-
kafkaParamsForDriver(specifiedKafkaParams),
137-
parameters,
138-
driverGroupIdPrefix = s"$uniqueGroupId-driver")
139-
140130
new KafkaRelation(
141131
sqlContext,
142-
kafkaOffsetReader,
143-
kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
144-
parameters,
145-
failOnDataLoss(caseInsensitiveParams),
146-
startingRelationOffsets,
147-
endingRelationOffsets)
132+
strategy(caseInsensitiveParams),
133+
sourceOptions = parameters,
134+
specifiedKafkaParams = specifiedKafkaParams,
135+
failOnDataLoss = failOnDataLoss(caseInsensitiveParams),
136+
startingOffsets = startingRelationOffsets,
137+
endingOffsets = endingRelationOffsets)
148138
}
149139

150140
override def createSink(
@@ -213,46 +203,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
213203
ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName)
214204
}
215205

216-
private def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]) =
217-
ConfigUpdater("source", specifiedKafkaParams)
218-
.set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
219-
.set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
220-
221-
// Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial
222-
// offsets by itself instead of counting on KafkaConsumer.
223-
.set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")
224-
225-
// So that consumers in the driver does not commit offsets unnecessarily
226-
.set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
227-
228-
// So that the driver does not pull too much data
229-
.set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1))
230-
231-
// If buffer config is not set, set it to reasonable value to work around
232-
// buffer issues (see KAFKA-3135)
233-
.setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
234-
.build()
235-
236-
private def kafkaParamsForExecutors(
237-
specifiedKafkaParams: Map[String, String], uniqueGroupId: String) =
238-
ConfigUpdater("executor", specifiedKafkaParams)
239-
.set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
240-
.set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
241-
242-
// Make sure executors do only what the driver tells them.
243-
.set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none")
244-
245-
// So that consumers in executors do not mess with any existing group id
246-
.set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor")
247-
248-
// So that consumers in executors does not commit offsets unnecessarily
249-
.set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
250-
251-
// If buffer config is not set, set it to reasonable value to work around
252-
// buffer issues (see KAFKA-3135)
253-
.setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
254-
.build()
255-
256206
private def strategy(caseInsensitiveParams: Map[String, String]) =
257207
caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match {
258208
case ("assign", value) =>
@@ -414,30 +364,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
414364
logWarning("maxOffsetsPerTrigger option ignored in batch queries")
415365
}
416366
}
417-
418-
/** Class to conveniently update Kafka config params, while logging the changes */
419-
private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) {
420-
private val map = new ju.HashMap[String, Object](kafkaParams.asJava)
421-
422-
def set(key: String, value: Object): this.type = {
423-
map.put(key, value)
424-
logInfo(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}")
425-
this
426-
}
427-
428-
def setIfUnset(key: String, value: Object): ConfigUpdater = {
429-
if (!map.containsKey(key)) {
430-
map.put(key, value)
431-
logInfo(s"$module: Set $key to $value")
432-
}
433-
this
434-
}
435-
436-
def build(): ju.Map[String, Object] = map
437-
}
438367
}
439368

440-
private[kafka010] object KafkaSourceProvider {
369+
private[kafka010] object KafkaSourceProvider extends Logging {
441370
private val STRATEGY_OPTION_KEYS = Set("subscribe", "subscribepattern", "assign")
442371
private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets"
443372
private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets"
@@ -459,4 +388,66 @@ private[kafka010] object KafkaSourceProvider {
459388
case None => defaultOffsets
460389
}
461390
}
391+
392+
def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]): ju.Map[String, Object] =
393+
ConfigUpdater("source", specifiedKafkaParams)
394+
.set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
395+
.set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
396+
397+
// Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial
398+
// offsets by itself instead of counting on KafkaConsumer.
399+
.set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")
400+
401+
// So that consumers in the driver does not commit offsets unnecessarily
402+
.set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
403+
404+
// So that the driver does not pull too much data
405+
.set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1))
406+
407+
// If buffer config is not set, set it to reasonable value to work around
408+
// buffer issues (see KAFKA-3135)
409+
.setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
410+
.build()
411+
412+
def kafkaParamsForExecutors(
413+
specifiedKafkaParams: Map[String, String],
414+
uniqueGroupId: String): ju.Map[String, Object] =
415+
ConfigUpdater("executor", specifiedKafkaParams)
416+
.set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
417+
.set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
418+
419+
// Make sure executors do only what the driver tells them.
420+
.set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none")
421+
422+
// So that consumers in executors do not mess with any existing group id
423+
.set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor")
424+
425+
// So that consumers in executors does not commit offsets unnecessarily
426+
.set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
427+
428+
// If buffer config is not set, set it to reasonable value to work around
429+
// buffer issues (see KAFKA-3135)
430+
.setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
431+
.build()
432+
433+
/** Class to conveniently update Kafka config params, while logging the changes */
434+
private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) {
435+
private val map = new ju.HashMap[String, Object](kafkaParams.asJava)
436+
437+
def set(key: String, value: Object): this.type = {
438+
map.put(key, value)
439+
logDebug(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}")
440+
this
441+
}
442+
443+
def setIfUnset(key: String, value: Object): ConfigUpdater = {
444+
if (!map.containsKey(key)) {
445+
map.put(key, value)
446+
logDebug(s"$module: Set $key to $value")
447+
}
448+
this
449+
}
450+
451+
def build(): ju.Map[String, Object] = map
452+
}
462453
}

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,15 @@ private[kafka010] class KafkaSourceRDD(
125125
context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = {
126126
val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition]
127127
val topic = sourcePartition.offsetRange.topic
128-
if (!reuseKafkaConsumer) {
129-
// if we can't reuse CachedKafkaConsumers, let's reset the groupId to something unique
130-
// to each task (i.e., append the task's unique partition id), because we will have
131-
// multiple tasks (e.g., in the case of union) reading from the same topic partitions
132-
val old = executorKafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
133-
val id = TaskContext.getPartitionId()
134-
executorKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, old + "-" + id)
135-
}
136128
val kafkaPartition = sourcePartition.offsetRange.partition
137-
val consumer = CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams)
129+
val consumer =
130+
if (!reuseKafkaConsumer) {
131+
// If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. As here we
132+
// uses `assign`, we don't need to worry about the "group.id" conflicts.
133+
CachedKafkaConsumer.createUncached(topic, kafkaPartition, executorKafkaParams)
134+
} else {
135+
CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams)
136+
}
138137
val range = resolveRange(consumer, sourcePartition.offsetRange)
139138
assert(
140139
range.fromOffset <= range.untilOffset,
@@ -170,7 +169,7 @@ private[kafka010] class KafkaSourceRDD(
170169
override protected def close(): Unit = {
171170
if (!reuseKafkaConsumer) {
172171
// Don't forget to close non-reuse KafkaConsumers. You may take down your cluster!
173-
CachedKafkaConsumer.removeKafkaConsumer(topic, kafkaPartition, executorKafkaParams)
172+
consumer.close()
174173
} else {
175174
// Indicate that we're no longer using this consumer
176175
CachedKafkaConsumer.releaseKafkaConsumer(topic, kafkaPartition, executorKafkaParams)

external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ private[spark] class KafkaRDD[K, V](
199199

200200
val consumer = if (useConsumerCache) {
201201
CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor)
202-
if (context.attemptNumber > 1) {
202+
if (context.attemptNumber >= 1) {
203203
// just in case the prior attempt failures were cache related
204204
CachedKafkaConsumer.remove(groupId, part.topic, part.partition)
205205
}

0 commit comments

Comments
 (0)