diff --git a/spring-integration-kafka/src/main/java/org/springframework/integration/kafka/inbound/KafkaMessageSource.java b/spring-integration-kafka/src/main/java/org/springframework/integration/kafka/inbound/KafkaMessageSource.java index 03e9aa57a61..88c0d7c8ca6 100644 --- a/spring-integration-kafka/src/main/java/org/springframework/integration/kafka/inbound/KafkaMessageSource.java +++ b/spring-integration-kafka/src/main/java/org/springframework/integration/kafka/inbound/KafkaMessageSource.java @@ -138,6 +138,8 @@ public class KafkaMessageSource extends AbstractMessageSource impl private Duration closeTimeout = Duration.ofSeconds(DEFAULT_CLOSE_TIMEOUT); + public boolean newAssignment; + private volatile Consumer consumer; private volatile boolean pausing; @@ -146,6 +148,8 @@ public class KafkaMessageSource extends AbstractMessageSource impl private volatile Iterator> recordsIterator; + private volatile boolean stopped; + /** * Construct an instance with the supplied parameters. Fetching multiple * records per poll will be disabled. @@ -386,12 +390,14 @@ public synchronized boolean isRunning() { @Override public synchronized void start() { this.running = true; + this.stopped = false; } @Override public synchronized void stop() { stopConsumer(); this.running = false; + this.stopped = true; } @Override @@ -411,6 +417,10 @@ public boolean isPaused() { @Override protected synchronized Object doReceive() { + if (this.stopped) { + this.logger.debug("Message source is stopped; no records will be returned"); + return null; + } if (this.consumer == null) { createConsumer(); this.running = true; @@ -511,14 +521,27 @@ private ConsumerRecord pollRecord() { } else { synchronized (this.consumerMonitor) { - ConsumerRecords records = this.consumer - .poll(this.assignedPartitions.isEmpty() ? this.assignTimeout : this.pollTimeout); - if (records == null || records.count() == 0) { + try { + ConsumerRecords records = this.consumer + .poll(this.assignedPartitions.isEmpty() ? this.assignTimeout : this.pollTimeout); + this.logger.debug(() -> records == null + ? "Received null" + : "Received " + records.count() + " records"); + if (records == null || records.count() == 0) { + return null; + } + this.remainingCount.set(records.count()); + this.recordsIterator = records.iterator(); + return nextRecord(); + } + catch (WakeupException ex) { + this.logger.debug("Woken"); + if (this.newAssignment) { + this.newAssignment = false; + return pollRecord(); + } return null; } - this.remainingCount.set(records.count()); - this.recordsIterator = records.iterator(); - return nextRecord(); } } } @@ -632,6 +655,8 @@ public void onPartitionsAssigned(Collection partitions) { this.providedRebalanceListener.onPartitionsAssigned(partitions); } } + KafkaMessageSource.this.consumer.wakeup(); + KafkaMessageSource.this.newAssignment = true; } }