From a9ac2717739265866d9a2db85e7ea4c12e907762 Mon Sep 17 00:00:00 2001 From: YunKui Lu Date: Fri, 31 Oct 2025 22:11:27 +0800 Subject: [PATCH] feat(zhipuai): Add `prompt_tokens_details` and update default chat options for tests - Introduced `prompt_tokens_details` with `cached_tokens` field to `ZhiPuAiApi.Usage` - Updated test cases to replace inline `ChatOptions` with `DEFAULT_CHAT_OPTIONS` - Refactored test models to ensure usage of `glm-4-flash` and `glm-4v-flash` as defaults - Added metadata validations for `promptTokensDetails` in response Signed-off-by: YunKui Lu --- .../ai/zhipuai/api/ZhiPuAiApi.java | 20 +++++- .../ai/zhipuai/chat/ZhiPuAiChatModelIT.java | 63 +++++++++++-------- 2 files changed, 57 insertions(+), 26 deletions(-) diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java index c5a99f99eb3..2cd86a267b8 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java @@ -1077,13 +1077,31 @@ public record TopLogProbs(// @formatter:off * @param promptTokens Number of tokens in the prompt. * @param totalTokens Total number of tokens used in the request (prompt + * completion). + * @param promptTokensDetails Details about the prompt tokens used. Support for + * GLM-4.5 and later models. */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record Usage(// @formatter:off @JsonProperty("completion_tokens") Integer completionTokens, @JsonProperty("prompt_tokens") Integer promptTokens, - @JsonProperty("total_tokens") Integer totalTokens) { // @formatter:on + @JsonProperty("total_tokens") Integer totalTokens, + @JsonProperty("prompt_tokens_details") PromptTokensDetails promptTokensDetails) { // @formatter:on + + public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) { + this(completionTokens, promptTokens, totalTokens, null); + } + + /** + * Details about the prompt tokens used. + * + * @param cachedTokens Number of tokens in the prompt that were cached. + */ + @JsonInclude(Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record PromptTokensDetails(// @formatter:off + @JsonProperty("cached_tokens") Integer cachedTokens) { // @formatter:on + } } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java index e06abdf5d3c..2c1bc90f705 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java @@ -40,7 +40,6 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -85,13 +84,22 @@ class ZhiPuAiChatModelIT { @Value("classpath:/prompts/system-message.st") private Resource systemResource; + /** + * Default chat options to use for the tests. + *

+ * glm-4-flash is a free model, so it is used by default on the tests. + */ + private static final ZhiPuAiChatOptions DEFAULT_CHAT_OPTIONS = ZhiPuAiChatOptions.builder() + .model(ZhiPuAiApi.ChatModel.GLM_4_Flash.getValue()) + .build(); + @Test void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); - Prompt prompt = new Prompt(List.of(userMessage, systemMessage), ChatOptions.builder().build()); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage), DEFAULT_CHAT_OPTIONS); ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getText()).contains("Blackbeard"); @@ -104,7 +112,7 @@ void streamRoleTest() { "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); - Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage), DEFAULT_CHAT_OPTIONS); Flux flux = this.streamingChatModel.stream(prompt); List responses = flux.collectList().block(); @@ -135,7 +143,7 @@ void listOutputConverter() { .template(template) .variables(Map.of("subject", "ice cream flavors", "format", format)) .build(); - Prompt prompt = new Prompt(promptTemplate.createMessage(), ChatOptions.builder().build()); + Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS); Generation generation = this.chatModel.call(prompt).getResult(); List list = outputConverter.convert(generation.getOutput().getText()); @@ -157,8 +165,9 @@ void mapOutputConverter() { .variables(Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)) .build(); - Prompt prompt = new Prompt(promptTemplate.createMessage(), ChatOptions.builder().build()); - Generation generation = this.chatModel.call(prompt).getResult(); + Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS); + ChatResponse chatResponse = this.chatModel.call(prompt); + Generation generation = chatResponse.getResult(); Map result = outputConverter.convert(generation.getOutput().getText()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -179,7 +188,7 @@ void beanOutputConverter() { .template(template) .variables(Map.of("format", format)) .build(); - Prompt prompt = new Prompt(promptTemplate.createMessage(), ChatOptions.builder().build()); + Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS); Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getText()); @@ -198,7 +207,7 @@ void beanOutputConverterRecords() { .template(template) .variables(Map.of("format", format)) .build(); - Prompt prompt = new Prompt(promptTemplate.createMessage(), ChatOptions.builder().build()); + Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS); Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getText()); @@ -221,7 +230,7 @@ void beanStreamOutputConverterRecords() { .template(template) .variables(Map.of("format", format)) .build(); - Prompt prompt = new Prompt(promptTemplate.createMessage()); + Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS); String generationTextFromStream = Objects .requireNonNull(this.streamingChatModel.stream(prompt).collectList().block()) @@ -253,7 +262,10 @@ void jsonObjectResponseFormatOutputConverterRecords() { .variables(Map.of("format", format)) .build(); Prompt prompt = new Prompt(promptTemplate.createMessage(), - ZhiPuAiChatOptions.builder().responseFormat(ChatCompletionRequest.ResponseFormat.jsonObject()).build()); + ZhiPuAiChatOptions.builder() + .model(ZhiPuAiApi.ChatModel.GLM_4_Flash.getValue()) + .responseFormat(ChatCompletionRequest.ResponseFormat.jsonObject()) + .build()); String generationTextFromStream = Objects .requireNonNull(this.streamingChatModel.stream(prompt).collectList().block()) @@ -281,7 +293,7 @@ void functionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = ZhiPuAiChatOptions.builder() - .model(ZhiPuAiApi.ChatModel.GLM_4.getValue()) + .model(ZhiPuAiApi.ChatModel.GLM_4_Flash.getValue()) .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) @@ -306,7 +318,7 @@ void streamFunctionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = ZhiPuAiChatOptions.builder() - .model(ZhiPuAiApi.ChatModel.GLM_4.getValue()) + .model(ZhiPuAiApi.ChatModel.GLM_4_Flash.getValue()) .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) @@ -332,8 +344,7 @@ void streamFunctionCallTest() { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "glm-4.5-flash" }) void enabledThinkingTest(String modelName) { - UserMessage userMessage = new UserMessage( - "Are there an infinite number of prime numbers such that n mod 4 == 3?"); + UserMessage userMessage = new UserMessage("9.11 and 9.8, which is greater?"); var promptOptions = ZhiPuAiChatOptions.builder() .model(modelName) @@ -344,14 +355,16 @@ void enabledThinkingTest(String modelName) { ChatResponse response = this.chatModel.call(new Prompt(List.of(userMessage), promptOptions)); logger.info("Response: {}", response); - for (Generation generation : response.getResults()) { - AssistantMessage message = generation.getOutput(); + Generation generation = response.getResult(); + AssistantMessage message = generation.getOutput(); - assertThat(message).isInstanceOf(ZhiPuAiAssistantMessage.class); + assertThat(message).isInstanceOf(ZhiPuAiAssistantMessage.class); - assertThat(message.getText()).isNotBlank(); - assertThat(((ZhiPuAiAssistantMessage) message).getReasoningContent()).isNotBlank(); - } + assertThat(message.getText()).isNotBlank(); + assertThat(((ZhiPuAiAssistantMessage) message).getReasoningContent()).isNotBlank(); + + ZhiPuAiApi.Usage nativeUsage = (ZhiPuAiApi.Usage) response.getMetadata().getUsage().getNativeUsage(); + assertThat(nativeUsage.promptTokensDetails()).isNotNull(); } @ParameterizedTest(name = "{0} : {displayName} ") @@ -382,8 +395,7 @@ void disabledThinkingTest(String modelName) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "glm-4.5-flash" }) void streamAndEnableThinkingTest(String modelName) { - UserMessage userMessage = new UserMessage( - "Are there an infinite number of prime numbers such that n mod 4 == 3?"); + UserMessage userMessage = new UserMessage("9.11 and 9.8, which is greater?"); var promptOptions = ZhiPuAiChatOptions.builder() .model(modelName) @@ -408,6 +420,7 @@ void streamAndEnableThinkingTest(String modelName) { } return message.getText(); }) + .filter(StringUtils::hasText) .collect(Collectors.joining()); logger.info("reasoningContent: {}", reasoningContent); @@ -420,7 +433,7 @@ void streamAndEnableThinkingTest(String modelName) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "glm-4v" }) + @ValueSource(strings = { "glm-4v-flash" }) void multiModalityEmbeddedImage(String modelName) throws IOException { var imageData = new ClassPathResource("/test.png"); @@ -461,7 +474,7 @@ void reasonerMultiModalityEmbeddedImageThinkingModel(String modelName) throws IO } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "glm-4v", "glm-4.1v-thinking-flash" }) + @ValueSource(strings = { "glm-4v-flash", "glm-4.1v-thinking-flash" }) void multiModalityImageUrl(String modelName) throws IOException { var userMessage = UserMessage.builder() @@ -505,7 +518,7 @@ void reasonerMultiModalityImageUrl(String modelName) throws IOException { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "glm-4v" }) + @ValueSource(strings = { "glm-4v-flash" }) void streamingMultiModalityImageUrl(String modelName) throws IOException { var userMessage = UserMessage.builder()