Skip to content

Commit 3a7af77

Browse files
authored
models - litellm - start and stop reasoning (#947)
1 parent 26862e4 commit 3a7af77

File tree

3 files changed

+104
-21
lines changed

3 files changed

+104
-21
lines changed

src/strands/models/litellm.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,26 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]
111111

112112
return super().format_request_message_content(content)
113113

114+
def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> tuple[list[StreamEvent], str]:
115+
"""Handle switching to a new content stream.
116+
117+
Args:
118+
data_type: The next content data type.
119+
prev_data_type: The previous content data type.
120+
121+
Returns:
122+
Tuple containing:
123+
- Stop block for previous content and the start block for the next content.
124+
- Next content data type.
125+
"""
126+
chunks = []
127+
if data_type != prev_data_type:
128+
if prev_data_type is not None:
129+
chunks.append(self.format_chunk({"chunk_type": "content_stop", "data_type": prev_data_type}))
130+
chunks.append(self.format_chunk({"chunk_type": "content_start", "data_type": data_type}))
131+
132+
return chunks, data_type
133+
114134
@override
115135
async def stream(
116136
self,
@@ -146,38 +166,46 @@ async def stream(
146166

147167
logger.debug("got response from model")
148168
yield self.format_chunk({"chunk_type": "message_start"})
149-
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
150169

151170
tool_calls: dict[int, list[Any]] = {}
171+
data_type: str | None = None
152172

153173
async for event in response:
154174
# Defensive: skip events with empty or missing choices
155175
if not getattr(event, "choices", None):
156176
continue
157177
choice = event.choices[0]
158178

159-
if choice.delta.content:
160-
yield self.format_chunk(
161-
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
162-
)
163-
164179
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
180+
chunks, data_type = self._stream_switch_content("reasoning_content", data_type)
181+
for chunk in chunks:
182+
yield chunk
183+
165184
yield self.format_chunk(
166185
{
167186
"chunk_type": "content_delta",
168-
"data_type": "reasoning_content",
187+
"data_type": data_type,
169188
"data": choice.delta.reasoning_content,
170189
}
171190
)
172191

192+
if choice.delta.content:
193+
chunks, data_type = self._stream_switch_content("text", data_type)
194+
for chunk in chunks:
195+
yield chunk
196+
197+
yield self.format_chunk(
198+
{"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content}
199+
)
200+
173201
for tool_call in choice.delta.tool_calls or []:
174202
tool_calls.setdefault(tool_call.index, []).append(tool_call)
175203

176204
if choice.finish_reason:
205+
if data_type:
206+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type})
177207
break
178208

179-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
180-
181209
for tool_deltas in tool_calls.values():
182210
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
183211

tests/strands/models/test_litellm.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -142,39 +142,71 @@ def test_format_request_message_content(content, exp_result):
142142

143143
@pytest.mark.asyncio
144144
async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, alist):
145-
mock_tool_call_1_part_1 = unittest.mock.Mock(index=0)
146-
mock_tool_call_2_part_1 = unittest.mock.Mock(index=1)
147145
mock_delta_1 = unittest.mock.Mock(
148146
reasoning_content="",
149147
content=None,
150148
tool_calls=None,
151149
)
150+
152151
mock_delta_2 = unittest.mock.Mock(
153152
reasoning_content="\nI'm thinking",
154153
content=None,
155154
tool_calls=None,
156155
)
157156
mock_delta_3 = unittest.mock.Mock(
157+
reasoning_content=None,
158+
content="One second",
159+
tool_calls=None,
160+
)
161+
mock_delta_4 = unittest.mock.Mock(
162+
reasoning_content="\nI'm think",
163+
content=None,
164+
tool_calls=None,
165+
)
166+
mock_delta_5 = unittest.mock.Mock(
167+
reasoning_content="ing again",
168+
content=None,
169+
tool_calls=None,
170+
)
171+
172+
mock_tool_call_1_part_1 = unittest.mock.Mock(index=0)
173+
mock_tool_call_2_part_1 = unittest.mock.Mock(index=1)
174+
mock_delta_6 = unittest.mock.Mock(
158175
content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1], reasoning_content=None
159176
)
160177

161178
mock_tool_call_1_part_2 = unittest.mock.Mock(index=0)
162179
mock_tool_call_2_part_2 = unittest.mock.Mock(index=1)
163-
mock_delta_4 = unittest.mock.Mock(
180+
mock_delta_7 = unittest.mock.Mock(
164181
content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2], reasoning_content=None
165182
)
166183

167-
mock_delta_5 = unittest.mock.Mock(content="", tool_calls=None, reasoning_content=None)
184+
mock_delta_8 = unittest.mock.Mock(content="", tool_calls=None, reasoning_content=None)
168185

169186
mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)])
170187
mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)])
171188
mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_3)])
172189
mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_4)])
173-
mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)])
174-
mock_event_6 = unittest.mock.Mock()
190+
mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_5)])
191+
mock_event_6 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_6)])
192+
mock_event_7 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_7)])
193+
mock_event_8 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_8)])
194+
mock_event_9 = unittest.mock.Mock()
175195

176196
litellm_acompletion.side_effect = unittest.mock.AsyncMock(
177-
return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6])
197+
return_value=agenerator(
198+
[
199+
mock_event_1,
200+
mock_event_2,
201+
mock_event_3,
202+
mock_event_4,
203+
mock_event_5,
204+
mock_event_6,
205+
mock_event_7,
206+
mock_event_8,
207+
mock_event_9,
208+
]
209+
)
178210
)
179211

180212
messages = [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]
@@ -184,6 +216,15 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator,
184216
{"messageStart": {"role": "assistant"}},
185217
{"contentBlockStart": {"start": {}}},
186218
{"contentBlockDelta": {"delta": {"reasoningContent": {"text": "\nI'm thinking"}}}},
219+
{"contentBlockStop": {}},
220+
{"contentBlockStart": {"start": {}}},
221+
{"contentBlockDelta": {"delta": {"text": "One second"}}},
222+
{"contentBlockStop": {}},
223+
{"contentBlockStart": {"start": {}}},
224+
{"contentBlockDelta": {"delta": {"reasoningContent": {"text": "\nI'm think"}}}},
225+
{"contentBlockDelta": {"delta": {"reasoningContent": {"text": "ing again"}}}},
226+
{"contentBlockStop": {}},
227+
{"contentBlockStart": {"start": {}}},
187228
{"contentBlockDelta": {"delta": {"text": "I'll calculate"}}},
188229
{"contentBlockDelta": {"delta": {"text": "that for you"}}},
189230
{"contentBlockStop": {}},
@@ -211,9 +252,9 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator,
211252
{
212253
"metadata": {
213254
"usage": {
214-
"inputTokens": mock_event_6.usage.prompt_tokens,
215-
"outputTokens": mock_event_6.usage.completion_tokens,
216-
"totalTokens": mock_event_6.usage.total_tokens,
255+
"inputTokens": mock_event_9.usage.prompt_tokens,
256+
"outputTokens": mock_event_9.usage.completion_tokens,
257+
"totalTokens": mock_event_9.usage.total_tokens,
217258
},
218259
"metrics": {"latencyMs": 0},
219260
}
@@ -253,8 +294,6 @@ async def test_stream_empty(litellm_acompletion, api_key, model_id, model, agene
253294
tru_events = await alist(response)
254295
exp_events = [
255296
{"messageStart": {"role": "assistant"}},
256-
{"contentBlockStart": {"start": {}}},
257-
{"contentBlockStop": {}},
258297
{"messageStop": {"stopReason": "end_turn"}},
259298
]
260299

tests_integ/models/test_model_litellm.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,22 @@ async def test_agent_stream_async(agent):
121121
assert all(string in text for string in ["12:00", "sunny"])
122122

123123

124+
def test_agent_invoke_reasoning(agent, model):
125+
model.update_config(
126+
params={
127+
"thinking": {
128+
"budget_tokens": 1024,
129+
"type": "enabled",
130+
},
131+
},
132+
)
133+
134+
result = agent("Please reason about the equation 2+2.")
135+
136+
assert "reasoningContent" in result.message["content"][0]
137+
assert result.message["content"][0]["reasoningContent"]["reasoningText"]["text"]
138+
139+
124140
def test_structured_output(agent, weather):
125141
tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
126142
exp_weather = weather

0 commit comments

Comments
 (0)