From 6d2603975cd80cb3b5922827f645f0468d48ff21 Mon Sep 17 00:00:00 2001 From: Gary Russell Date: Wed, 2 Aug 2023 10:40:13 -0400 Subject: [PATCH] GH-2760: Add SmartLifecycle to Producer Factory Resolves https://github.com/spring-projects/spring-kafka/issues/2760 --- .../core/DefaultKafkaProducerFactory.java | 50 +++++++++- .../DefaultKafkaProducerFactoryTests.java | 94 ++++++++++++++++++- 2 files changed, 141 insertions(+), 3 deletions(-) diff --git a/spring-kafka/src/main/java/org/springframework/kafka/core/DefaultKafkaProducerFactory.java b/spring-kafka/src/main/java/org/springframework/kafka/core/DefaultKafkaProducerFactory.java index 4f8812e35e..30ed0f6900 100644 --- a/spring-kafka/src/main/java/org/springframework/kafka/core/DefaultKafkaProducerFactory.java +++ b/spring-kafka/src/main/java/org/springframework/kafka/core/DefaultKafkaProducerFactory.java @@ -23,10 +23,12 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiPredicate; import java.util.function.Supplier; @@ -55,6 +57,7 @@ import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.context.ApplicationListener; +import org.springframework.context.SmartLifecycle; import org.springframework.context.event.ContextStoppedEvent; import org.springframework.core.log.LogAccessor; import org.springframework.kafka.KafkaException; @@ -111,7 +114,7 @@ */ public class DefaultKafkaProducerFactory extends KafkaResourceFactory implements ProducerFactory, ApplicationContextAware, - BeanNameAware, ApplicationListener, DisposableBean { + BeanNameAware, ApplicationListener, DisposableBean, SmartLifecycle { private static final LogAccessor LOGGER = new LogAccessor(LogFactory.getLog(DefaultKafkaProducerFactory.class)); @@ -123,6 +126,8 @@ public class DefaultKafkaProducerFactory extends KafkaResourceFactory private final ThreadLocal> threadBoundProducers = new ThreadLocal<>(); + private final Set> threadBoundProducersAll = ConcurrentHashMap.newKeySet(); + private final AtomicInteger epoch = new AtomicInteger(); private final AtomicInteger clientIdCounter = new AtomicInteger(); @@ -131,6 +136,8 @@ public class DefaultKafkaProducerFactory extends KafkaResourceFactory private final List> postProcessors = new ArrayList<>(); + private final AtomicBoolean running = new AtomicBoolean(); + private Supplier> keySerializerSupplier; private Supplier> valueSerializerSupplier; @@ -519,6 +526,27 @@ public void setMaxAge(Duration maxAge) { this.maxAge = maxAge.toMillis(); } + @Override + public void start() { + this.running.set(true); + } + + @Override + public void stop() { + this.running.set(false); + destroy(); + } + + @Override + public boolean isRunning() { + return this.running.get(); + } + + @Override + public int getPhase() { + return Integer.MIN_VALUE; + } + /** * Copy properties of the instance and the given properties to create a new producer factory. *

If the {@link org.springframework.kafka.core.DefaultKafkaProducerFactory} makes a @@ -677,7 +705,12 @@ public void destroy() { this.producer = null; } if (producerToClose != null) { - producerToClose.closeDelegate(this.physicalCloseTimeout, this.listeners); + try { + producerToClose.closeDelegate(this.physicalCloseTimeout, this.listeners); + } + catch (Exception e) { + LOGGER.error(e, "Exception while closing producer"); + } } this.cache.values().forEach(queue -> { CloseSafeProducer next = queue.poll(); @@ -691,6 +724,16 @@ public void destroy() { next = queue.poll(); } }); + this.cache.clear(); + this.threadBoundProducersAll.forEach(prod -> { + try { + prod.closeDelegate(this.physicalCloseTimeout, this.listeners); + } + catch (Exception e) { + LOGGER.error(e, "Exception while closing producer"); + } + }); + this.threadBoundProducersAll.clear(); this.epoch.incrementAndGet(); } @@ -760,6 +803,7 @@ private Producer getOrCreateThreadBoundProducer() { CloseSafeProducer tlProducer = this.threadBoundProducers.get(); if (tlProducer != null && (tlProducer.closed || this.epoch.get() != tlProducer.epoch || expire(tlProducer))) { closeThreadBoundProducer(); + this.threadBoundProducersAll.remove(tlProducer); tlProducer = null; } if (tlProducer == null) { @@ -769,6 +813,7 @@ private Producer getOrCreateThreadBoundProducer() { listener.producerAdded(tlProducer.clientId, tlProducer); } this.threadBoundProducers.set(tlProducer); + this.threadBoundProducersAll.add(tlProducer); } return tlProducer; } @@ -907,6 +952,7 @@ public void closeThreadBoundProducer() { CloseSafeProducer tlProducer = this.threadBoundProducers.get(); if (tlProducer != null) { this.threadBoundProducers.remove(); + this.threadBoundProducersAll.remove(tlProducer); tlProducer.closeDelegate(this.physicalCloseTimeout, this.listeners); } } diff --git a/spring-kafka/src/test/java/org/springframework/kafka/core/DefaultKafkaProducerFactoryTests.java b/spring-kafka/src/test/java/org/springframework/kafka/core/DefaultKafkaProducerFactoryTests.java index 1d479b6778..2d2288633a 100644 --- a/spring-kafka/src/test/java/org/springframework/kafka/core/DefaultKafkaProducerFactoryTests.java +++ b/spring-kafka/src/test/java/org/springframework/kafka/core/DefaultKafkaProducerFactoryTests.java @@ -36,6 +36,7 @@ import java.util.List; import java.util.Map; import java.util.Queue; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -151,6 +152,35 @@ protected Producer createRawProducer(Map configs) { verify(producer, times(2)).close(any(Duration.class)); } + @Test + @SuppressWarnings({ "rawtypes", "unchecked" }) + void singleLifecycle() throws InterruptedException { + final Producer producer = mock(Producer.class); + DefaultKafkaProducerFactory pf = new DefaultKafkaProducerFactory(new HashMap<>()) { + + @Override + protected Producer createRawProducer(Map configs) { + return producer; + } + + }; + Producer aProducer = pf.createProducer(); + assertThat(aProducer).isNotNull(); + Producer bProducer = pf.createProducer(); + assertThat(bProducer).isSameAs(aProducer); + aProducer.close(ProducerFactoryUtils.DEFAULT_CLOSE_TIMEOUT); + assertThat(KafkaTestUtils.getPropertyValue(pf, "producer")).isNotNull(); + pf.setMaxAge(Duration.ofMillis(10)); + Thread.sleep(50); + aProducer = pf.createProducer(); + assertThat(aProducer).isNotSameAs(bProducer); + Map cache = KafkaTestUtils.getPropertyValue(pf, "cache", Map.class); + assertThat(cache.size()).isEqualTo(0); + pf.stop(); + assertThat(KafkaTestUtils.getPropertyValue(pf, "producer")).isNull(); + verify(producer, times(2)).close(any(Duration.class)); + } + @Test @SuppressWarnings({ "rawtypes", "unchecked" }) void testResetTx() throws Exception { @@ -186,6 +216,42 @@ protected Producer createRawProducer(Map configs) { verify(producer).close(any(Duration.class)); } + @Test + @SuppressWarnings({ "rawtypes", "unchecked" }) + void txLifecycle() throws Exception { + final Producer producer = mock(Producer.class); + ApplicationContext ctx = mock(ApplicationContext.class); + DefaultKafkaProducerFactory pf = new DefaultKafkaProducerFactory(new HashMap<>()) { + + @Override + protected Producer createRawProducer(Map configs) { + return producer; + } + + }; + pf.setApplicationContext(ctx); + pf.setTransactionIdPrefix("foo"); + Producer aProducer = pf.createProducer(); + assertThat(aProducer).isNotNull(); + aProducer.close(); + Producer bProducer = pf.createProducer(); + assertThat(bProducer).isSameAs(aProducer); + bProducer.close(); + assertThat(KafkaTestUtils.getPropertyValue(pf, "producer")).isNull(); + Map cache = KafkaTestUtils.getPropertyValue(pf, "cache", Map.class); + assertThat(cache.size()).isEqualTo(1); + Queue queue = (Queue) cache.get("foo"); + assertThat(queue.size()).isEqualTo(1); + pf.setMaxAge(Duration.ofMillis(10)); + Thread.sleep(50); + aProducer = pf.createProducer(); + assertThat(aProducer).isNotSameAs(bProducer); + pf.stop(); + assertThat(queue.size()).isEqualTo(0); + assertThat(cache.size()).isEqualTo(0); + verify(producer).close(any(Duration.class)); + } + @Test @SuppressWarnings({ "rawtypes", "unchecked" }) void dontReturnToCacheAfterReset() { @@ -255,6 +321,32 @@ protected Producer createKafkaProducer() { verify(producer, times(3)).close(any(Duration.class)); } + @Test + @SuppressWarnings({ "rawtypes", "unchecked" }) + void threadLocalLifecycle() throws InterruptedException { + final Producer producer = mock(Producer.class); + AtomicBoolean created = new AtomicBoolean(); + DefaultKafkaProducerFactory pf = new DefaultKafkaProducerFactory(new HashMap<>()) { + + @Override + protected Producer createKafkaProducer() { + assertThat(created.get()).isFalse(); + created.set(true); + return producer; + } + + }; + pf.setProducerPerThread(true); + Producer aProducer = pf.createProducer(); + assertThat(aProducer).isNotNull(); + aProducer.close(); + assertThat(KafkaTestUtils.getPropertyValue(pf, "producer")).isNull(); + assertThat(KafkaTestUtils.getPropertyValue(pf, "threadBoundProducers", ThreadLocal.class).get()).isNotNull(); + pf.stop(); + assertThat(KafkaTestUtils.getPropertyValue(pf, "threadBoundProducersAll", Set.class)).hasSize(0); + verify(producer).close(any(Duration.class)); + } + @Test @SuppressWarnings({ "rawtypes", "unchecked" }) void testThreadLocalReset() { @@ -281,7 +373,7 @@ protected Producer createKafkaProducer() { bProducer = pf.createProducer(); assertThat(bProducer).isNotSameAs(aProducer); bProducer.close(); - verify(producer1).close(any(Duration.class)); + verify(producer1, times(2)).close(any(Duration.class)); } @Test