Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiPredicate;
import java.util.function.Supplier;
Expand Down Expand Up @@ -55,6 +57,7 @@
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.ApplicationListener;
import org.springframework.context.SmartLifecycle;
import org.springframework.context.event.ContextStoppedEvent;
import org.springframework.core.log.LogAccessor;
import org.springframework.kafka.KafkaException;
Expand Down Expand Up @@ -111,7 +114,7 @@
*/
public class DefaultKafkaProducerFactory<K, V> extends KafkaResourceFactory
implements ProducerFactory<K, V>, ApplicationContextAware,
BeanNameAware, ApplicationListener<ContextStoppedEvent>, DisposableBean {
BeanNameAware, ApplicationListener<ContextStoppedEvent>, DisposableBean, SmartLifecycle {

private static final LogAccessor LOGGER = new LogAccessor(LogFactory.getLog(DefaultKafkaProducerFactory.class));

Expand All @@ -123,6 +126,8 @@ public class DefaultKafkaProducerFactory<K, V> extends KafkaResourceFactory

private final ThreadLocal<CloseSafeProducer<K, V>> threadBoundProducers = new ThreadLocal<>();

private final Set<CloseSafeProducer<K, V>> threadBoundProducersAll = ConcurrentHashMap.newKeySet();

private final AtomicInteger epoch = new AtomicInteger();

private final AtomicInteger clientIdCounter = new AtomicInteger();
Expand All @@ -131,6 +136,8 @@ public class DefaultKafkaProducerFactory<K, V> extends KafkaResourceFactory

private final List<ProducerPostProcessor<K, V>> postProcessors = new ArrayList<>();

private final AtomicBoolean running = new AtomicBoolean();

private Supplier<Serializer<K>> keySerializerSupplier;

private Supplier<Serializer<V>> valueSerializerSupplier;
Expand Down Expand Up @@ -519,6 +526,27 @@ public void setMaxAge(Duration maxAge) {
this.maxAge = maxAge.toMillis();
}

@Override
public void start() {
this.running.set(true);
}

@Override
public void stop() {
this.running.set(false);
destroy();
}

@Override
public boolean isRunning() {
return this.running.get();
}

@Override
public int getPhase() {
return Integer.MIN_VALUE;
}

/**
* Copy properties of the instance and the given properties to create a new producer factory.
* <p>If the {@link org.springframework.kafka.core.DefaultKafkaProducerFactory} makes a
Expand Down Expand Up @@ -677,7 +705,12 @@ public void destroy() {
this.producer = null;
}
if (producerToClose != null) {
producerToClose.closeDelegate(this.physicalCloseTimeout, this.listeners);
try {
producerToClose.closeDelegate(this.physicalCloseTimeout, this.listeners);
}
catch (Exception e) {
LOGGER.error(e, "Exception while closing producer");
}
}
this.cache.values().forEach(queue -> {
CloseSafeProducer<K, V> next = queue.poll();
Expand All @@ -691,6 +724,16 @@ public void destroy() {
next = queue.poll();
}
});
this.cache.clear();
this.threadBoundProducersAll.forEach(prod -> {
try {
prod.closeDelegate(this.physicalCloseTimeout, this.listeners);
}
catch (Exception e) {
LOGGER.error(e, "Exception while closing producer");
}
});
this.threadBoundProducersAll.clear();
this.epoch.incrementAndGet();
}

Expand Down Expand Up @@ -760,6 +803,7 @@ private Producer<K, V> getOrCreateThreadBoundProducer() {
CloseSafeProducer<K, V> tlProducer = this.threadBoundProducers.get();
if (tlProducer != null && (tlProducer.closed || this.epoch.get() != tlProducer.epoch || expire(tlProducer))) {
closeThreadBoundProducer();
this.threadBoundProducersAll.remove(tlProducer);
tlProducer = null;
}
if (tlProducer == null) {
Expand All @@ -769,6 +813,7 @@ private Producer<K, V> getOrCreateThreadBoundProducer() {
listener.producerAdded(tlProducer.clientId, tlProducer);
}
this.threadBoundProducers.set(tlProducer);
this.threadBoundProducersAll.add(tlProducer);
}
return tlProducer;
}
Expand Down Expand Up @@ -907,6 +952,7 @@ public void closeThreadBoundProducer() {
CloseSafeProducer<K, V> tlProducer = this.threadBoundProducers.get();
if (tlProducer != null) {
this.threadBoundProducers.remove();
this.threadBoundProducersAll.remove(tlProducer);
tlProducer.closeDelegate(this.physicalCloseTimeout, this.listeners);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -151,6 +152,35 @@ protected Producer createRawProducer(Map configs) {
verify(producer, times(2)).close(any(Duration.class));
}

@Test
@SuppressWarnings({ "rawtypes", "unchecked" })
void singleLifecycle() throws InterruptedException {
final Producer producer = mock(Producer.class);
DefaultKafkaProducerFactory pf = new DefaultKafkaProducerFactory(new HashMap<>()) {

@Override
protected Producer createRawProducer(Map configs) {
return producer;
}

};
Producer aProducer = pf.createProducer();
assertThat(aProducer).isNotNull();
Producer bProducer = pf.createProducer();
assertThat(bProducer).isSameAs(aProducer);
aProducer.close(ProducerFactoryUtils.DEFAULT_CLOSE_TIMEOUT);
assertThat(KafkaTestUtils.getPropertyValue(pf, "producer")).isNotNull();
pf.setMaxAge(Duration.ofMillis(10));
Thread.sleep(50);
aProducer = pf.createProducer();
assertThat(aProducer).isNotSameAs(bProducer);
Map<?, ?> cache = KafkaTestUtils.getPropertyValue(pf, "cache", Map.class);
assertThat(cache.size()).isEqualTo(0);
pf.stop();
assertThat(KafkaTestUtils.getPropertyValue(pf, "producer")).isNull();
verify(producer, times(2)).close(any(Duration.class));
}

@Test
@SuppressWarnings({ "rawtypes", "unchecked" })
void testResetTx() throws Exception {
Expand Down Expand Up @@ -186,6 +216,42 @@ protected Producer createRawProducer(Map configs) {
verify(producer).close(any(Duration.class));
}

@Test
@SuppressWarnings({ "rawtypes", "unchecked" })
void txLifecycle() throws Exception {
final Producer producer = mock(Producer.class);
ApplicationContext ctx = mock(ApplicationContext.class);
DefaultKafkaProducerFactory pf = new DefaultKafkaProducerFactory(new HashMap<>()) {

@Override
protected Producer createRawProducer(Map configs) {
return producer;
}

};
pf.setApplicationContext(ctx);
pf.setTransactionIdPrefix("foo");
Producer aProducer = pf.createProducer();
assertThat(aProducer).isNotNull();
aProducer.close();
Producer bProducer = pf.createProducer();
assertThat(bProducer).isSameAs(aProducer);
bProducer.close();
assertThat(KafkaTestUtils.getPropertyValue(pf, "producer")).isNull();
Map<?, ?> cache = KafkaTestUtils.getPropertyValue(pf, "cache", Map.class);
assertThat(cache.size()).isEqualTo(1);
Queue queue = (Queue) cache.get("foo");
assertThat(queue.size()).isEqualTo(1);
pf.setMaxAge(Duration.ofMillis(10));
Thread.sleep(50);
aProducer = pf.createProducer();
assertThat(aProducer).isNotSameAs(bProducer);
pf.stop();
assertThat(queue.size()).isEqualTo(0);
assertThat(cache.size()).isEqualTo(0);
verify(producer).close(any(Duration.class));
}

@Test
@SuppressWarnings({ "rawtypes", "unchecked" })
void dontReturnToCacheAfterReset() {
Expand Down Expand Up @@ -255,6 +321,32 @@ protected Producer createKafkaProducer() {
verify(producer, times(3)).close(any(Duration.class));
}

@Test
@SuppressWarnings({ "rawtypes", "unchecked" })
void threadLocalLifecycle() throws InterruptedException {
final Producer producer = mock(Producer.class);
AtomicBoolean created = new AtomicBoolean();
DefaultKafkaProducerFactory pf = new DefaultKafkaProducerFactory(new HashMap<>()) {

@Override
protected Producer createKafkaProducer() {
assertThat(created.get()).isFalse();
created.set(true);
return producer;
}

};
pf.setProducerPerThread(true);
Producer aProducer = pf.createProducer();
assertThat(aProducer).isNotNull();
aProducer.close();
assertThat(KafkaTestUtils.getPropertyValue(pf, "producer")).isNull();
assertThat(KafkaTestUtils.getPropertyValue(pf, "threadBoundProducers", ThreadLocal.class).get()).isNotNull();
pf.stop();
assertThat(KafkaTestUtils.getPropertyValue(pf, "threadBoundProducersAll", Set.class)).hasSize(0);
verify(producer).close(any(Duration.class));
}

@Test
@SuppressWarnings({ "rawtypes", "unchecked" })
void testThreadLocalReset() {
Expand All @@ -281,7 +373,7 @@ protected Producer createKafkaProducer() {
bProducer = pf.createProducer();
assertThat(bProducer).isNotSameAs(aProducer);
bProducer.close();
verify(producer1).close(any(Duration.class));
verify(producer1, times(2)).close(any(Duration.class));
}

@Test
Expand Down