diff --git a/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/AbstractConnectionFactory.java b/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/AbstractConnectionFactory.java index 0b104aa8719..2408204845a 100644 --- a/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/AbstractConnectionFactory.java +++ b/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/AbstractConnectionFactory.java @@ -563,10 +563,15 @@ public void stop() { synchronized (this.connections) { Iterator> iterator = this.connections.entrySet().iterator(); while (iterator.hasNext()) { - TcpConnection connection = iterator.next().getValue(); + TcpConnectionSupport connection = iterator.next().getValue(); connection.close(); iterator.remove(); - getSenders().forEach(sender -> sender.removeDeadConnection(connection)); + if (connection instanceof TcpConnectionInterceptorSupport) { + ((TcpConnectionInterceptorSupport) connection).removeDeadConnection(connection); + } + else { + connection.getSenders().forEach(sender -> sender.removeDeadConnection(connection)); + } } } synchronized (this.lifecycleMonitor) { @@ -866,7 +871,12 @@ private List removeClosedConnectionsAndReturnOpenConnectionIds() { TcpConnectionSupport connection = entry.getValue(); if (!connection.isOpen()) { iterator.remove(); - getSenders().forEach(sender -> sender.removeDeadConnection(connection)); + if (connection instanceof TcpConnectionInterceptorSupport) { + ((TcpConnectionInterceptorSupport) connection).removeDeadConnection(connection); + } + else { + connection.getSenders().forEach(sender -> sender.removeDeadConnection(connection)); + } logger.debug(() -> getComponentName() + ": Removed closed connection: " + connection.getConnectionId()); } @@ -944,7 +954,7 @@ public boolean closeConnection(String connectionId) { try { connection.close(); closed = true; - getSenders().forEach(sender -> sender.removeDeadConnection(connection)); + connection.getSenders().forEach(sender -> sender.removeDeadConnection(connection)); } catch (Exception ex) { logger.debug(ex, () -> "Failed to close connection " + connectionId); diff --git a/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/TcpConnectionInterceptorSupport.java b/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/TcpConnectionInterceptorSupport.java index 05867f56695..f3d649625ee 100644 --- a/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/TcpConnectionInterceptorSupport.java +++ b/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/TcpConnectionInterceptorSupport.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,15 @@ package org.springframework.integration.ip.tcp.connection; +import java.util.Collections; +import java.util.List; + import javax.net.ssl.SSLSession; import org.springframework.context.ApplicationEventPublisher; import org.springframework.core.serializer.Deserializer; import org.springframework.core.serializer.Serializer; +import org.springframework.lang.Nullable; import org.springframework.messaging.Message; import org.springframework.messaging.support.ErrorMessage; @@ -38,9 +42,12 @@ public abstract class TcpConnectionInterceptorSupport extends TcpConnectionSuppo private TcpListener tcpListener; - private TcpSender tcpSender; + private boolean realSender; + + private List interceptedSenders; + + private boolean removed; - private Boolean realSender; public TcpConnectionInterceptorSupport() { } @@ -92,10 +99,29 @@ public void registerListener(TcpListener listener) { @Override public void registerSender(TcpSender sender) { - this.tcpSender = sender; this.theConnection.registerSender(this); } + @Override + public void registerSenders(List sendersToRegister) { + this.interceptedSenders = sendersToRegister; + if (sendersToRegister.size() > 0) { + if (!(sendersToRegister.get(0) instanceof TcpConnectionInterceptorSupport)) { + this.realSender = true; + } + else { + this.realSender = ((TcpConnectionInterceptorSupport) this.interceptedSenders.get(0)) + .hasRealSender(); + } + } + if (this.theConnection instanceof TcpConnectionInterceptorSupport) { + this.theConnection.registerSenders(Collections.singletonList(this)); + } + else { + super.registerSender(this); + } + } + /** * {@inheritDoc} *

@@ -198,21 +224,30 @@ public void setTheConnection(TcpConnectionSupport theConnection) { * @return the listener */ @Override + @Nullable public TcpListener getListener() { return this.tcpListener; } @Override public void addNewConnection(TcpConnection connection) { - if (this.tcpSender != null) { - this.tcpSender.addNewConnection(this); + if (this.interceptedSenders != null) { + this.interceptedSenders.forEach(sender -> sender.addNewConnection(connection)); } } @Override - public void removeDeadConnection(TcpConnection connection) { - if (this.tcpSender != null) { - this.tcpSender.removeDeadConnection(this); + public synchronized void removeDeadConnection(TcpConnection connection) { + if (this.removed) { + return; + } + this.removed = true; + if (this.theConnection instanceof TcpConnectionInterceptorSupport && !this.theConnection.equals(this)) { + ((TcpConnectionInterceptorSupport) this.theConnection).removeDeadConnection(this); + } + TcpSender sender = getSender(); + if (sender != null && !(sender instanceof TcpConnectionInterceptorSupport)) { + this.interceptedSenders.forEach(snder -> snder.removeDeadConnection(connection)); } } @@ -222,19 +257,21 @@ public long incrementAndGetConnectionSequence() { } @Override + @Nullable public TcpSender getSender() { - return this.tcpSender; + return this.interceptedSenders != null && this.interceptedSenders.size() > 0 + ? this.interceptedSenders.get(0) + : null; + } + + @Override + public List getSenders() { + return this.interceptedSenders == null + ? super.getSenders() + : Collections.unmodifiableList(this.interceptedSenders); } protected boolean hasRealSender() { - if (this.realSender != null) { - return this.realSender; - } - TcpSender sender = getSender(); - while (sender instanceof TcpConnectionInterceptorSupport) { - sender = ((TcpConnectionInterceptorSupport) sender).getSender(); - } - this.realSender = sender != null; return this.realSender; } diff --git a/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/TcpConnectionSupport.java b/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/TcpConnectionSupport.java index 3c403422e4d..c9b218c5f95 100644 --- a/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/TcpConnectionSupport.java +++ b/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/TcpConnectionSupport.java @@ -1,5 +1,5 @@ /* - * Copyright 2001-2021 the original author or authors. + * Copyright 2001-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -199,9 +199,7 @@ protected void closeConnection(boolean isException) { outerListener = nextListener; } outerListener.close(); - for (TcpSender sender : getSenders()) { - sender.removeDeadConnection(outerListener); - } + outerListener.removeDeadConnection(outerListener); if (isException) { // ensure physical close in case the interceptor did not close this.close(); @@ -337,6 +335,7 @@ public void registerSenders(List sendersToRegister) { * @return the listener */ @Override + @Nullable public TcpListener getListener() { if (this.needsTest && this.testListener != null) { this.needsTest = false; diff --git a/spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/connection/TcpSenderTests.java b/spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/connection/TcpSenderTests.java index b34a48292d1..c2b980ea2b1 100644 --- a/spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/connection/TcpSenderTests.java +++ b/spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/connection/TcpSenderTests.java @@ -18,8 +18,13 @@ import static org.assertj.core.api.Assertions.assertThat; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import org.junit.jupiter.api.Test; @@ -69,26 +74,55 @@ void senderCalledForDeadConnectionClientNio() throws InterruptedException { private void senderCalledForDeadConnectionClient(AbstractClientConnectionFactory client) throws InterruptedException { CountDownLatch adds = new CountDownLatch(2); CountDownLatch removes = new CountDownLatch(2); + CountDownLatch interceptorAddCalled = new CountDownLatch(6); + CountDownLatch interceptorRemCalled = new CountDownLatch(6); TcpConnectionInterceptorFactoryChain chain = new TcpConnectionInterceptorFactoryChain(); - chain.setInterceptor(new HelloWorldInterceptorFactory() { + AtomicInteger instances = new AtomicInteger(); + List addOrder = Collections.synchronizedList(new ArrayList<>()); + List remOrder = Collections.synchronizedList(new ArrayList<>()); + AtomicReference thread = new AtomicReference<>(); + class InterceptorFactory extends HelloWorldInterceptorFactory { @Override public TcpConnectionInterceptorSupport getInterceptor() { return new TcpConnectionInterceptorSupport() { + + private final int instance = instances.incrementAndGet(); + + @Override + public void addNewConnection(TcpConnection connection) { + addOrder.add(this.instance); + interceptorAddCalled.countDown(); + super.addNewConnection(connection); + } + + @Override + public synchronized void removeDeadConnection(TcpConnection connection) { + super.removeDeadConnection(connection); + // can be called multiple times on different threads. + if (!remOrder.contains(this.instance)) { + remOrder.add(this.instance); + interceptorRemCalled.countDown(); + } + } + }; } - }); + } + chain.setInterceptor(new InterceptorFactory(), new InterceptorFactory(), new InterceptorFactory()); client.setInterceptorFactoryChain(chain); client.registerSender(new TcpSender() { @Override public void addNewConnection(TcpConnection connection) { + addOrder.add(99); adds.countDown(); } @Override - public void removeDeadConnection(TcpConnection connection) { + public synchronized void removeDeadConnection(TcpConnection connection) { + remOrder.add(99); removes.countDown(); } @@ -97,12 +131,18 @@ public void removeDeadConnection(TcpConnection connection) { client.afterPropertiesSet(); client.start(); TcpConnectionSupport conn = client.getConnection(); + assertThat(((TcpConnectionInterceptorSupport) conn).hasRealSender()).isTrue(); conn.close(); conn = client.getConnection(); assertThat(adds.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(addOrder).containsExactly(1, 2, 3, 99, 4, 5, 6, 99); conn.close(); client.stop(); assertThat(removes.await(10, TimeUnit.SECONDS)).isTrue(); + // 9x before 3, 6 due to ordering in overridden interceptor method + assertThat(remOrder).containsExactly(1, 2, 99, 3, 4, 5, 99, 6); + assertThat(interceptorAddCalled.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(interceptorRemCalled.await(10, TimeUnit.SECONDS)).isTrue(); } }