From 36a78004d02ea29cdc7fd126971ece423d89cb67 Mon Sep 17 00:00:00 2001 From: Artem Bilan Date: Thu, 9 Sep 2021 14:17:57 -0400 Subject: [PATCH] GH-3627: Fix race condition NPE in MqttPahoMDCA Fixes https://github.com/spring-projects/spring-integration/issues/3627 The `destroy()`, and therefore `stop()` could be called from the `MqttConnectionFailedEvent` handling in the same thread resetting the `client` property to `null`. * Check for `this.client != null` in the next block of the `connectAndSubscribe()` to avoid NPE * Check for `isActive()` in the `scheduleReconnect()` to be sure do not reconnect if channel adapter has been stopped already **Cherry-pick to `5.4.x`** --- .../MqttPahoMessageDrivenChannelAdapter.java | 60 ++++++++++--------- .../integration/mqtt/MqttAdapterTests.java | 53 ++++++++++++++-- 2 files changed, 80 insertions(+), 33 deletions(-) diff --git a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/MqttPahoMessageDrivenChannelAdapter.java b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/MqttPahoMessageDrivenChannelAdapter.java index 5fec854ef04..d7b99d0f317 100644 --- a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/MqttPahoMessageDrivenChannelAdapter.java +++ b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/MqttPahoMessageDrivenChannelAdapter.java @@ -84,6 +84,8 @@ public class MqttPahoMessageDrivenChannelAdapter extends AbstractMqttMessageDriv private boolean manualAcks; + private ApplicationEventPublisher applicationEventPublisher; + private volatile IMqttClient client; private volatile ScheduledFuture reconnectFuture; @@ -94,8 +96,6 @@ public class MqttPahoMessageDrivenChannelAdapter extends AbstractMqttMessageDriv private volatile ConsumerStopAction consumerStopAction; - private ApplicationEventPublisher applicationEventPublisher; - /** * Use this constructor for a single url (although it may be overridden if the server * URI(s) are provided by the {@link MqttConnectOptions#getServerURIs()} provided by @@ -311,15 +311,17 @@ private synchronized void connectAndSubscribe() throws MqttException { this.applicationEventPublisher.publishEvent(new MqttConnectionFailedEvent(this, ex)); } logger.error(ex, () -> "Error connecting or subscribing to " + Arrays.toString(topics)); - this.client.disconnectForcibly(this.disconnectCompletionTimeout); - try { - this.client.setCallback(null); - this.client.close(); - } - catch (MqttException e1) { - // NOSONAR + if (this.client != null) { // Could be reset during event handling before + this.client.disconnectForcibly(this.disconnectCompletionTimeout); + try { + this.client.setCallback(null); + this.client.close(); + } + catch (MqttException e1) { + // NOSONAR + } + this.client = null; } - this.client = null; throw ex; } finally { @@ -355,25 +357,27 @@ private synchronized void cancelReconnect() { private synchronized void scheduleReconnect() { cancelReconnect(); - try { - this.reconnectFuture = getTaskScheduler().schedule(() -> { - try { - logger.debug("Attempting reconnect"); - synchronized (MqttPahoMessageDrivenChannelAdapter.this) { - if (!MqttPahoMessageDrivenChannelAdapter.this.connected) { - connectAndSubscribe(); - MqttPahoMessageDrivenChannelAdapter.this.reconnectFuture = null; + if (isActive()) { + try { + this.reconnectFuture = getTaskScheduler().schedule(() -> { + try { + logger.debug("Attempting reconnect"); + synchronized (MqttPahoMessageDrivenChannelAdapter.this) { + if (!MqttPahoMessageDrivenChannelAdapter.this.connected) { + connectAndSubscribe(); + MqttPahoMessageDrivenChannelAdapter.this.reconnectFuture = null; + } } } - } - catch (MqttException ex) { - logger.error(ex, "Exception while connecting and subscribing"); - scheduleReconnect(); - } - }, new Date(System.currentTimeMillis() + this.recoveryInterval)); - } - catch (Exception ex) { - logger.error(ex, "Failed to schedule reconnect"); + catch (MqttException ex) { + logger.error(ex, "Exception while connecting and subscribing"); + scheduleReconnect(); + } + }, new Date(System.currentTimeMillis() + this.recoveryInterval)); + } + catch (Exception ex) { + logger.error(ex, "Failed to schedule reconnect"); + } } } @@ -412,7 +416,7 @@ public void messageArrived(String topic, MqttMessage mqttMessage) { sendMessage(message); } catch (RuntimeException ex) { - logger.error(ex, () -> "Unhandled exception for " + message.toString()); + logger.error(ex, () -> "Unhandled exception for " + message); throw ex; } } diff --git a/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/MqttAdapterTests.java b/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/MqttAdapterTests.java index 87fad048bd9..9569fefc1d1 100644 --- a/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/MqttAdapterTests.java +++ b/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/MqttAdapterTests.java @@ -22,6 +22,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.willAnswer; @@ -53,7 +54,6 @@ import org.assertj.core.api.Condition; import org.eclipse.paho.client.mqttv3.IMqttAsyncClient; import org.eclipse.paho.client.mqttv3.IMqttClient; -import org.eclipse.paho.client.mqttv3.IMqttMessageListener; import org.eclipse.paho.client.mqttv3.IMqttToken; import org.eclipse.paho.client.mqttv3.MqttAsyncClient; import org.eclipse.paho.client.mqttv3.MqttCallback; @@ -65,6 +65,7 @@ import org.eclipse.paho.client.mqttv3.MqttToken; import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; import org.mockito.internal.stubbing.answers.CallsRealMethods; @@ -541,8 +542,7 @@ public void testDifferentQos() throws Exception { new DirectFieldAccessor(client).setPropertyValue("aClient", aClient); willAnswer(new CallsRealMethods()).given(client).connect(any(MqttConnectOptions.class)); willAnswer(new CallsRealMethods()).given(client).subscribe(any(String[].class), any(int[].class)); - willAnswer(new CallsRealMethods()).given(client).subscribe(any(String[].class), any(int[].class), - (IMqttMessageListener[]) isNull()); + willAnswer(new CallsRealMethods()).given(client).subscribe(any(String[].class), any(int[].class), isNull()); willReturn(alwaysComplete).given(aClient).connect(any(MqttConnectOptions.class), any(), any()); IMqttToken token = mock(IMqttToken.class); @@ -573,8 +573,51 @@ public void testDifferentQos() throws Exception { verify(client).disconnectForcibly(5_000L); } + @Test + public void testNoNPEOnReconnectAndStopRaceCondition() throws Exception { + final IMqttClient client = mock(IMqttClient.class); + MqttPahoMessageDrivenChannelAdapter adapter = buildAdapterIn(client, null, ConsumerStopAction.UNSUBSCRIBE_NEVER); + adapter.setRecoveryInterval(10); + + MqttException mqttException = new MqttException(MqttException.REASON_CODE_SUBSCRIBE_FAILED); + + willThrow(mqttException) + .given(client) + .subscribe(any(), ArgumentMatchers.any()); + + LogAccessor logger = spy(TestUtils.getPropertyValue(adapter, "logger", LogAccessor.class)); + new DirectFieldAccessor(adapter).setPropertyValue("logger", logger); + CountDownLatch exceptionLatch = new CountDownLatch(1); + ArgumentCaptor mqttExceptionArgumentCaptor = ArgumentCaptor.forClass(MqttException.class); + willAnswer(i -> { + exceptionLatch.countDown(); + return null; + }) + .given(logger) + .error(mqttExceptionArgumentCaptor.capture(), eq("Exception while connecting and subscribing")); + + ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler(); + taskScheduler.initialize(); + adapter.setTaskScheduler(taskScheduler); + + adapter.setApplicationEventPublisher(event -> { + if (event instanceof MqttConnectionFailedEvent) { + adapter.destroy(); + } + }); + adapter.start(); + + assertThat(exceptionLatch.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(mqttExceptionArgumentCaptor.getValue()) + .isNotNull() + .isSameAs(mqttException); + + taskScheduler.destroy(); + } + private MqttPahoMessageDrivenChannelAdapter buildAdapterIn(final IMqttClient client, Boolean cleanSession, - ConsumerStopAction action) throws MqttException { + ConsumerStopAction action) { + DefaultMqttPahoClientFactory factory = new DefaultMqttPahoClientFactory() { @Override @@ -605,7 +648,7 @@ private MqttPahoMessageHandler buildAdapterOut(final IMqttAsyncClient client) { DefaultMqttPahoClientFactory factory = new DefaultMqttPahoClientFactory() { @Override - public IMqttAsyncClient getAsyncClientInstance(String uri, String clientId) throws MqttException { + public IMqttAsyncClient getAsyncClientInstance(String uri, String clientId) { return client; }