Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ public class MqttPahoMessageDrivenChannelAdapter extends AbstractMqttMessageDriv

private boolean manualAcks;

private ApplicationEventPublisher applicationEventPublisher;

private volatile IMqttClient client;

private volatile ScheduledFuture<?> reconnectFuture;
Expand All @@ -94,8 +96,6 @@ public class MqttPahoMessageDrivenChannelAdapter extends AbstractMqttMessageDriv

private volatile ConsumerStopAction consumerStopAction;

private ApplicationEventPublisher applicationEventPublisher;

/**
* Use this constructor for a single url (although it may be overridden if the server
* URI(s) are provided by the {@link MqttConnectOptions#getServerURIs()} provided by
Expand Down Expand Up @@ -311,15 +311,17 @@ private synchronized void connectAndSubscribe() throws MqttException {
this.applicationEventPublisher.publishEvent(new MqttConnectionFailedEvent(this, ex));
}
logger.error(ex, () -> "Error connecting or subscribing to " + Arrays.toString(topics));
this.client.disconnectForcibly(this.disconnectCompletionTimeout);
try {
this.client.setCallback(null);
this.client.close();
}
catch (MqttException e1) {
// NOSONAR
if (this.client != null) { // Could be reset during event handling before
this.client.disconnectForcibly(this.disconnectCompletionTimeout);
try {
this.client.setCallback(null);
this.client.close();
}
catch (MqttException e1) {
// NOSONAR
}
this.client = null;
}
this.client = null;
throw ex;
}
finally {
Expand Down Expand Up @@ -355,25 +357,27 @@ private synchronized void cancelReconnect() {

private synchronized void scheduleReconnect() {
cancelReconnect();
try {
this.reconnectFuture = getTaskScheduler().schedule(() -> {
try {
logger.debug("Attempting reconnect");
synchronized (MqttPahoMessageDrivenChannelAdapter.this) {
if (!MqttPahoMessageDrivenChannelAdapter.this.connected) {
connectAndSubscribe();
MqttPahoMessageDrivenChannelAdapter.this.reconnectFuture = null;
if (isActive()) {
try {
this.reconnectFuture = getTaskScheduler().schedule(() -> {
try {
logger.debug("Attempting reconnect");
synchronized (MqttPahoMessageDrivenChannelAdapter.this) {
if (!MqttPahoMessageDrivenChannelAdapter.this.connected) {
connectAndSubscribe();
MqttPahoMessageDrivenChannelAdapter.this.reconnectFuture = null;
}
}
}
}
catch (MqttException ex) {
logger.error(ex, "Exception while connecting and subscribing");
scheduleReconnect();
}
}, new Date(System.currentTimeMillis() + this.recoveryInterval));
}
catch (Exception ex) {
logger.error(ex, "Failed to schedule reconnect");
catch (MqttException ex) {
logger.error(ex, "Exception while connecting and subscribing");
scheduleReconnect();
}
}, new Date(System.currentTimeMillis() + this.recoveryInterval));
}
catch (Exception ex) {
logger.error(ex, "Failed to schedule reconnect");
}
}
}

Expand Down Expand Up @@ -412,7 +416,7 @@ public void messageArrived(String topic, MqttMessage mqttMessage) {
sendMessage(message);
}
catch (RuntimeException ex) {
logger.error(ex, () -> "Unhandled exception for " + message.toString());
logger.error(ex, () -> "Unhandled exception for " + message);
throw ex;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.willAnswer;
Expand Down Expand Up @@ -53,7 +54,6 @@
import org.assertj.core.api.Condition;
import org.eclipse.paho.client.mqttv3.IMqttAsyncClient;
import org.eclipse.paho.client.mqttv3.IMqttClient;
import org.eclipse.paho.client.mqttv3.IMqttMessageListener;
import org.eclipse.paho.client.mqttv3.IMqttToken;
import org.eclipse.paho.client.mqttv3.MqttAsyncClient;
import org.eclipse.paho.client.mqttv3.MqttCallback;
Expand All @@ -65,6 +65,7 @@
import org.eclipse.paho.client.mqttv3.MqttToken;
import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.internal.stubbing.answers.CallsRealMethods;

Expand Down Expand Up @@ -541,8 +542,7 @@ public void testDifferentQos() throws Exception {
new DirectFieldAccessor(client).setPropertyValue("aClient", aClient);
willAnswer(new CallsRealMethods()).given(client).connect(any(MqttConnectOptions.class));
willAnswer(new CallsRealMethods()).given(client).subscribe(any(String[].class), any(int[].class));
willAnswer(new CallsRealMethods()).given(client).subscribe(any(String[].class), any(int[].class),
(IMqttMessageListener[]) isNull());
willAnswer(new CallsRealMethods()).given(client).subscribe(any(String[].class), any(int[].class), isNull());
willReturn(alwaysComplete).given(aClient).connect(any(MqttConnectOptions.class), any(), any());

IMqttToken token = mock(IMqttToken.class);
Expand Down Expand Up @@ -573,8 +573,51 @@ public void testDifferentQos() throws Exception {
verify(client).disconnectForcibly(5_000L);
}

@Test
public void testNoNPEOnReconnectAndStopRaceCondition() throws Exception {
final IMqttClient client = mock(IMqttClient.class);
MqttPahoMessageDrivenChannelAdapter adapter = buildAdapterIn(client, null, ConsumerStopAction.UNSUBSCRIBE_NEVER);
adapter.setRecoveryInterval(10);

MqttException mqttException = new MqttException(MqttException.REASON_CODE_SUBSCRIBE_FAILED);

willThrow(mqttException)
.given(client)
.subscribe(any(), ArgumentMatchers.<int[]>any());

LogAccessor logger = spy(TestUtils.getPropertyValue(adapter, "logger", LogAccessor.class));
new DirectFieldAccessor(adapter).setPropertyValue("logger", logger);
CountDownLatch exceptionLatch = new CountDownLatch(1);
ArgumentCaptor<MqttException> mqttExceptionArgumentCaptor = ArgumentCaptor.forClass(MqttException.class);
willAnswer(i -> {
exceptionLatch.countDown();
return null;
})
.given(logger)
.error(mqttExceptionArgumentCaptor.capture(), eq("Exception while connecting and subscribing"));

ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler();
taskScheduler.initialize();
adapter.setTaskScheduler(taskScheduler);

adapter.setApplicationEventPublisher(event -> {
if (event instanceof MqttConnectionFailedEvent) {
adapter.destroy();
}
});
adapter.start();

assertThat(exceptionLatch.await(10, TimeUnit.SECONDS)).isTrue();
assertThat(mqttExceptionArgumentCaptor.getValue())
.isNotNull()
.isSameAs(mqttException);

taskScheduler.destroy();
}

private MqttPahoMessageDrivenChannelAdapter buildAdapterIn(final IMqttClient client, Boolean cleanSession,
ConsumerStopAction action) throws MqttException {
ConsumerStopAction action) {

DefaultMqttPahoClientFactory factory = new DefaultMqttPahoClientFactory() {

@Override
Expand Down Expand Up @@ -605,7 +648,7 @@ private MqttPahoMessageHandler buildAdapterOut(final IMqttAsyncClient client) {
DefaultMqttPahoClientFactory factory = new DefaultMqttPahoClientFactory() {

@Override
public IMqttAsyncClient getAsyncClientInstance(String uri, String clientId) throws MqttException {
public IMqttAsyncClient getAsyncClientInstance(String uri, String clientId) {
return client;
}

Expand Down