Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2001-2022 the original author or authors.
* Copyright 2001-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -59,7 +60,6 @@
* <p>
* {@link org.springframework.context.Lifecycle} methods delegate to the underlying {@link AbstractConnectionFactory}.
*
*
* @author Gary Russell
* @author Artem Bilan
*
Expand Down Expand Up @@ -223,7 +223,17 @@ protected Object handleRequestMessage(Message<?> requestMessage) {
this.pendingReplies.put(connectionId, reply);
String connectionIdToLog = connectionId;
logger.debug(() -> "Added pending reply " + connectionIdToLog);
connection.send(requestMessage);
try {
connection.send(requestMessage);
}
catch (Exception ex) {
// If it cannot send, then no reply for this connection.
// Therefor release resources for subsequent requests.
if (async) {
cleanUp(haveSemaphore, connection, connectionId);
}
throw ex;
}
if (this.closeStreamAfterSend) {
connection.shutdownOutput();
}
Expand Down Expand Up @@ -326,7 +336,7 @@ public boolean onMessage(Message<?> message) {
if (reply == null) {
if (message instanceof ErrorMessage) {
/*
* Socket errors are sent here so they can be conveyed to any waiting thread.
* Socket errors are sent here, so they can be conveyed to any waiting thread.
* If there's not one, simply ignore.
*/
return false;
Expand Down Expand Up @@ -427,7 +437,11 @@ private final class AsyncReply {

private final boolean haveSemaphore;

private final CompletableFuture<Message<?>> future = new CompletableFuture<>();
private final ScheduledFuture<?> noResponseFuture;

private final CompletableFuture<Message<?>> future =
new CompletableFuture<Message<?>>()
.thenApply(this::cancelNoResponseFutureIfAny);

private volatile Message<?> reply;

Expand All @@ -440,13 +454,25 @@ private final class AsyncReply {
this.connection = connection;
this.haveSemaphore = haveSemaphore;
if (async && remoteTimeout > 0) {
getTaskScheduler()
.schedule(() -> {
TcpOutboundGateway.this.pendingReplies.remove(connection.getConnectionId());
this.future.completeExceptionally(
new MessageTimeoutException(requestMessage, "Timed out waiting for response"));
}, Instant.now().plusMillis(remoteTimeout));
this.noResponseFuture =
getTaskScheduler()
.schedule(() -> {
cleanUp(this.haveSemaphore, this.connection, this.connection.getConnectionId());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's a new race here; we could end up releasing the semaphore twice because onMessage calls cleanup before completing the future.

Perhaps check if pendingReplies was actually removed in cleanup before releasing the semaphore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See some fixed I have pushed.

this.future.completeExceptionally(
new MessageTimeoutException(requestMessage,
"Timed out waiting for response"));
}, Instant.now().plusMillis(remoteTimeout));
}
else {
this.noResponseFuture = null;
}
}

private Message<?> cancelNoResponseFutureIfAny(Message<?> message) {
if (this.noResponseFuture != null) {
this.noResponseFuture.cancel(true);
}
return message;
}

TcpConnection getConnection() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.springframework.integration.ip.tcp.connection.AbstractClientConnectionFactory;
import org.springframework.integration.ip.tcp.connection.CachingClientConnectionFactory;
import org.springframework.integration.ip.tcp.connection.FailoverClientConnectionFactory;
import org.springframework.integration.ip.tcp.connection.TcpConnection;
import org.springframework.integration.ip.tcp.connection.TcpConnectionSupport;
import org.springframework.integration.ip.tcp.connection.TcpNetClientConnectionFactory;
import org.springframework.integration.ip.tcp.connection.TcpNioClientConnectionFactory;
Expand All @@ -80,9 +81,14 @@
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.catchThrowable;
import static org.assertj.core.api.Assertions.fail;
import static org.awaitility.Awaitility.await;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.willReturn;
import static org.mockito.BDDMockito.willThrow;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -397,7 +403,7 @@ private void testGoodNetGWTimeoutGuts(AbstractClientConnectionFactory ccf,

Expression remoteTimeoutExpression = Mockito.mock(Expression.class);

when(remoteTimeoutExpression.getValue(Mockito.any(EvaluationContext.class), Mockito.any(Message.class),
when(remoteTimeoutExpression.getValue(any(EvaluationContext.class), any(Message.class),
Mockito.eq(Long.class))).thenReturn(50L, 60000L);

gateway.setRemoteTimeoutExpression(remoteTimeoutExpression);
Expand Down Expand Up @@ -488,7 +494,7 @@ void testCachingFailover() throws Exception {
TcpConnectionSupport mockConn1 = makeMockConnection();
when(factory1.getConnection()).thenReturn(mockConn1);
doThrow(new UncheckedIOException(new IOException("fail")))
.when(mockConn1).send(Mockito.any(Message.class));
.when(mockConn1).send(any(Message.class));

AbstractClientConnectionFactory factory2 = new TcpNetClientConnectionFactory("localhost",
serverSocket.get().getLocalPort());
Expand Down Expand Up @@ -521,7 +527,7 @@ void testCachingFailover() throws Exception {
assertThat(reply.getPayload()).isEqualTo("bar");
done.set(true);
gateway.stop();
verify(mockConn1).send(Mockito.any(Message.class));
verify(mockConn1).send(any(Message.class));
factory2.stop();
serverSocket.get().close();
}
Expand Down Expand Up @@ -571,7 +577,7 @@ void testFailoverCached() throws Exception {
when(factory1.getConnection()).thenReturn(mockConn1);
when(factory1.isSingleUse()).thenReturn(true);
doThrow(new UncheckedIOException(new IOException("fail")))
.when(mockConn1).send(Mockito.any(Message.class));
.when(mockConn1).send(any(Message.class));
CachingClientConnectionFactory cachingFactory1 = new CachingClientConnectionFactory(factory1, 1);

AbstractClientConnectionFactory factory2 = new TcpNetClientConnectionFactory("localhost",
Expand Down Expand Up @@ -606,7 +612,7 @@ void testFailoverCached() throws Exception {
assertThat(reply.getPayload()).isEqualTo("bar");
done.set(true);
gateway.stop();
verify(mockConn1).send(Mockito.any(Message.class));
verify(mockConn1).send(any(Message.class));
factory2.stop();
serverSocket.get().close();
}
Expand Down Expand Up @@ -1081,4 +1087,37 @@ void testAsyncTimeout() throws Exception {
}
}

@Test
void semaphoreIsReleasedOnAsyncSendFailure() throws InterruptedException {
AbstractClientConnectionFactory ccf = mock(AbstractClientConnectionFactory.class);

TcpConnection connection = mock(TcpConnectionSupport.class);

given(connection.getConnectionId()).willReturn("testId");
willThrow(new RuntimeException("intentional"))
.given(connection)
.send(any(Message.class));

willReturn(connection)
.given(ccf)
.getConnection();

TcpOutboundGateway gateway = new TcpOutboundGateway();
gateway.setConnectionFactory(ccf);
gateway.setAsync(true);
gateway.setBeanFactory(mock(BeanFactory.class));
gateway.setRemoteTimeout(-1);
gateway.afterPropertiesSet();

assertThatExceptionOfType(MessageHandlingException.class)
.isThrownBy(() -> gateway.handleMessage(new GenericMessage<>("Test1")))
.withCauseExactlyInstanceOf(RuntimeException.class)
.withStackTraceContaining("intentional");

assertThatExceptionOfType(MessageHandlingException.class)
.isThrownBy(() -> gateway.handleMessage(new GenericMessage<>("Test2")))
.withCauseExactlyInstanceOf(RuntimeException.class)
.withStackTraceContaining("intentional");
}

}