From 9289ef477785943fd4b308e0fe3f22566d2a49a4 Mon Sep 17 00:00:00 2001 From: Ahoo Wang Date: Thu, 19 Jun 2025 17:31:46 +0800 Subject: [PATCH] refactor(ai-client-chat): optimize conversation memory handling in PromptChatMemoryAdvisor - When the memory message is empty, the `ChatClientRequest` is returned intact. - Add new user message to conversation memory at the beginning of the process - Reorder and optimize the steps for processing memory messages Signed-off-by: Ahoo Wang --- .../advisor/PromptChatMemoryAdvisor.java | 23 +++++++++---------- .../chat/client/ChatClientAdvisorTests.java | 20 ++-------------- 2 files changed, 13 insertions(+), 30 deletions(-) diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java index de88715e896..585f38c9735 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java @@ -113,29 +113,28 @@ public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChai logger.debug("[PromptChatMemoryAdvisor.before] Memory before processing for conversationId={}: {}", conversationId, memoryMessages); - // 2. Process memory messages as a string. + // 2. Add the new user message to the conversation memory. + UserMessage userMessage = chatClientRequest.prompt().getUserMessage(); + this.chatMemory.add(conversationId, userMessage); + // 3. Check if memory is empty and return the request as is. + if (memoryMessages.isEmpty()) { + return chatClientRequest; + } + // 4. Process memory messages as a string. String memory = memoryMessages.stream() .filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT) .map(m -> m.getMessageType() + ":" + m.getText()) .collect(Collectors.joining(System.lineSeparator())); - // 3. Augment the system message. + // 5. Augment the system message. SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage(); String augmentedSystemText = this.systemPromptTemplate .render(Map.of("instructions", systemMessage.getText(), "memory", memory)); - // 4. Create a new request with the augmented system message. - ChatClientRequest processedChatClientRequest = chatClientRequest.mutate() + // 6. Create a new request with the augmented system message. + return chatClientRequest.mutate() .prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText)) .build(); - - // 5. Add all user messages from the current prompt to memory (after system - // message is generated) - // 4. Add the new user message to the conversation memory. - UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage(); - this.chatMemory.add(conversationId, userMessage); - - return processedChatClientRequest; } @Override diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java index e4514358638..ea79d011979 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java @@ -97,15 +97,7 @@ public void promptChatMemory() { // Capture and verify the system message instructions Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); - assertThat(systemMessage.getText()).isEqualToIgnoringWhitespace(""" - Default system text. - - Use the conversation memory from the MEMORY section to provide accurate answers. - - --------------------- - MEMORY: - --------------------- - """); + assertThat(systemMessage.getText()).isEqualToIgnoringWhitespace("Default system text."); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); // Capture and verify the user message instructions @@ -175,15 +167,7 @@ public void streamingPromptChatMemory() { // Capture and verify the system message instructions Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); - assertThat(systemMessage.getText()).isEqualToIgnoringWhitespace(""" - Default system text. - - Use the conversation memory from the MEMORY section to provide accurate answers. - - --------------------- - MEMORY: - --------------------- - """); + assertThat(systemMessage.getText()).isEqualToIgnoringWhitespace("Default system text."); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); // Capture and verify the user message instructions