Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1077,13 +1077,31 @@ public record TopLogProbs(// @formatter:off
* @param promptTokens Number of tokens in the prompt.
* @param totalTokens Total number of tokens used in the request (prompt +
* completion).
* @param promptTokensDetails Details about the prompt tokens used. Support for
* GLM-4.5 and later models.
*/
@JsonInclude(Include.NON_NULL)
@JsonIgnoreProperties(ignoreUnknown = true)
public record Usage(// @formatter:off
@JsonProperty("completion_tokens") Integer completionTokens,
@JsonProperty("prompt_tokens") Integer promptTokens,
@JsonProperty("total_tokens") Integer totalTokens) { // @formatter:on
@JsonProperty("total_tokens") Integer totalTokens,
@JsonProperty("prompt_tokens_details") PromptTokensDetails promptTokensDetails) { // @formatter:on

public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) {
this(completionTokens, promptTokens, totalTokens, null);
}

/**
* Details about the prompt tokens used.
*
* @param cachedTokens Number of tokens in the prompt that were cached.
*/
@JsonInclude(Include.NON_NULL)
@JsonIgnoreProperties(ignoreUnknown = true)
public record PromptTokensDetails(// @formatter:off
@JsonProperty("cached_tokens") Integer cachedTokens) { // @formatter:on
}

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
Expand Down Expand Up @@ -85,13 +84,22 @@ class ZhiPuAiChatModelIT {
@Value("classpath:/prompts/system-message.st")
private Resource systemResource;

/**
* Default chat options to use for the tests.
* <p>
* glm-4-flash is a free model, so it is used by default on the tests.
*/
private static final ZhiPuAiChatOptions DEFAULT_CHAT_OPTIONS = ZhiPuAiChatOptions.builder()
.model(ZhiPuAiApi.ChatModel.GLM_4_Flash.getValue())
.build();

@Test
void roleTest() {
UserMessage userMessage = new UserMessage(
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage), ChatOptions.builder().build());
Prompt prompt = new Prompt(List.of(userMessage, systemMessage), DEFAULT_CHAT_OPTIONS);
ChatResponse response = this.chatModel.call(prompt);
assertThat(response.getResults()).hasSize(1);
assertThat(response.getResults().get(0).getOutput().getText()).contains("Blackbeard");
Expand All @@ -104,7 +112,7 @@ void streamRoleTest() {
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage), DEFAULT_CHAT_OPTIONS);
Flux<ChatResponse> flux = this.streamingChatModel.stream(prompt);

List<ChatResponse> responses = flux.collectList().block();
Expand Down Expand Up @@ -135,7 +143,7 @@ void listOutputConverter() {
.template(template)
.variables(Map.of("subject", "ice cream flavors", "format", format))
.build();
Prompt prompt = new Prompt(promptTemplate.createMessage(), ChatOptions.builder().build());
Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS);
Generation generation = this.chatModel.call(prompt).getResult();

List<String> list = outputConverter.convert(generation.getOutput().getText());
Expand All @@ -157,8 +165,9 @@ void mapOutputConverter() {
.variables(Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format",
format))
.build();
Prompt prompt = new Prompt(promptTemplate.createMessage(), ChatOptions.builder().build());
Generation generation = this.chatModel.call(prompt).getResult();
Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS);
ChatResponse chatResponse = this.chatModel.call(prompt);
Generation generation = chatResponse.getResult();

Map<String, Object> result = outputConverter.convert(generation.getOutput().getText());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
Expand All @@ -179,7 +188,7 @@ void beanOutputConverter() {
.template(template)
.variables(Map.of("format", format))
.build();
Prompt prompt = new Prompt(promptTemplate.createMessage(), ChatOptions.builder().build());
Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS);
Generation generation = this.chatModel.call(prompt).getResult();

ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getText());
Expand All @@ -198,7 +207,7 @@ void beanOutputConverterRecords() {
.template(template)
.variables(Map.of("format", format))
.build();
Prompt prompt = new Prompt(promptTemplate.createMessage(), ChatOptions.builder().build());
Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS);
Generation generation = this.chatModel.call(prompt).getResult();

ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getText());
Expand All @@ -221,7 +230,7 @@ void beanStreamOutputConverterRecords() {
.template(template)
.variables(Map.of("format", format))
.build();
Prompt prompt = new Prompt(promptTemplate.createMessage());
Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS);

String generationTextFromStream = Objects
.requireNonNull(this.streamingChatModel.stream(prompt).collectList().block())
Expand Down Expand Up @@ -253,7 +262,10 @@ void jsonObjectResponseFormatOutputConverterRecords() {
.variables(Map.of("format", format))
.build();
Prompt prompt = new Prompt(promptTemplate.createMessage(),
ZhiPuAiChatOptions.builder().responseFormat(ChatCompletionRequest.ResponseFormat.jsonObject()).build());
ZhiPuAiChatOptions.builder()
.model(ZhiPuAiApi.ChatModel.GLM_4_Flash.getValue())
.responseFormat(ChatCompletionRequest.ResponseFormat.jsonObject())
.build());

String generationTextFromStream = Objects
.requireNonNull(this.streamingChatModel.stream(prompt).collectList().block())
Expand Down Expand Up @@ -281,7 +293,7 @@ void functionCallTest() {
List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = ZhiPuAiChatOptions.builder()
.model(ZhiPuAiApi.ChatModel.GLM_4.getValue())
.model(ZhiPuAiApi.ChatModel.GLM_4_Flash.getValue())
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
Expand All @@ -306,7 +318,7 @@ void streamFunctionCallTest() {
List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = ZhiPuAiChatOptions.builder()
.model(ZhiPuAiApi.ChatModel.GLM_4.getValue())
.model(ZhiPuAiApi.ChatModel.GLM_4_Flash.getValue())
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
Expand All @@ -332,8 +344,7 @@ void streamFunctionCallTest() {
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "glm-4.5-flash" })
void enabledThinkingTest(String modelName) {
UserMessage userMessage = new UserMessage(
"Are there an infinite number of prime numbers such that n mod 4 == 3?");
UserMessage userMessage = new UserMessage("9.11 and 9.8, which is greater?");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reduce model output latency, use simpler questions


var promptOptions = ZhiPuAiChatOptions.builder()
.model(modelName)
Expand All @@ -344,14 +355,16 @@ void enabledThinkingTest(String modelName) {
ChatResponse response = this.chatModel.call(new Prompt(List.of(userMessage), promptOptions));
logger.info("Response: {}", response);

for (Generation generation : response.getResults()) {
AssistantMessage message = generation.getOutput();
Generation generation = response.getResult();
AssistantMessage message = generation.getOutput();

assertThat(message).isInstanceOf(ZhiPuAiAssistantMessage.class);
assertThat(message).isInstanceOf(ZhiPuAiAssistantMessage.class);

assertThat(message.getText()).isNotBlank();
assertThat(((ZhiPuAiAssistantMessage) message).getReasoningContent()).isNotBlank();
}
assertThat(message.getText()).isNotBlank();
assertThat(((ZhiPuAiAssistantMessage) message).getReasoningContent()).isNotBlank();

ZhiPuAiApi.Usage nativeUsage = (ZhiPuAiApi.Usage) response.getMetadata().getUsage().getNativeUsage();
assertThat(nativeUsage.promptTokensDetails()).isNotNull();
}

@ParameterizedTest(name = "{0} : {displayName} ")
Expand Down Expand Up @@ -382,8 +395,7 @@ void disabledThinkingTest(String modelName) {
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "glm-4.5-flash" })
void streamAndEnableThinkingTest(String modelName) {
UserMessage userMessage = new UserMessage(
"Are there an infinite number of prime numbers such that n mod 4 == 3?");
UserMessage userMessage = new UserMessage("9.11 and 9.8, which is greater?");

var promptOptions = ZhiPuAiChatOptions.builder()
.model(modelName)
Expand All @@ -408,6 +420,7 @@ void streamAndEnableThinkingTest(String modelName) {
}
return message.getText();
})
.filter(StringUtils::hasText)
.collect(Collectors.joining());

logger.info("reasoningContent: {}", reasoningContent);
Expand All @@ -420,7 +433,7 @@ void streamAndEnableThinkingTest(String modelName) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "glm-4v" })
@ValueSource(strings = { "glm-4v-flash" })
void multiModalityEmbeddedImage(String modelName) throws IOException {

var imageData = new ClassPathResource("/test.png");
Expand Down Expand Up @@ -461,7 +474,7 @@ void reasonerMultiModalityEmbeddedImageThinkingModel(String modelName) throws IO
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "glm-4v", "glm-4.1v-thinking-flash" })
@ValueSource(strings = { "glm-4v-flash", "glm-4.1v-thinking-flash" })
void multiModalityImageUrl(String modelName) throws IOException {

var userMessage = UserMessage.builder()
Expand Down Expand Up @@ -505,7 +518,7 @@ void reasonerMultiModalityImageUrl(String modelName) throws IOException {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "glm-4v" })
@ValueSource(strings = { "glm-4v-flash" })
void streamingMultiModalityImageUrl(String modelName) throws IOException {

var userMessage = UserMessage.builder()
Expand Down