From 2a90c8adb3764fdda0fea573dbc1f8f0a0491b11 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Sat, 14 Jun 2025 10:56:36 -0400 Subject: [PATCH] chore: Inline event loop helper functions While reading the through the event loop, it made more sense to me to inline the implementation - most of the actual apis are signature + docs, while the actual code is 2 lines each. --- src/strands/event_loop/event_loop.py | 44 +------- tests/strands/event_loop/test_event_loop.py | 107 +++++++++++++------- 2 files changed, 78 insertions(+), 73 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 711659265..bbbdc9395 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -33,23 +33,6 @@ MAX_DELAY = 240 # 4 minutes -def initialize_state(**kwargs: Any) -> Any: - """Initialize the request state if not present. - - Creates an empty request_state dictionary if one doesn't already exist in the - provided keyword arguments. - - Args: - **kwargs: Keyword arguments that may contain a request_state. - - Returns: - The updated kwargs dictionary with request_state initialized if needed. - """ - if "request_state" not in kwargs: - kwargs["request_state"] = {} - return kwargs - - def event_loop_cycle( model: Model, system_prompt: Optional[str], @@ -107,7 +90,8 @@ def event_loop_cycle( event_loop_metrics: EventLoopMetrics = kwargs.get("event_loop_metrics", EventLoopMetrics()) # Initialize state and get cycle trace - kwargs = initialize_state(**kwargs) + if "request_state" not in kwargs: + kwargs["request_state"] = {} cycle_start_time, cycle_trace = event_loop_metrics.start_cycle() kwargs["event_loop_cycle_trace"] = cycle_trace @@ -309,26 +293,6 @@ def recurse_event_loop( ) -def prepare_next_cycle(kwargs: Dict[str, Any], event_loop_metrics: EventLoopMetrics) -> Dict[str, Any]: - """Prepare state for the next event loop cycle. - - Updates the keyword arguments with the current event loop metrics and stores the current cycle ID as the parent - cycle ID for the next cycle. This maintains the parent-child relationship between cycles for tracing and metrics. - - Args: - kwargs: Current keyword arguments containing event loop state. - event_loop_metrics: The metrics object tracking event loop execution. - - Returns: - Updated keyword arguments ready for the next cycle. - """ - # Store parent cycle ID - kwargs["event_loop_metrics"] = event_loop_metrics - kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"] - - return kwargs - - def _handle_tool_execution( stop_reason: StopReason, message: Message, @@ -402,7 +366,9 @@ def _handle_tool_execution( parallel_tool_executor=tool_execution_handler, ) - kwargs = prepare_next_cycle(kwargs, event_loop_metrics) + # Store parent cycle ID for the next cycle + kwargs["event_loop_metrics"] = event_loop_metrics + kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"] tool_result_message: Message = { "role": "user", diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 8c46e009b..4849b926a 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -104,27 +104,6 @@ def mock_tracer(): return tracer -@pytest.mark.parametrize( - ("kwargs", "exp_state"), - [ - ( - {"request_state": {"key1": "value1"}}, - {"key1": "value1"}, - ), - ( - {}, - {}, - ), - ], -) -def test_initialize_state(kwargs, exp_state): - kwargs = strands.event_loop.event_loop.initialize_state(**kwargs) - - tru_state = kwargs["request_state"] - - assert tru_state == exp_state - - def test_event_loop_cycle_text_response( model, model_id, @@ -411,19 +390,6 @@ def test_event_loop_cycle_stop( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -def test_prepare_next_cycle(): - kwargs = {"event_loop_cycle_id": "c1"} - event_loop_metrics = strands.telemetry.metrics.EventLoopMetrics() - tru_result = strands.event_loop.event_loop.prepare_next_cycle(kwargs, event_loop_metrics) - exp_result = { - "event_loop_cycle_id": "c1", - "event_loop_parent_cycle_id": "c1", - "event_loop_metrics": event_loop_metrics, - } - - assert tru_result == exp_result - - def test_cycle_exception( model, system_prompt, @@ -679,3 +645,76 @@ def test_event_loop_cycle_with_parent_span( mock_tracer.start_event_loop_cycle_span.assert_called_once_with( event_loop_kwargs=unittest.mock.ANY, parent_span=parent_span, messages=messages ) + + +def test_request_state_initialization(): + # Call without providing request_state + tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( + model=MagicMock(), + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + tool_execution_handler=MagicMock(), + ) + + # Verify request_state was initialized to empty dict + assert tru_request_state == {} + + # Call with pre-existing request_state + initial_request_state = {"key": "value"} + tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( + model=MagicMock(), + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + request_state=initial_request_state, + ) + + # Verify existing request_state was preserved + assert tru_request_state == initial_request_state + + +def test_prepare_next_cycle_in_tool_execution(model, tool_stream): + """Test that cycle ID and metrics are properly updated during tool execution.""" + model.converse.side_effect = [ + tool_stream, + [ + {"contentBlockStop": {}}, + ], + ] + + # Create a mock for recurse_event_loop to capture the kwargs passed to it + with unittest.mock.patch.object(strands.event_loop.event_loop, "recurse_event_loop") as mock_recurse: + # Set up mock to return a valid response + mock_recurse.return_value = ( + "end_turn", + {"role": "assistant", "content": [{"text": "test text"}]}, + strands.telemetry.metrics.EventLoopMetrics(), + {}, + ) + + # Call event_loop_cycle which should execute a tool and then call recurse_event_loop + strands.event_loop.event_loop.event_loop_cycle( + model=model, + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + tool_execution_handler=MagicMock(), + ) + + assert mock_recurse.called + + # Verify required properties are present + recursive_kwargs = mock_recurse.call_args[1] + assert "event_loop_metrics" in recursive_kwargs + assert "event_loop_parent_cycle_id" in recursive_kwargs + assert recursive_kwargs["event_loop_parent_cycle_id"] == recursive_kwargs["event_loop_cycle_id"]