diff --git a/spring-integration-core/src/main/java/org/springframework/integration/handler/advice/ExpressionEvaluatingRequestHandlerAdvice.java b/spring-integration-core/src/main/java/org/springframework/integration/handler/advice/ExpressionEvaluatingRequestHandlerAdvice.java index 9d581d85637..65b039c8f3f 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/handler/advice/ExpressionEvaluatingRequestHandlerAdvice.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/handler/advice/ExpressionEvaluatingRequestHandlerAdvice.java @@ -164,7 +164,7 @@ private void evaluateSuccessExpression(Message message) throws Exception { evaluationFailed = true; } if (evalResult != null && this.successChannel != null) { - AdviceMessage resultMessage = new AdviceMessage(evalResult, message); + AdviceMessage resultMessage = new AdviceMessage(evalResult, message); this.messagingTemplate.send(this.successChannel, resultMessage); } if (evaluationFailed && this.propagateOnSuccessEvaluationFailures) { diff --git a/spring-integration-core/src/main/java/org/springframework/integration/history/MessageHistory.java b/spring-integration-core/src/main/java/org/springframework/integration/history/MessageHistory.java index b656559e6e0..568b3f54d3d 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/history/MessageHistory.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/history/MessageHistory.java @@ -25,20 +25,32 @@ import java.util.ListIterator; import java.util.Properties; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.integration.IntegrationMessageHeaderAccessor; +import org.springframework.integration.message.AdviceMessage; import org.springframework.integration.support.DefaultMessageBuilderFactory; import org.springframework.integration.support.MessageBuilderFactory; +import org.springframework.integration.support.MutableMessage; +import org.springframework.integration.support.MutableMessageBuilderFactory; import org.springframework.integration.support.context.NamedComponent; import org.springframework.messaging.Message; +import org.springframework.messaging.support.ErrorMessage; +import org.springframework.messaging.support.GenericMessage; import org.springframework.util.Assert; import org.springframework.util.StringUtils; /** * @author Mark Fisher + * @author Artem Bilan * @since 2.0 */ @SuppressWarnings("serial") public final class MessageHistory implements List, Serializable { + private static final Log logger = LogFactory.getLog(MessageHistory.class); + public static final String HEADER_NAME = "history"; public static final String NAME_PROPERTY = "name"; @@ -47,21 +59,21 @@ public final class MessageHistory implements List, Serializable { public static final String TIMESTAMP_PROPERTY = "timestamp"; - private static final MessageBuilderFactory mesageBuilderFactory = new DefaultMessageBuilderFactory(); + private static final MessageBuilderFactory MESSAGE_BUILDER_FACTORY = new DefaultMessageBuilderFactory(); private final List components; public static MessageHistory read(Message message) { - return (message != null) ? - message.getHeaders().get(HEADER_NAME, MessageHistory.class) : null; + return message != null ? message.getHeaders().get(HEADER_NAME, MessageHistory.class) : null; } public static Message write(Message message, NamedComponent component) { - return write(message, component, mesageBuilderFactory); + return write(message, component, MESSAGE_BUILDER_FACTORY); } + @SuppressWarnings("unchecked") public static Message write(Message message, NamedComponent component, MessageBuilderFactory messageBuilderFactory) { Assert.notNull(message, "Message must not be null"); @@ -73,7 +85,38 @@ public static Message write(Message message, NamedComponent component, new ArrayList(previousHistory) : new ArrayList(); components.add(metadata); MessageHistory history = new MessageHistory(components); - message = messageBuilderFactory.fromMessage(message).setHeader(HEADER_NAME, history).build(); + + if (message instanceof MutableMessage) { + message.getHeaders().put(HEADER_NAME, history); + } + else if (message instanceof ErrorMessage) { + IntegrationMessageHeaderAccessor headerAccessor = new IntegrationMessageHeaderAccessor(message); + headerAccessor.setHeader(HEADER_NAME, history); + Throwable payload = ((ErrorMessage) message).getPayload(); + ErrorMessage errorMessage = new ErrorMessage(payload, headerAccessor.toMessageHeaders()); + message = (Message) errorMessage; + } + else if (message instanceof AdviceMessage) { + IntegrationMessageHeaderAccessor headerAccessor = new IntegrationMessageHeaderAccessor(message); + headerAccessor.setHeader(HEADER_NAME, history); + message = new AdviceMessage(message.getPayload(), headerAccessor.toMessageHeaders(), + ((AdviceMessage) message).getInputMessage()); + } + else { + if (!(message instanceof GenericMessage) && + (messageBuilderFactory instanceof DefaultMessageBuilderFactory || + messageBuilderFactory instanceof MutableMessageBuilderFactory)) { + if (logger.isWarnEnabled()) { + logger.warn("MessageHistory rebuilds the message and produces the result of the [" + + messageBuilderFactory + "], not an instance of the provided type [" + + message.getClass() + "]. Consider to supply a custom MessageBuilderFactory " + + "to retain custom messages during MessageHistory tracking."); + } + } + message = messageBuilderFactory.fromMessage(message) + .setHeader(HEADER_NAME, history) + .build(); + } } return message; } @@ -263,6 +306,7 @@ public String getTimestamp() { private void setTimestamp(String timestamp) { this.setProperty(TIMESTAMP_PROPERTY, timestamp); } + } } diff --git a/spring-integration-core/src/main/java/org/springframework/integration/message/AdviceMessage.java b/spring-integration-core/src/main/java/org/springframework/integration/message/AdviceMessage.java index c442f9deacf..67dd3ec3393 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/message/AdviceMessage.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/message/AdviceMessage.java @@ -29,20 +29,22 @@ * handler. * . * @author Gary Russell + * @author Artem Bilan + * * @since 2.2 */ -public class AdviceMessage extends GenericMessage { +public class AdviceMessage extends GenericMessage { private static final long serialVersionUID = 1L; private final Message inputMessage; - public AdviceMessage(Object payload, Message inputMessage) { + public AdviceMessage(T payload, Message inputMessage) { super(payload); this.inputMessage = inputMessage; } - public AdviceMessage(Object payload, Map headers, Message inputMessage) { + public AdviceMessage(T payload, Map headers, Message inputMessage) { super(payload, headers); this.inputMessage = inputMessage; } diff --git a/spring-integration-core/src/test/java/org/springframework/integration/core/MessageHistoryTests.java b/spring-integration-core/src/test/java/org/springframework/integration/core/MessageHistoryTests.java index d21b2290b07..bcae61d47a4 100644 --- a/spring-integration-core/src/test/java/org/springframework/integration/core/MessageHistoryTests.java +++ b/spring-integration-core/src/test/java/org/springframework/integration/core/MessageHistoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2010 the original author or authors. + * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,22 +16,30 @@ package org.springframework.integration.core; +import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThat; import java.util.Properties; import org.junit.Test; import org.springframework.integration.history.MessageHistory; -import org.springframework.messaging.support.GenericMessage; +import org.springframework.integration.message.AdviceMessage; import org.springframework.integration.support.MessageBuilder; +import org.springframework.integration.support.MutableMessage; import org.springframework.integration.support.context.NamedComponent; import org.springframework.messaging.Message; +import org.springframework.messaging.support.ErrorMessage; +import org.springframework.messaging.support.GenericMessage; /** * @author Mark Fisher + * @author Artem Bilan * @since 2.0 */ public class MessageHistoryTests { @@ -57,6 +65,69 @@ public void verifyImmutability() { history.add(new Properties()); } + @Test + public void testCorrectMutableMessageAfterWrite() { + MutableMessage original = new MutableMessage<>("foo"); + assertNull(MessageHistory.read(original)); + Message result1 = MessageHistory.write(original, new TestComponent(1)); + assertThat(result1, instanceOf(MutableMessage.class)); + assertSame(original, result1); + MessageHistory history1 = MessageHistory.read(result1); + assertNotNull(history1); + assertEquals("testComponent-1", history1.toString()); + Message result2 = MessageHistory.write(result1, new TestComponent(2)); + assertSame(original, result2); + MessageHistory history2 = MessageHistory.read(result2); + assertNotNull(history2); + assertEquals("testComponent-1,testComponent-2", history2.toString()); + } + + @Test + public void testCorrectErrorMessageAfterWrite() { + RuntimeException payload = new RuntimeException(); + ErrorMessage original = new ErrorMessage(payload); + assertNull(MessageHistory.read(original)); + Message result1 = MessageHistory.write(original, new TestComponent(1)); + assertThat(result1, instanceOf(ErrorMessage.class)); + assertNotSame(original, result1); + assertSame(original.getPayload(), result1.getPayload()); + MessageHistory history1 = MessageHistory.read(result1); + assertNotNull(history1); + assertEquals("testComponent-1", history1.toString()); + Message result2 = MessageHistory.write(result1, new TestComponent(2)); + assertThat(result2, instanceOf(ErrorMessage.class)); + assertNotSame(original, result2); + assertNotSame(result1, result2); + assertSame(original.getPayload(), result2.getPayload()); + MessageHistory history2 = MessageHistory.read(result2); + assertNotNull(history2); + assertEquals("testComponent-1,testComponent-2", history2.toString()); + } + + @Test + public void testCorrectAdviceMessageAfterWrite() { + Message inputMessage = new GenericMessage<>("input"); + AdviceMessage original = new AdviceMessage<>("foo", inputMessage); + assertNull(MessageHistory.read(original)); + Message result1 = MessageHistory.write(original, new TestComponent(1)); + assertThat(result1, instanceOf(AdviceMessage.class)); + assertNotSame(original, result1); + assertSame(original.getPayload(), result1.getPayload()); + assertSame(original.getInputMessage(), ((AdviceMessage) result1).getInputMessage()); + MessageHistory history1 = MessageHistory.read(result1); + assertNotNull(history1); + assertEquals("testComponent-1", history1.toString()); + Message result2 = MessageHistory.write(result1, new TestComponent(2)); + assertThat(result2, instanceOf(AdviceMessage.class)); + assertNotSame(original, result2); + assertSame(original.getPayload(), result2.getPayload()); + assertSame(original.getInputMessage(), ((AdviceMessage) result2).getInputMessage()); + assertNotSame(result1, result2); + MessageHistory history2 = MessageHistory.read(result2); + assertNotNull(history2); + assertEquals("testComponent-1,testComponent-2", history2.toString()); + } + private static class TestComponent implements NamedComponent { diff --git a/spring-integration-core/src/test/java/org/springframework/integration/handler/advice/AdvisedMessageHandlerTests.java b/spring-integration-core/src/test/java/org/springframework/integration/handler/advice/AdvisedMessageHandlerTests.java index b8c17d1dc26..b081d74e28f 100644 --- a/spring-integration-core/src/test/java/org/springframework/integration/handler/advice/AdvisedMessageHandlerTests.java +++ b/spring-integration-core/src/test/java/org/springframework/integration/handler/advice/AdvisedMessageHandlerTests.java @@ -164,7 +164,7 @@ protected Object handleRequestMessage(Message requestMessage) { Message success = successChannel.receive(1000); assertNotNull(success); - assertEquals("Hello, world!", ((AdviceMessage) success).getInputMessage().getPayload()); + assertEquals("Hello, world!", ((AdviceMessage) success).getInputMessage().getPayload()); assertEquals("foo", success.getPayload()); // advice with failure, not trapped @@ -244,7 +244,7 @@ protected Object handleRequestMessage(Message requestMessage) { Message success = successChannel.receive(1000); assertNotNull(success); - assertEquals("Hello, world!", ((AdviceMessage) success).getInputMessage().getPayload()); + assertEquals("Hello, world!", ((AdviceMessage) success).getInputMessage().getPayload()); assertEquals(ArithmeticException.class, success.getPayload().getClass()); assertEquals("/ by zero", ((Exception) success.getPayload()).getMessage()); @@ -262,7 +262,7 @@ protected Object handleRequestMessage(Message requestMessage) { success = successChannel.receive(1000); assertNotNull(success); - assertEquals("Hello, world!", ((AdviceMessage) success).getInputMessage().getPayload()); + assertEquals("Hello, world!", ((AdviceMessage) success).getInputMessage().getPayload()); assertEquals(ArithmeticException.class, success.getPayload().getClass()); assertEquals("/ by zero", ((Exception) success.getPayload()).getMessage()); diff --git a/spring-integration-mongodb/src/main/java/org/springframework/integration/mongodb/store/MongoDbMessageStore.java b/spring-integration-mongodb/src/main/java/org/springframework/integration/mongodb/store/MongoDbMessageStore.java index 4eb6ff5c96a..0f936708380 100644 --- a/spring-integration-mongodb/src/main/java/org/springframework/integration/mongodb/store/MongoDbMessageStore.java +++ b/spring-integration-mongodb/src/main/java/org/springframework/integration/mongodb/store/MongoDbMessageStore.java @@ -19,13 +19,11 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Properties; -import java.util.Set; import java.util.UUID; import org.springframework.beans.BeansException; @@ -34,9 +32,7 @@ import org.springframework.beans.factory.InitializingBean; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; -import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.Converter; -import org.springframework.core.convert.converter.GenericConverter; import org.springframework.core.serializer.support.DeserializingConverter; import org.springframework.core.serializer.support.SerializingConverter; import org.springframework.data.annotation.Id; @@ -64,6 +60,7 @@ import org.springframework.integration.store.MessageGroup; import org.springframework.integration.store.MessageGroupStore; import org.springframework.integration.store.MessageStore; +import org.springframework.integration.support.MutableMessage; import org.springframework.integration.support.MutableMessageBuilder; import org.springframework.jmx.export.annotation.ManagedAttribute; import org.springframework.messaging.Message; @@ -650,44 +647,27 @@ public GenericMessage convert(DBObject source) { } - private final class DBObjectToMutableMessageConverter implements GenericConverter { + private final class DBObjectToMutableMessageConverter implements Converter> { - private final Class mutableMessageClass; - - private DBObjectToMutableMessageConverter() { - try { - this.mutableMessageClass = ClassUtils.forName("org.springframework.integration.support.MutableMessage", - MongoDbMessageStore.this.classLoader); - } - catch (ClassNotFoundException e) { - throw new IllegalStateException(e); - } - } @Override - public Set getConvertibleTypes() { - Set convertiblePairs = new HashSet(); - convertiblePairs.add(new ConvertiblePair(DBObject.class, this.mutableMessageClass)); - return convertiblePairs; - } - - @Override - public Object convert(Object source, TypeDescriptor sourceType, TypeDescriptor targetType) { - DBObject dbObject = (DBObject) source; + public MutableMessage convert(DBObject source) { @SuppressWarnings("unchecked") Map headers = - MongoDbMessageStore.this.converter.normalizeHeaders((Map) dbObject.get("headers")); + MongoDbMessageStore.this.converter.normalizeHeaders((Map) source.get("headers")); - return MutableMessageBuilder.withPayload(MongoDbMessageStore.this.converter.extractPayload(dbObject)) + Object payload = MongoDbMessageStore.this.converter.extractPayload(source); + return (MutableMessage) MutableMessageBuilder.withPayload(payload) .copyHeaders(headers) .build(); } + } - private class DBObjectToAdviceMessageConverter implements Converter { + private class DBObjectToAdviceMessageConverter implements Converter> { @Override - public AdviceMessage convert(DBObject source) { + public AdviceMessage convert(DBObject source) { @SuppressWarnings("unchecked") Map headers = MongoDbMessageStore.this.converter.normalizeHeaders((Map) source.get("headers")); @@ -698,16 +678,18 @@ public AdviceMessage convert(DBObject source) { DBObject inputMessageObject = (DBObject) source.get("inputMessage"); Object inputMessageType = inputMessageObject.get("_class"); try { - Class messageClass = ClassUtils.forName(inputMessageType.toString(), MongoDbMessageStore.this.classLoader); - inputMessage = (Message) MongoDbMessageStore.this.converter.read(messageClass, inputMessageObject); + Class messageClass = ClassUtils.forName(inputMessageType.toString(), + MongoDbMessageStore.this.classLoader); + inputMessage = (Message) MongoDbMessageStore.this.converter.read(messageClass, + inputMessageObject); } catch (Exception e) { throw new IllegalStateException("failed to load class: " + inputMessageType, e); } } - AdviceMessage message = - new AdviceMessage(MongoDbMessageStore.this.converter.extractPayload(source), headers, inputMessage); + AdviceMessage message = new AdviceMessage( + MongoDbMessageStore.this.converter.extractPayload(source), headers, inputMessage); enhanceHeaders(message.getHeaders(), headers); return message; diff --git a/spring-integration-mongodb/src/test/java/org/springframework/integration/mongodb/store/AbstractMongoDbMessageStoreTests.java b/spring-integration-mongodb/src/test/java/org/springframework/integration/mongodb/store/AbstractMongoDbMessageStoreTests.java index 92380a6592f..a9ee9e3d6ac 100644 --- a/spring-integration-mongodb/src/test/java/org/springframework/integration/mongodb/store/AbstractMongoDbMessageStoreTests.java +++ b/spring-integration-mongodb/src/test/java/org/springframework/integration/mongodb/store/AbstractMongoDbMessageStoreTests.java @@ -17,6 +17,7 @@ package org.springframework.integration.mongodb.store; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; @@ -38,6 +39,7 @@ import org.springframework.integration.mongodb.rules.MongoDbAvailableTests; import org.springframework.integration.store.MessageStore; import org.springframework.integration.support.MessageBuilder; +import org.springframework.integration.support.MutableMessage; import org.springframework.integration.support.MutableMessageBuilder; import org.springframework.messaging.Message; import org.springframework.messaging.MessagingException; @@ -186,7 +188,7 @@ public void testInt3076AdviceMessage() throws Exception { p.setFname("John"); p.setLname("Doe"); Message inputMessage = MessageBuilder.withPayload(p).build(); - Message messageToStore = new AdviceMessage("foo", inputMessage); + Message messageToStore = new AdviceMessage("foo", inputMessage); store.addMessage(messageToStore); Message retrievedMessage = store.getMessage(messageToStore.getHeaders().getId()); assertNotNull(retrievedMessage); @@ -205,12 +207,12 @@ public void testAdviceMessageAsPayload() throws Exception { p.setFname("John"); p.setLname("Doe"); Message inputMessage = MessageBuilder.withPayload(p).build(); - Message messageToStore = new GenericMessage>(new AdviceMessage("foo", inputMessage)); + Message messageToStore = new GenericMessage>(new AdviceMessage("foo", inputMessage)); store.addMessage(messageToStore); Message retrievedMessage = store.getMessage(messageToStore.getHeaders().getId()); assertNotNull(retrievedMessage); assertTrue(retrievedMessage.getPayload() instanceof AdviceMessage); - AdviceMessage adviceMessage = (AdviceMessage) retrievedMessage.getPayload(); + AdviceMessage adviceMessage = (AdviceMessage) retrievedMessage.getPayload(); assertEquals("foo", adviceMessage.getPayload()); assertEquals(messageToStore.getHeaders(), retrievedMessage.getHeaders()); assertEquals(inputMessage, adviceMessage.getInputMessage()); @@ -228,7 +230,7 @@ public void testMutableMessageAsPayload() throws Exception { store.addMessage(messageToStore); Message retrievedMessage = store.getMessage(messageToStore.getHeaders().getId()); assertNotNull(retrievedMessage); - assertEquals("org.springframework.integration.support.MutableMessage", retrievedMessage.getPayload().getClass().getName()); + assertThat(retrievedMessage.getPayload(), instanceOf(MutableMessage.class)); assertEquals(messageToStore.getPayload(), retrievedMessage.getPayload()); assertEquals(messageToStore.getHeaders(), retrievedMessage.getHeaders()); assertEquals(((Message) messageToStore.getPayload()).getPayload(), p);