Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,111 +18,191 @@
package org.apache.spark.sql.kafka010

import java.{util => ju}
import java.util.concurrent.{ConcurrentMap, ExecutionException, TimeUnit}
import java.util.concurrent.{ConcurrentLinkedQueue, ConcurrentMap, ExecutionException, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger
import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.util.control.NonFatal

import com.google.common.cache._
import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}
import org.apache.kafka.clients.producer.KafkaProducer
import scala.collection.JavaConverters._
import scala.util.control.NonFatal

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.kafka010.{KafkaConfigUpdater, KafkaRedactionUtil}

private[kafka010] object CachedKafkaProducer extends Logging {
private[kafka010] case class CachedKafkaProducer(
private val id: String = ju.UUID.randomUUID().toString,
private val inUseCount: AtomicInteger = new AtomicInteger(0),
private val kafkaParams: Seq[(String, Object)]) extends Logging {

private val configMap = kafkaParams.map(x => x._1 -> x._2).toMap.asJava

lazy val kafkaProducer: KafkaProducer[Array[Byte], Array[Byte]] = {
val producer = new KafkaProducer[Array[Byte], Array[Byte]](configMap)
if (log.isDebugEnabled()) {
val redactedParamsSeq = KafkaRedactionUtil.redactParams(kafkaParams)
logDebug(s"Created a new instance of KafkaProducer for $redactedParamsSeq, with Id: $id.")
}
closed = false
producer
}
@GuardedBy("this")
private var closed: Boolean = true
private def close(): Unit = {
try {
this.synchronized {
if (!closed) {
closed = true
if (log.isInfoEnabled()) {
val redactedParamsSeq = KafkaRedactionUtil.redactParams(kafkaParams)
logInfo(s"Closing the KafkaProducer with params: ${redactedParamsSeq.mkString("\n")}.")
}
kafkaProducer.close()
}
}
} catch {
case NonFatal(e) =>
logWarning(s"Error while closing kafka producer with params: $kafkaParams", e)
}
}

private def inUse(): Boolean = inUseCount.get() > 0

private[kafka010] def getInUseCount: Int = inUseCount.get()

private type Producer = KafkaProducer[Array[Byte], Array[Byte]]
private[kafka010] def getKafkaParams: Seq[(String, Object)] = kafkaParams

private[kafka010] def flush(): Unit = kafkaProducer.flush()

private[kafka010] def isClosed: Boolean = closed
}

private[kafka010] object CachedKafkaProducer extends Logging {

private val defaultCacheExpireTimeout = TimeUnit.MINUTES.toMillis(10)

private lazy val cacheExpireTimeout: Long = Option(SparkEnv.get)
.map(_.conf.get(PRODUCER_CACHE_TIMEOUT))
.getOrElse(defaultCacheExpireTimeout)

private val cacheLoader = new CacheLoader[Seq[(String, Object)], Producer] {
override def load(config: Seq[(String, Object)]): Producer = {
createKafkaProducer(config)
private val cacheLoader = new CacheLoader[Seq[(String, Object)], CachedKafkaProducer] {
override def load(params: Seq[(String, Object)]): CachedKafkaProducer = {
CachedKafkaProducer(kafkaParams = params)
}
}

private val removalListener = new RemovalListener[Seq[(String, Object)], Producer]() {
private def updatedAuthConfigIfNeeded(kafkaParamsMap: ju.Map[String, Object]) =
KafkaConfigUpdater("executor", kafkaParamsMap.asScala.toMap)
.setAuthenticationConfigIfNeeded()
.build()

private val closeQueue = new ConcurrentLinkedQueue[CachedKafkaProducer]()

private val removalListener = new RemovalListener[Seq[(String, Object)], CachedKafkaProducer]() {
override def onRemoval(
notification: RemovalNotification[Seq[(String, Object)], Producer]): Unit = {
val paramsSeq: Seq[(String, Object)] = notification.getKey
val producer: Producer = notification.getValue
notification: RemovalNotification[Seq[(String, Object)], CachedKafkaProducer]): Unit = {
val producer: CachedKafkaProducer = notification.getValue
if (log.isDebugEnabled()) {
val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq)
val redactedParamsSeq = KafkaRedactionUtil.redactParams(producer.kafkaParams)
logDebug(s"Evicting kafka producer $producer params: $redactedParamsSeq, " +
s"due to ${notification.getCause}")
}
close(paramsSeq, producer)
if (producer.inUse()) {
// When `inuse` producer is evicted we wait for it to be released by all the tasks,
// before finally closing it.
closeQueue.add(producer)
} else {
producer.close()
}
}
}

private lazy val guavaCache: LoadingCache[Seq[(String, Object)], Producer] =
private lazy val guavaCache: LoadingCache[Seq[(String, Object)], CachedKafkaProducer] =
CacheBuilder.newBuilder().expireAfterAccess(cacheExpireTimeout, TimeUnit.MILLISECONDS)
.removalListener(removalListener)
.build[Seq[(String, Object)], Producer](cacheLoader)

private def createKafkaProducer(paramsSeq: Seq[(String, Object)]): Producer = {
val kafkaProducer: Producer = new Producer(paramsSeq.toMap.asJava)
if (log.isDebugEnabled()) {
val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq)
logDebug(s"Created a new instance of KafkaProducer for $redactedParamsSeq.")
}
kafkaProducer
}
.build[Seq[(String, Object)], CachedKafkaProducer](cacheLoader)

/**
* Get a cached KafkaProducer for a given configuration. If matching KafkaProducer doesn't
* exist, a new KafkaProducer will be created. KafkaProducer is thread safe, it is best to keep
* one instance per specified kafkaParams.
*/
private[kafka010] def getOrCreate(kafkaParams: ju.Map[String, Object]): Producer = {
val updatedKafkaProducerConfiguration =
KafkaConfigUpdater("executor", kafkaParams.asScala.toMap)
.setAuthenticationConfigIfNeeded()
.build()
val paramsSeq: Seq[(String, Object)] = paramsToSeq(updatedKafkaProducerConfiguration)
private[kafka010] def acquire(kafkaParamsMap: ju.Map[String, Object]): CachedKafkaProducer = {
val paramsSeq: Seq[(String, Object)] = paramsToSeq(updatedAuthConfigIfNeeded(kafkaParamsMap))
try {
guavaCache.get(paramsSeq)
val producer = this.synchronized {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this required? It's risky to add new global locks to things.

val cachedKafkaProducer: CachedKafkaProducer = guavaCache.get(paramsSeq)
cachedKafkaProducer.inUseCount.incrementAndGet()
logDebug(s"Granted producer $cachedKafkaProducer")
cachedKafkaProducer
}
producer
} catch {
case e @ (_: ExecutionException | _: UncheckedExecutionException | _: ExecutionError)
case e@(_: ExecutionException | _: UncheckedExecutionException | _: ExecutionError)
if e.getCause != null =>
throw e.getCause
}
}

private def paramsToSeq(kafkaParams: ju.Map[String, Object]): Seq[(String, Object)] = {
val paramsSeq: Seq[(String, Object)] = kafkaParams.asScala.toSeq.sortBy(x => x._1)
private def paramsToSeq(kafkaParamsMap: ju.Map[String, Object]): Seq[(String, Object)] = {
val paramsSeq: Seq[(String, Object)] = kafkaParamsMap.asScala.toSeq.sortBy(x => x._1)
paramsSeq
}

/** For explicitly closing kafka producer */
private[kafka010] def close(kafkaParams: ju.Map[String, Object]): Unit = {
val paramsSeq = paramsToSeq(kafkaParams)
guavaCache.invalidate(paramsSeq)
/* Release a kafka producer back to the kafka cache. We simply decrement it's inuse count. */
private[kafka010] def release(producer: CachedKafkaProducer, failing: Boolean): Unit = {
this.synchronized {
// It should be ok to call release multiple times on the same producer object.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it's not really okay, right? If task A calls release multiple times, the producer might have its inUseCount decremented to 0 even though task B is using it.

if (producer.inUse()) {
producer.inUseCount.decrementAndGet()
logDebug(s"Released producer $producer.")
} else {
logWarning(s"Tried to release a not in use producer, $producer.")
}
if (failing) {
// If this producer is failing to write, we remove it from cache.
// So that it is re-created, eventually.
val cachedProducer = guavaCache.getIfPresent(producer.kafkaParams)
if (cachedProducer != null && cachedProducer.id == producer.id) {
logDebug(s"Invalidating a failing producer: $producer.")
guavaCache.invalidate(producer.kafkaParams)
}
}
}
// We need a close queue, so that we can close the producer(s) outside of a synchronized block.
processPendingClose()
}

/** Auto close on cache evict */
private def close(paramsSeq: Seq[(String, Object)], producer: Producer): Unit = {
try {
if (log.isInfoEnabled()) {
val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq)
logInfo(s"Closing the KafkaProducer with params: ${redactedParamsSeq.mkString("\n")}.")
/** Process pending closes. */
private def processPendingClose(): Unit = {
// Check and close any other producers previously evicted, but pending to be closed.
for (p <- closeQueue.iterator().asScala) {
if (!p.inUse()) {
closeQueue.remove(p)
p.close()
}
producer.close()
} catch {
case NonFatal(e) => logWarning("Error while closing kafka producer.", e)
}
}

// For testing only.
private[kafka010] def clear(): Unit = {
logInfo("Cleaning up guava cache.")
logInfo("Cleaning up guava cache and force closing all kafka producers.")
guavaCache.invalidateAll()
for (p <- closeQueue.iterator().asScala) {
p.close()
}
closeQueue.clear()
}

// For testing only.
private[kafka010] def evict(params: Seq[(String, Object)]): Unit = {
guavaCache.invalidate(params)
}

// Intended for testing purpose only.
private def getAsMap: ConcurrentMap[Seq[(String, Object)], Producer] = guavaCache.asMap()
// For testing only.
private[kafka010] def getAsMap: ConcurrentMap[Seq[(String, Object)], CachedKafkaProducer] =
guavaCache.asMap()
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ private[kafka010] class KafkaDataWriter(
inputSchema: Seq[Attribute])
extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] {

private lazy val producer = CachedKafkaProducer.getOrCreate(producerParams)
protected lazy val producer = CachedKafkaProducer.acquire(producerParams)

def write(row: InternalRow): Unit = {
checkForErrors()
Expand All @@ -61,14 +61,17 @@ private[kafka010] class KafkaDataWriter(
KafkaDataWriterCommitMessage
}

def abort(): Unit = {}
def abort(): Unit = {
close()
}

def close(): Unit = {
checkForErrors()
if (producer != null) {
try {
checkForErrors()
producer.flush()
checkForErrors()
CachedKafkaProducer.close(producerParams)
} finally {
CachedKafkaProducer.release(producer, failedWrite != null)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.kafka010

import java.{util => ju}

import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata}
import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
Expand All @@ -35,32 +35,35 @@ private[kafka010] class KafkaWriteTask(
inputSchema: Seq[Attribute],
topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) {
// used to synchronize with Kafka callbacks
private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _
protected val producer: CachedKafkaProducer =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a change in lifecycle for the producer. Are we sure that's safe?

CachedKafkaProducer.acquire(producerConfiguration)

/**
* Writes key value data out to topics.
*/
def execute(iterator: Iterator[InternalRow]): Unit = {
producer = CachedKafkaProducer.getOrCreate(producerConfiguration)
while (iterator.hasNext && failedWrite == null) {
val currentRow = iterator.next()
sendRow(currentRow, producer)
}
}

def close(): Unit = {
checkForErrors()
if (producer != null) {
try {
checkForErrors()
producer.flush()
checkForErrors()
producer = null
} finally {
CachedKafkaProducer.release(producer, failedWrite != null)
}
}

}

private[kafka010] abstract class KafkaRowWriter(
inputSchema: Seq[Attribute], topic: Option[String]) {

protected val producer: CachedKafkaProducer
// used to synchronize with Kafka callbacks
@volatile protected var failedWrite: Exception = _
protected val projection = createProjection
Expand All @@ -79,7 +82,7 @@ private[kafka010] abstract class KafkaRowWriter(
* assuming the row is in Kafka.
*/
protected def sendRow(
row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = {
row: InternalRow, producer: CachedKafkaProducer): Unit = {
val projectedRow = projection(row)
val topic = projectedRow.getUTF8String(0)
val key = projectedRow.getBinary(1)
Expand All @@ -89,7 +92,7 @@ private[kafka010] abstract class KafkaRowWriter(
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
}
val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value)
producer.send(record, callback)
producer.kafkaProducer.send(record, callback)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a wrapper since we have one for flush?

}

protected def checkForErrors(): Unit = {
Expand Down
Loading