Skip to content

Commit 68d176f

Browse files
sunyuhan1998ilayaperumalg
authored andcommitted
fix: GH-4586 fixed a handling exception in DeepSeekStreamFunctionCallingHelper when toolCalls() returns an empty list instead of null; corrected a flawed ternary expression; and added corresponding unit tests to improve code robustness.
Signed-off-by: Sun Yuhan <[email protected]>
1 parent d7977aa commit 68d176f

File tree

2 files changed

+221
-5
lines changed

2 files changed

+221
-5
lines changed

models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekStreamFunctionCallingHelper.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -34,6 +34,7 @@
3434
* ChatCompletionChunk in case of function calling message.
3535
*
3636
* @author Geng Rong
37+
* @author Sun Yuhan
3738
*/
3839
public class DeepSeekStreamFunctionCallingHelper {
3940

@@ -76,22 +77,23 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
7677
}
7778

7879
private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) {
79-
String content = (current.content() != null ? current.content()
80-
: "" + ((previous.content() != null) ? previous.content() : ""));
80+
String content = (current.content() != null
81+
? (previous.content() != null ? previous.content() + current.content() : current.content())
82+
: (previous.content() != null ? previous.content() : ""));
8183
Role role = (current.role() != null ? current.role() : previous.role());
8284
role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null
8385
String name = (current.name() != null ? current.name() : previous.name());
8486
String toolCallId = (current.toolCallId() != null ? current.toolCallId() : previous.toolCallId());
8587

8688
List<ToolCall> toolCalls = new ArrayList<>();
8789
ToolCall lastPreviousTooCall = null;
88-
if (previous.toolCalls() != null) {
90+
if (!CollectionUtils.isEmpty(previous.toolCalls())) {
8991
lastPreviousTooCall = previous.toolCalls().get(previous.toolCalls().size() - 1);
9092
if (previous.toolCalls().size() > 1) {
9193
toolCalls.addAll(previous.toolCalls().subList(0, previous.toolCalls().size() - 1));
9294
}
9395
}
94-
if (current.toolCalls() != null) {
96+
if (!CollectionUtils.isEmpty(current.toolCalls())) {
9597
if (current.toolCalls().size() > 1) {
9698
throw new IllegalStateException("Currently only one tool call is supported per message!");
9799
}
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.deepseek.api;
18+
19+
import java.util.List;
20+
21+
import org.junit.jupiter.api.BeforeEach;
22+
import org.junit.jupiter.api.Test;
23+
24+
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionChunk;
25+
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage;
26+
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.ChatCompletionFunction;
27+
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Role;
28+
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.ToolCall;
29+
30+
import static org.assertj.core.api.Assertions.assertThat;
31+
32+
/**
33+
* Unit test for {@link DeepSeekStreamFunctionCallingHelper}.
34+
*
35+
* @author Sun Yuhan
36+
*/
37+
class DeepSeekStreamFunctionCallingHelperTest {
38+
39+
private DeepSeekStreamFunctionCallingHelper helper;
40+
41+
@BeforeEach
42+
void setUp() {
43+
this.helper = new DeepSeekStreamFunctionCallingHelper();
44+
}
45+
46+
@Test
47+
void mergeWhenPreviousIsNullShouldReturnCurrent() {
48+
// Given
49+
ChatCompletionChunk current = new ChatCompletionChunk("id1", List.of(), 123L, "model1", null, null, null, null);
50+
51+
// When
52+
ChatCompletionChunk result = this.helper.merge(null, current);
53+
54+
// Then
55+
assertThat(result).isEqualTo(current);
56+
}
57+
58+
@Test
59+
void mergeShouldMergeBasicFieldsFromCurrentAndPrevious() {
60+
// Given
61+
ChatCompletionChunk previous = new ChatCompletionChunk("id1", List.of(), 123L, "model1", null, null, null,
62+
null);
63+
ChatCompletionChunk current = new ChatCompletionChunk("id2", List.of(), null, null, null, null, null, null);
64+
65+
// When
66+
ChatCompletionChunk result = this.helper.merge(previous, current);
67+
68+
// Then
69+
assertThat(result.id()).isEqualTo("id2"); // from current
70+
assertThat(result.created()).isEqualTo(123L); // from previous
71+
assertThat(result.model()).isEqualTo("model1"); // from previous
72+
}
73+
74+
@Test
75+
void mergeShouldMergeMessagesContent() {
76+
// Given
77+
ChatCompletionMessage previousMsg = new ChatCompletionMessage("Hello ", Role.ASSISTANT, null, null, null);
78+
ChatCompletionMessage currentMsg = new ChatCompletionMessage("World!", Role.ASSISTANT, null, null, null);
79+
80+
ChatCompletionChunk previous = new ChatCompletionChunk("id",
81+
List.of(new ChatCompletionChunk.ChunkChoice(null, 0, previousMsg, null)), 123L, "model", null, null,
82+
null, null);
83+
84+
ChatCompletionChunk current = new ChatCompletionChunk("id",
85+
List.of(new ChatCompletionChunk.ChunkChoice(null, 0, currentMsg, null)), 123L, "model", null, null,
86+
null, null);
87+
88+
// When
89+
ChatCompletionChunk result = this.helper.merge(previous, current);
90+
91+
// Then
92+
assertThat(result.choices().get(0).delta().content()).isEqualTo("Hello World!");
93+
}
94+
95+
@Test
96+
void mergeShouldHandleToolCallsMerging() {
97+
// Given
98+
ChatCompletionFunction func1 = new ChatCompletionFunction("func1", "{\"arg1\":");
99+
ToolCall toolCall1 = new ToolCall("call_123", "function", func1);
100+
ChatCompletionMessage previousMsg = new ChatCompletionMessage(null, Role.ASSISTANT, null, null,
101+
List.of(toolCall1));
102+
103+
ChatCompletionFunction func2 = new ChatCompletionFunction("func1", "\"value1\"}");
104+
ToolCall toolCall2 = new ToolCall(null, "function", func2); // No ID -
105+
// continuation
106+
ChatCompletionMessage currentMsg = new ChatCompletionMessage(null, Role.ASSISTANT, null, null,
107+
List.of(toolCall2));
108+
109+
ChatCompletionChunk previous = new ChatCompletionChunk("id",
110+
List.of(new ChatCompletionChunk.ChunkChoice(null, 0, previousMsg, null)), 123L, "model", null, null,
111+
null, null);
112+
113+
ChatCompletionChunk current = new ChatCompletionChunk("id",
114+
List.of(new ChatCompletionChunk.ChunkChoice(null, 0, currentMsg, null)), 123L, "model", null, null,
115+
null, null);
116+
117+
// When
118+
ChatCompletionChunk result = this.helper.merge(previous, current);
119+
120+
// Then
121+
assertThat(result.choices()).hasSize(1);
122+
assertThat(result.choices().get(0).delta().toolCalls()).hasSize(1);
123+
ToolCall mergedToolCall = result.choices().get(0).delta().toolCalls().get(0);
124+
assertThat(mergedToolCall.id()).isEqualTo("call_123");
125+
assertThat(mergedToolCall.function().name()).isEqualTo("func1");
126+
assertThat(mergedToolCall.function().arguments()).isEqualTo("{\"arg1\":\"value1\"}");
127+
}
128+
129+
@Test
130+
void mergeWithSingleToolCallShouldWork() {
131+
// Given
132+
ToolCall toolCall = new ToolCall("call_1", "function", new ChatCompletionFunction("func1", "{}"));
133+
ChatCompletionMessage msg = new ChatCompletionMessage(null, Role.ASSISTANT, null, null, List.of(toolCall));
134+
135+
ChatCompletionChunk previous = new ChatCompletionChunk("id", List.of(), 123L, "model", null, null, null, null);
136+
ChatCompletionChunk current = new ChatCompletionChunk("id",
137+
List.of(new ChatCompletionChunk.ChunkChoice(null, 0, msg, null)), 123L, "model", null, null, null,
138+
null);
139+
140+
// When
141+
ChatCompletionChunk result = this.helper.merge(previous, current);
142+
143+
// Then
144+
assertThat(result).isNotNull();
145+
assertThat(result.choices().get(0).delta().toolCalls()).hasSize(1);
146+
}
147+
148+
@Test
149+
void isStreamingToolFunctionCallWhenNullChunkShouldReturnFalse() {
150+
// When & Then
151+
assertThat(this.helper.isStreamingToolFunctionCall(null)).isFalse();
152+
}
153+
154+
@Test
155+
void isStreamingToolFunctionCallWhenEmptyChoicesShouldReturnFalse() {
156+
// Given
157+
ChatCompletionChunk chunk = new ChatCompletionChunk("id", List.of(), 123L, "model", null, null, null, null);
158+
159+
// When & Then
160+
assertThat(this.helper.isStreamingToolFunctionCall(chunk)).isFalse();
161+
}
162+
163+
@Test
164+
void isStreamingToolFunctionCallWhenHasToolCallsShouldReturnTrue() {
165+
// Given
166+
ToolCall toolCall = new ToolCall("call_1", "function", new ChatCompletionFunction("func", "{}"));
167+
ChatCompletionMessage msg = new ChatCompletionMessage(null, Role.ASSISTANT, null, null, List.of(toolCall));
168+
ChatCompletionChunk chunk = new ChatCompletionChunk("id",
169+
List.of(new ChatCompletionChunk.ChunkChoice(null, 0, msg, null)), 123L, "model", null, null, null,
170+
null);
171+
172+
// When & Then
173+
assertThat(this.helper.isStreamingToolFunctionCall(chunk)).isTrue();
174+
}
175+
176+
@Test
177+
void isStreamingToolFunctionCallFinishWhenFinishReasonIsToolCallsShouldReturnTrue() {
178+
// Given
179+
ChatCompletionMessage msg = new ChatCompletionMessage(null, Role.ASSISTANT, null, null, null);
180+
ChatCompletionChunk.ChunkChoice choice = new ChatCompletionChunk.ChunkChoice(
181+
DeepSeekApi.ChatCompletionFinishReason.TOOL_CALLS, 0, msg, null);
182+
ChatCompletionChunk chunk = new ChatCompletionChunk("id", List.of(choice), 123L, "model", null, null, null,
183+
null);
184+
185+
// When & Then
186+
assertThat(this.helper.isStreamingToolFunctionCallFinish(chunk)).isTrue();
187+
}
188+
189+
@Test
190+
void mergeWhenCurrentToolCallsIsEmptyListShouldNotThrowException() {
191+
// Given
192+
ToolCall toolCall = new ToolCall("call_1", "function", new ChatCompletionFunction("func1", "{}"));
193+
ChatCompletionMessage previousMsg = new ChatCompletionMessage(null, Role.ASSISTANT, null, null,
194+
List.of(toolCall));
195+
196+
// Empty list instead of null
197+
ChatCompletionMessage currentMsg = new ChatCompletionMessage(null, Role.ASSISTANT, null, null, List.of());
198+
199+
ChatCompletionChunk previous = new ChatCompletionChunk("id",
200+
List.of(new ChatCompletionChunk.ChunkChoice(null, 0, previousMsg, null)), 123L, "model", null, null,
201+
null, null);
202+
203+
ChatCompletionChunk current = new ChatCompletionChunk("id",
204+
List.of(new ChatCompletionChunk.ChunkChoice(null, 0, currentMsg, null)), 123L, "model", null, null,
205+
null, null);
206+
207+
// When
208+
ChatCompletionChunk result = this.helper.merge(previous, current);
209+
210+
// Then
211+
assertThat(result).isNotNull();
212+
}
213+
214+
}

0 commit comments

Comments
 (0)