Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,17 @@ private[spark] class DirectKafkaInputDStream[K, V](

lagPerPartition.map { case (tp, lag) =>
val maxRateLimitPerPartition = ppc.maxRatePerPartition(tp)
val backpressureRate = Math.round(lag / totalLag.toFloat * rate)
val backpressureRate = lag / totalLag.toDouble * rate
tp -> (if (maxRateLimitPerPartition > 0) {
Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate)
}
case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp) }
case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp).toDouble }
}

if (effectiveRateLimitPerPartition.values.sum > 0) {
val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000
Some(effectiveRateLimitPerPartition.map {
case (tp, limit) => tp -> (secsPerBatch * limit).toLong
case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L)
})
} else {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,54 @@ class DirectKafkaStreamSuite
ssc.stop()
}

test("maxMessagesPerPartition with zero offset and rate equal to one") {
val topic = "backpressure"
val kafkaParams = getKafkaParams()
val batchIntervalMilliseconds = 60000
val sparkConf = new SparkConf()
// Safe, even with streaming, because we're using the direct API.
// Using 1 core is useful to make the test more predictable.
.setMaster("local[1]")
.setAppName(this.getClass.getSimpleName)
.set("spark.streaming.kafka.maxRatePerPartition", "100")

// Setup the streaming context
ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds))
val estimateRate = 1L
val fromOffsets = Map(
new TopicPartition(topic, 0) -> 0L,
new TopicPartition(topic, 1) -> 0L,
new TopicPartition(topic, 2) -> 0L,
new TopicPartition(topic, 3) -> 0L
)
val kafkaStream = withClue("Error creating direct stream") {
new DirectKafkaInputDStream[String, String](
ssc,
preferredHosts,
ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala),
new DefaultPerPartitionConfig(sparkConf)
) {
currentOffsets = fromOffsets
override val rateController = Some(new ConstantRateController(id, null, estimateRate))
}
}

val offsets = Map[TopicPartition, Long](
new TopicPartition(topic, 0) -> 0,
new TopicPartition(topic, 1) -> 100L,
new TopicPartition(topic, 2) -> 200L,
new TopicPartition(topic, 3) -> 300L
)
val result = kafkaStream.maxMessagesPerPartition(offsets)
val expected = Map(
new TopicPartition(topic, 0) -> 1L,
new TopicPartition(topic, 1) -> 10L,
new TopicPartition(topic, 2) -> 20L,
new TopicPartition(topic, 3) -> 30L
)
assert(result.contains(expected), s"Number of messages per partition must be at least 1")
}

/** Get the generated offset ranges from the DirectKafkaStream */
private def getOffsetRanges[K, V](
kafkaStream: DStream[ConsumerRecord[K, V]]): Seq[(Time, Array[OffsetRange])] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,17 @@ class DirectKafkaInputDStream[
val totalLag = lagPerPartition.values.sum

lagPerPartition.map { case (tp, lag) =>
val backpressureRate = Math.round(lag / totalLag.toFloat * rate)
val backpressureRate = lag / totalLag.toDouble * rate
tp -> (if (maxRateLimitPerPartition > 0) {
Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate)
}
case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition }
case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition.toDouble }
}

if (effectiveRateLimitPerPartition.values.sum > 0) {
val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000
Some(effectiveRateLimitPerPartition.map {
case (tp, limit) => tp -> (secsPerBatch * limit).toLong
case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L)
})
} else {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,57 @@ class DirectKafkaStreamSuite
ssc.stop()
}

test("maxMessagesPerPartition with zero offset and rate equal to one") {
val topic = "backpressure"
val kafkaParams = Map(
"metadata.broker.list" -> kafkaTestUtils.brokerAddress,
"auto.offset.reset" -> "smallest"
)

val batchIntervalMilliseconds = 60000
val sparkConf = new SparkConf()
// Safe, even with streaming, because we're using the direct API.
// Using 1 core is useful to make the test more predictable.
.setMaster("local[1]")
.setAppName(this.getClass.getSimpleName)
.set("spark.streaming.kafka.maxRatePerPartition", "100")

// Setup the streaming context
ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds))
val estimatedRate = 1L
val kafkaStream = withClue("Error creating direct stream") {
val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message)
val fromOffsets = Map(
TopicAndPartition(topic, 0) -> 0L,
TopicAndPartition(topic, 1) -> 0L,
TopicAndPartition(topic, 2) -> 0L,
TopicAndPartition(topic, 3) -> 0L
)
new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)](
ssc, kafkaParams, fromOffsets, messageHandler) {
override protected[streaming] val rateController =
Some(new DirectKafkaRateController(id, null) {
override def getLatestRate() = estimatedRate
})
}
}

val offsets = Map(
TopicAndPartition(topic, 0) -> 0L,
TopicAndPartition(topic, 1) -> 100L,
TopicAndPartition(topic, 2) -> 200L,
TopicAndPartition(topic, 3) -> 300L
)
val result = kafkaStream.maxMessagesPerPartition(offsets)
val expected = Map(
TopicAndPartition(topic, 0) -> 1L,
TopicAndPartition(topic, 1) -> 10L,
TopicAndPartition(topic, 2) -> 20L,
TopicAndPartition(topic, 3) -> 30L
)
assert(result.contains(expected), s"Number of messages per partition must be at least 1")
}

/** Get the generated offset ranges from the DirectKafkaStream */
private def getOffsetRanges[K, V](
kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = {
Expand Down