Skip to content

Commit b25ebc9

Browse files
committed
feat: MockedConstructionExtension
1 parent 1138dc3 commit b25ebc9

File tree

5 files changed

+165
-82
lines changed

5 files changed

+165
-82
lines changed

int-aws/src/test/java/org/springframework/integration/aws/config/xml/ParserTestBase.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
import org.springframework.beans.factory.xml.XmlBeanDefinitionReader;
1111
import org.springframework.context.expression.StandardBeanExpressionResolver;
1212
import org.springframework.core.io.ByteArrayResource;
13+
import org.springframework.integration.junit.MockedConstructionExtension;
1314

1415
import java.io.IOException;
1516
import java.io.UncheckedIOException;
1617
import java.nio.charset.StandardCharsets;
1718

19+
@ExtendWith(MockedConstructionExtension.class)
1820
@ExtendWith(MockitoExtension.class)
1921
public abstract class ParserTestBase extends BDDMockito {
2022

int-aws/src/test/java/org/springframework/integration/aws/config/xml/SqsMessageDrivenChannelAdapterParserTest.java

Lines changed: 58 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
import org.mockito.ArgumentCaptor;
1414
import org.mockito.Captor;
1515
import org.mockito.Mock;
16+
import org.mockito.MockedConstruction;
17+
import org.mockito.MockedConstruction.Context;
1618
import org.springframework.core.task.TaskExecutor;
1719
import org.springframework.integration.aws.inbound.SqsMessageDrivenChannelAdapter;
20+
import org.springframework.integration.junit.ConstructionMock;
1821
import software.amazon.awssdk.services.sqs.SqsAsyncClient;
1922

2023
import java.time.Duration;
@@ -35,64 +38,67 @@ class SqsMessageDrivenChannelAdapterParserTest extends ParserTestBase {
3538
@Mock
3639
private TaskExecutor taskExecutor;
3740

41+
@ConstructionMock(SqsMessageDrivenChannelAdapter.class)
3842
@Test
39-
void testBeanDefinition() {
43+
void testBeanDefinition(MockedConstruction<SqsMessageDrivenChannelAdapter> mocked) {
4044
registerBean("sqs", SqsAsyncClient.class, sqs);
4145
registerBean("mc", MessagingMessageConverter.class, messageConverter);
4246
registerBean("ex", TaskExecutor.class, taskExecutor);
4347

44-
try (var mocked = mockConstruction(SqsMessageDrivenChannelAdapter.class,
45-
(mock, context) -> assertThat(context.arguments()).asInstanceOf(InstanceOfAssertFactories.LIST).contains(sqs, new String[] {"q"}))) {
46-
var adapter = loadBean(SqsMessageDrivenChannelAdapter.class, """
47-
<int-aws:sqs-message-driven-channel-adapter queues="q" sqs="sqs"
48-
id="i"
49-
channel="c"
50-
error-channel="ec"
51-
send-timeout="#{50}"
52-
acknowledgement-interval="#{50}"
53-
acknowledgement-ordering="#{'ORDERED_BY_GROUP'}"
54-
acknowledgement-mode="#{'ALWAYS'}"
55-
acknowledgement-threshold="#{5}"
56-
back-pressure-mode="#{'FIXED_HIGH_THROUGHPUT'}"
57-
queue-not-found-strategy="#{'FAIL'}"
58-
fifo-batch-grouping-strategy="#{'PROCESS_MESSAGE_GROUPS_IN_PARALLEL_BATCHES'}"
59-
listener-mode="#{'BATCH'}"
60-
message-visibility="#{5}"
61-
max-concurrent-messages="#{5}"
62-
max-messages-per-poll="#{5}"
63-
max-delay-between-polls="#{5}"
64-
poll-timeout="#{5}"
65-
listener-shutdown-timeout="#{5}"
66-
message-converter="mc"
67-
components-task-executor="ex"/>
68-
""");
48+
var adapter = loadBean(SqsMessageDrivenChannelAdapter.class, """
49+
<int-aws:sqs-message-driven-channel-adapter queues="q" sqs="sqs"
50+
id="i"
51+
channel="c"
52+
error-channel="ec"
53+
send-timeout="#{50}"
54+
acknowledgement-interval="#{50}"
55+
acknowledgement-ordering="#{'ORDERED_BY_GROUP'}"
56+
acknowledgement-mode="#{'ALWAYS'}"
57+
acknowledgement-threshold="#{5}"
58+
back-pressure-mode="#{'FIXED_HIGH_THROUGHPUT'}"
59+
queue-not-found-strategy="#{'FAIL'}"
60+
fifo-batch-grouping-strategy="#{'PROCESS_MESSAGE_GROUPS_IN_PARALLEL_BATCHES'}"
61+
listener-mode="#{'BATCH'}"
62+
message-visibility="#{5}"
63+
max-concurrent-messages="#{5}"
64+
max-messages-per-poll="#{5}"
65+
max-delay-between-polls="#{5}"
66+
poll-timeout="#{5}"
67+
listener-shutdown-timeout="#{5}"
68+
message-converter="mc"
69+
components-task-executor="ex"/>
70+
""");
6971

70-
assertThat(mocked.constructed()).size().isOne();
72+
verify(adapter).setBeanName("i");
73+
verify(adapter).setOutputChannelName("c");
74+
verify(adapter).setErrorChannelName("ec");
75+
verify(adapter).setSendTimeout(50);
76+
verify(adapter).setSqsContainerOptions(containerOptions.capture());
7177

72-
verify(adapter).setBeanName("i");
73-
verify(adapter).setOutputChannelName("c");
74-
verify(adapter).setErrorChannelName("ec");
75-
verify(adapter).setSendTimeout(50);
76-
verify(adapter).setSqsContainerOptions(containerOptions.capture());
77-
78-
assertThat(containerOptions.getValue())
79-
.returns(Duration.ofMillis(50), SqsContainerOptions::getAcknowledgementInterval)
80-
.returns(AcknowledgementOrdering.ORDERED_BY_GROUP, SqsContainerOptions::getAcknowledgementOrdering)
81-
.returns(AcknowledgementMode.ALWAYS, SqsContainerOptions::getAcknowledgementMode)
82-
.returns(5, SqsContainerOptions::getAcknowledgementThreshold)
83-
.returns(BackPressureMode.FIXED_HIGH_THROUGHPUT, SqsContainerOptions::getBackPressureMode)
84-
.returns(QueueNotFoundStrategy.FAIL, SqsContainerOptions::getQueueNotFoundStrategy)
85-
.returns(FifoBatchGroupingStrategy.PROCESS_MESSAGE_GROUPS_IN_PARALLEL_BATCHES, SqsContainerOptions::getFifoBatchGroupingStrategy)
86-
.returns(ListenerMode.BATCH, SqsContainerOptions::getListenerMode)
87-
.returns(Duration.ofSeconds(5), SqsContainerOptions::getMessageVisibility)
88-
.returns(5, SqsContainerOptions::getMaxConcurrentMessages)
89-
.returns(5, SqsContainerOptions::getMaxMessagesPerPoll)
90-
.returns(Duration.ofSeconds(5), SqsContainerOptions::getMaxDelayBetweenPolls)
91-
.returns(Duration.ofSeconds(5), SqsContainerOptions::getPollTimeout)
92-
.returns(Duration.ofSeconds(5), SqsContainerOptions::getListenerShutdownTimeout)
93-
.returns(messageConverter, SqsContainerOptions::getMessageConverter)
94-
.returns(taskExecutor, SqsContainerOptions::getComponentsTaskExecutor)
95-
;
96-
}
78+
assertThat(containerOptions.getValue())
79+
.returns(Duration.ofMillis(50), SqsContainerOptions::getAcknowledgementInterval)
80+
.returns(AcknowledgementOrdering.ORDERED_BY_GROUP, SqsContainerOptions::getAcknowledgementOrdering)
81+
.returns(AcknowledgementMode.ALWAYS, SqsContainerOptions::getAcknowledgementMode)
82+
.returns(5, SqsContainerOptions::getAcknowledgementThreshold)
83+
.returns(BackPressureMode.FIXED_HIGH_THROUGHPUT, SqsContainerOptions::getBackPressureMode)
84+
.returns(QueueNotFoundStrategy.FAIL, SqsContainerOptions::getQueueNotFoundStrategy)
85+
.returns(FifoBatchGroupingStrategy.PROCESS_MESSAGE_GROUPS_IN_PARALLEL_BATCHES, SqsContainerOptions::getFifoBatchGroupingStrategy)
86+
.returns(ListenerMode.BATCH, SqsContainerOptions::getListenerMode)
87+
.returns(Duration.ofSeconds(5), SqsContainerOptions::getMessageVisibility)
88+
.returns(5, SqsContainerOptions::getMaxConcurrentMessages)
89+
.returns(5, SqsContainerOptions::getMaxMessagesPerPoll)
90+
.returns(Duration.ofSeconds(5), SqsContainerOptions::getMaxDelayBetweenPolls)
91+
.returns(Duration.ofSeconds(5), SqsContainerOptions::getPollTimeout)
92+
.returns(Duration.ofSeconds(5), SqsContainerOptions::getListenerShutdownTimeout)
93+
.returns(messageConverter, SqsContainerOptions::getMessageConverter)
94+
.returns(taskExecutor, SqsContainerOptions::getComponentsTaskExecutor)
95+
;
96+
97+
assertThat(mocked.constructed()).size().isOne();
98+
}
99+
100+
void testBeanDefinition(SqsMessageDrivenChannelAdapter mock, Context context) {
101+
assertThat(context.arguments()).asInstanceOf(InstanceOfAssertFactories.LIST)
102+
.contains(sqs, new String[] {"q"});
97103
}
98104
}

int-aws/src/test/java/org/springframework/integration/aws/config/xml/SqsOutboundChannelAdapterParserTest.java

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import io.awspring.cloud.sqs.listener.QueueNotFoundStrategy;
44
import org.junit.jupiter.api.Test;
55
import org.mockito.Mock;
6+
import org.mockito.MockedConstruction;
7+
import org.mockito.MockedConstruction.Context;
68
import org.springframework.integration.aws.outbound.SqsMessageHandler;
9+
import org.springframework.integration.junit.ConstructionMock;
710
import org.springframework.messaging.converter.MessageConverter;
811
import software.amazon.awssdk.services.sqs.SqsAsyncClient;
912

@@ -17,39 +20,43 @@ class SqsOutboundChannelAdapterParserTest extends ParserTestBase {
1720
@Mock
1821
private MessageConverter messageConverter;
1922

23+
@ConstructionMock(SqsMessageHandler.class)
2024
@Test
21-
void testParser() {
25+
void testParser(MockedConstruction<SqsMessageHandler> mocked) {
2226
registerBean("sqs", SqsAsyncClient.class, sqs);
2327
registerBean("mc", MessageConverter.class, messageConverter);
2428

25-
try (var mocked = mockConstruction(SqsMessageHandler.class, (mock, context) -> assertThat(context.arguments().get(0)).isSameAs(sqs))) {
26-
var handler = loadBean(SqsMessageHandler.class, """
27-
<int-aws:sqs-outbound-channel-adapter sqs="sqs"
28-
channel="in"
29-
async="#{true}"
30-
delay="#{50}"
31-
queue="#{'q'}"
32-
order="#{5}"
33-
output-channel="#{'out'}"
34-
queue-not-found-strategy="#{'CREATE'}"
35-
send-timeout="#{50}"
36-
message-deduplication-id="#{'dd'}"
37-
message-group-id="#{'mg'}"
38-
message-converter="mc"/>
39-
""");
40-
41-
assertThat(mocked.constructed()).size().isOne();
42-
43-
verify(handler).setAsync(true);
44-
verify(handler).setDelay(50);
45-
verify(handler).setQueue("q");
46-
verify(handler).setOrder(5);
47-
verify(handler).setOutputChannelName("out");
48-
verify(handler).setQueueNotFoundStrategy(QueueNotFoundStrategy.CREATE);
49-
verify(handler).setSendTimeout(50);
50-
verify(handler).setMessageDeduplicationId("dd");
51-
verify(handler).setMessageGroupId("mg");
52-
verify(handler).setMessageConverter(messageConverter);
53-
}
29+
var handler = loadBean(SqsMessageHandler.class, """
30+
<int-aws:sqs-outbound-channel-adapter sqs="sqs"
31+
channel="in"
32+
async="#{true}"
33+
delay="#{50}"
34+
queue="#{'q'}"
35+
order="#{5}"
36+
output-channel="#{'out'}"
37+
queue-not-found-strategy="#{'CREATE'}"
38+
send-timeout="#{50}"
39+
message-deduplication-id="#{'dd'}"
40+
message-group-id="#{'mg'}"
41+
message-converter="mc"/>
42+
""");
43+
44+
verify(handler).setAsync(true);
45+
verify(handler).setDelay(50);
46+
verify(handler).setQueue("q");
47+
verify(handler).setOrder(5);
48+
verify(handler).setOutputChannelName("out");
49+
verify(handler).setQueueNotFoundStrategy(QueueNotFoundStrategy.CREATE);
50+
verify(handler).setSendTimeout(50);
51+
verify(handler).setMessageDeduplicationId("dd");
52+
verify(handler).setMessageGroupId("mg");
53+
verify(handler).setMessageConverter(messageConverter);
54+
55+
assertThat(mocked.constructed()).size().isOne();
56+
}
57+
58+
void testParser(SqsMessageHandler mock, Context context) {
59+
assertThat(context.arguments())
60+
.singleElement().isSameAs(sqs);
5461
}
5562
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package org.springframework.integration.junit;
2+
3+
import java.lang.annotation.Retention;
4+
import java.lang.annotation.Target;
5+
6+
import static java.lang.annotation.ElementType.METHOD;
7+
import static java.lang.annotation.RetentionPolicy.RUNTIME;
8+
9+
@Retention(RUNTIME)
10+
@Target(METHOD)
11+
public @interface ConstructionMock {
12+
13+
Class<?> value();
14+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package org.springframework.integration.junit;
2+
3+
import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
4+
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
5+
import org.junit.jupiter.api.extension.ExtensionContext;
6+
import org.junit.jupiter.api.extension.ExtensionContext.Namespace;
7+
import org.junit.jupiter.api.extension.ParameterContext;
8+
import org.junit.jupiter.api.extension.ParameterResolver;
9+
import org.junit.platform.commons.support.AnnotationSupport;
10+
import org.junit.platform.commons.support.ReflectionSupport;
11+
import org.mockito.MockedConstruction;
12+
import org.mockito.MockedConstruction.Context;
13+
import org.mockito.Mockito;
14+
15+
public class MockedConstructionExtension implements ParameterResolver, BeforeTestExecutionCallback, AfterTestExecutionCallback {
16+
17+
private final Namespace namespace = Namespace.create(MockedConstructionExtension.class);
18+
19+
@Override
20+
public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) {
21+
return extensionContext.getTestMethod().isPresent()
22+
&& parameterContext.getDeclaringExecutable().equals(extensionContext.getRequiredTestMethod())
23+
&& extensionContext.getRequiredTestMethod().isAnnotationPresent(ConstructionMock.class)
24+
&& parameterContext.getParameter().getType() == MockedConstruction.class;
25+
}
26+
27+
@Override
28+
public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) {
29+
return extensionContext.getStore(namespace).get("mocked");
30+
}
31+
32+
@Override
33+
public void beforeTestExecution(ExtensionContext context) {
34+
var annotation = context.getTestMethod()
35+
.flatMap(method -> AnnotationSupport.findAnnotation(method, ConstructionMock.class)).orElse(null);
36+
if (annotation != null) {
37+
context.getStore(namespace).put("mocked", mockConstruction(context, annotation.value()));
38+
}
39+
}
40+
41+
private <T> MockedConstruction<T> mockConstruction(ExtensionContext context, Class<T> mockType) {
42+
return ReflectionSupport.findMethod(context.getRequiredTestClass(), context.getRequiredTestMethod().getName(), mockType, Context.class)
43+
.map(method -> Mockito.mockConstruction(mockType, (m, c) -> ReflectionSupport.invokeMethod(method, context.getRequiredTestInstance(), m, c)))
44+
.orElseGet(() -> Mockito.mockConstruction(mockType));
45+
}
46+
47+
@Override
48+
public void afterTestExecution(ExtensionContext context) throws Exception {
49+
var mocked = context.getStore(namespace).get("mocked");
50+
if (mocked != null) {
51+
((AutoCloseable) mocked).close();
52+
}
53+
}
54+
}

0 commit comments

Comments
 (0)