Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing_extensions import Unpack, override

from ..types.content import ContentBlock, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
from ._validation import validate_config_keys
Expand Down Expand Up @@ -372,6 +373,10 @@ async def stream(

Yields:
Formatted message chunks from the model.

Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the request is throttled by OpenAI (rate limits).
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
Expand All @@ -383,7 +388,20 @@ async def stream(
# client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
# https://github.com/encode/httpx/discussions/2959.
async with openai.AsyncOpenAI(**self.client_args) as client:
response = await client.chat.completions.create(**request)
try:
response = await client.chat.completions.create(**request)
except openai.BadRequestError as e:
# Check if this is a context length exceeded error
if hasattr(e, "code") and e.code == "context_length_exceeded":
logger.warning("OpenAI threw context window overflow error")
raise ContextWindowOverflowException(str(e)) from e
# Re-raise other BadRequestError exceptions
raise
except openai.RateLimitError as e:
# All rate limit errors should be treated as throttling, not context overflow
# Rate limits (including TPM) require waiting/retrying, not context reduction
logger.warning("OpenAI threw rate limit error")
raise ModelThrottledException(str(e)) from e

logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
Expand Down Expand Up @@ -452,16 +470,33 @@ async def structured_output(

Yields:
Model events with the last being the structured output.

Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the request is throttled by OpenAI (rate limits).
"""
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx
# client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
# https://github.com/encode/httpx/discussions/2959.
async with openai.AsyncOpenAI(**self.client_args) as client:
response: ParsedChatCompletion = await client.beta.chat.completions.parse(
model=self.get_config()["model_id"],
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
response_format=output_model,
)
try:
response: ParsedChatCompletion = await client.beta.chat.completions.parse(
model=self.get_config()["model_id"],
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
response_format=output_model,
)
except openai.BadRequestError as e:
# Check if this is a context length exceeded error
if hasattr(e, "code") and e.code == "context_length_exceeded":
logger.warning("OpenAI threw context window overflow error")
raise ContextWindowOverflowException(str(e)) from e
# Re-raise other BadRequestError exceptions
raise
except openai.RateLimitError as e:
# All rate limit errors should be treated as throttling, not context overflow
# Rate limits (including TPM) require waiting/retrying, not context reduction
logger.warning("OpenAI threw rate limit error")
raise ModelThrottledException(str(e)) from e

parsed: T | None = None
# Find the first choice with tool_calls
Expand Down
148 changes: 148 additions & 0 deletions tests/strands/models/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import unittest.mock

import openai
import pydantic
import pytest

import strands
from strands.models.openai import OpenAIModel
from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException


@pytest.fixture
Expand Down Expand Up @@ -752,3 +754,149 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings):
model.format_request(messages, tool_choice=None)

assert len(captured_warnings) == 0


@pytest.mark.asyncio
async def test_stream_context_overflow_exception(openai_client, model, messages):
"""Test that OpenAI context overflow errors are properly converted to ContextWindowOverflowException."""
# Create a mock OpenAI BadRequestError with context_length_exceeded code
mock_error = openai.BadRequestError(
message="This model's maximum context length is 4096 tokens. However, your messages resulted in 5000 tokens.",
response=unittest.mock.MagicMock(),
body={"error": {"code": "context_length_exceeded"}},
)
mock_error.code = "context_length_exceeded"

# Configure the mock client to raise the context overflow error
openai_client.chat.completions.create.side_effect = mock_error

# Test that the stream method converts the error properly
with pytest.raises(ContextWindowOverflowException) as exc_info:
async for _ in model.stream(messages):
pass

# Verify the exception message contains the original error
assert "maximum context length" in str(exc_info.value)
assert exc_info.value.__cause__ == mock_error


@pytest.mark.asyncio
async def test_stream_other_bad_request_errors_passthrough(openai_client, model, messages):
"""Test that other BadRequestError exceptions are not converted to ContextWindowOverflowException."""
# Create a mock OpenAI BadRequestError with a different error code
mock_error = openai.BadRequestError(
message="Invalid parameter value",
response=unittest.mock.MagicMock(),
body={"error": {"code": "invalid_parameter"}},
)
mock_error.code = "invalid_parameter"

# Configure the mock client to raise the non-context error
openai_client.chat.completions.create.side_effect = mock_error

# Test that other BadRequestError exceptions pass through unchanged
with pytest.raises(openai.BadRequestError) as exc_info:
async for _ in model.stream(messages):
pass

# Verify the original exception is raised, not ContextWindowOverflowException
assert exc_info.value == mock_error


@pytest.mark.asyncio
async def test_structured_output_context_overflow_exception(openai_client, model, messages, test_output_model_cls):
"""Test that structured output also handles context overflow properly."""
# Create a mock OpenAI BadRequestError with context_length_exceeded code
mock_error = openai.BadRequestError(
message="This model's maximum context length is 4096 tokens. However, your messages resulted in 5000 tokens.",
response=unittest.mock.MagicMock(),
body={"error": {"code": "context_length_exceeded"}},
)
mock_error.code = "context_length_exceeded"

# Configure the mock client to raise the context overflow error
openai_client.beta.chat.completions.parse.side_effect = mock_error

# Test that the structured_output method converts the error properly
with pytest.raises(ContextWindowOverflowException) as exc_info:
async for _ in model.structured_output(test_output_model_cls, messages):
pass

# Verify the exception message contains the original error
assert "maximum context length" in str(exc_info.value)
assert exc_info.value.__cause__ == mock_error


@pytest.mark.asyncio
async def test_stream_rate_limit_as_throttle(openai_client, model, messages):
"""Test that all rate limit errors are converted to ModelThrottledException."""

# Create a mock OpenAI RateLimitError (any type of rate limit)
mock_error = openai.RateLimitError(
message="Request too large for gpt-4o on tokens per min (TPM): Limit 30000, Requested 117505.",
response=unittest.mock.MagicMock(),
body={"error": {"code": "rate_limit_exceeded"}},
)
mock_error.code = "rate_limit_exceeded"

# Configure the mock client to raise the rate limit error
openai_client.chat.completions.create.side_effect = mock_error

# Test that the stream method converts the error properly
with pytest.raises(ModelThrottledException) as exc_info:
async for _ in model.stream(messages):
pass

# Verify the exception message contains the original error
assert "tokens per min" in str(exc_info.value)
assert exc_info.value.__cause__ == mock_error


@pytest.mark.asyncio
async def test_stream_request_rate_limit_as_throttle(openai_client, model, messages):
"""Test that request-based rate limit errors are converted to ModelThrottledException."""

# Create a mock OpenAI RateLimitError for request-based rate limiting
mock_error = openai.RateLimitError(
message="Rate limit reached for requests per minute.",
response=unittest.mock.MagicMock(),
body={"error": {"code": "rate_limit_exceeded"}},
)
mock_error.code = "rate_limit_exceeded"

# Configure the mock client to raise the request rate limit error
openai_client.chat.completions.create.side_effect = mock_error

# Test that the stream method converts the error properly
with pytest.raises(ModelThrottledException) as exc_info:
async for _ in model.stream(messages):
pass

# Verify the exception message contains the original error
assert "Rate limit reached" in str(exc_info.value)
assert exc_info.value.__cause__ == mock_error


@pytest.mark.asyncio
async def test_structured_output_rate_limit_as_throttle(openai_client, model, messages, test_output_model_cls):
"""Test that structured output handles rate limit errors properly."""

# Create a mock OpenAI RateLimitError
mock_error = openai.RateLimitError(
message="Request too large for gpt-4o on tokens per min (TPM): Limit 30000, Requested 117505.",
response=unittest.mock.MagicMock(),
body={"error": {"code": "rate_limit_exceeded"}},
)
mock_error.code = "rate_limit_exceeded"

# Configure the mock client to raise the rate limit error
openai_client.beta.chat.completions.parse.side_effect = mock_error

# Test that the structured_output method converts the error properly
with pytest.raises(ModelThrottledException) as exc_info:
async for _ in model.structured_output(test_output_model_cls, messages):
pass

# Verify the exception message contains the original error
assert "tokens per min" in str(exc_info.value)
assert exc_info.value.__cause__ == mock_error
54 changes: 54 additions & 0 deletions tests_integ/models/test_model_openai.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
import unittest.mock

import pydantic
import pytest

import strands
from strands import Agent, tool
from strands.models.openai import OpenAIModel
from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException
from tests_integ.models import providers

# these tests only run if we have the openai api key
Expand Down Expand Up @@ -167,3 +169,55 @@ def tool_with_image_return():
# 'user', but this message with role 'tool' contains an image URL."
# See https://github.com/strands-agents/sdk-python/issues/320 for additional details
agent("Run the the tool and analyze the image")


def test_context_window_overflow_integration():
"""Integration test for context window overflow with OpenAI.

This test verifies that when a request exceeds the model's context window,
the OpenAI model properly raises a ContextWindowOverflowException.
"""
# Use gpt-4o-mini which has a smaller context window to make this test more reliable
mini_model = OpenAIModel(
model_id="gpt-4o-mini-2024-07-18",
client_args={
"api_key": os.getenv("OPENAI_API_KEY"),
},
)

agent = Agent(model=mini_model)

# Create a very long text that should exceed context window
# This text is designed to be long enough to exceed context but not hit token rate limits
long_text = (
"This text is longer than context window, but short enough to not get caught in token rate limit. " * 6800
)

# This should raise ContextWindowOverflowException which gets handled by conversation manager
# The agent should attempt to reduce context and retry
with pytest.raises(ContextWindowOverflowException):
agent(long_text)


def test_rate_limit_throttling_integration_no_retries(model):
"""Integration test for rate limit handling with retries disabled.

This test verifies that when a request exceeds OpenAI's rate limits,
the model properly raises a ModelThrottledException. We disable retries
to avoid waiting for the exponential backoff during testing.
"""
# Patch the event loop constants to disable retries for this test
with unittest.mock.patch("strands.event_loop.event_loop.MAX_ATTEMPTS", 1):
agent = Agent(model=model)

# Create a message that's very long to trigger token-per-minute rate limits
# This should be large enough to exceed TPM limits immediately
very_long_text = "Really long text " * 20000

# This should raise ModelThrottledException without retries
with pytest.raises(ModelThrottledException) as exc_info:
agent(very_long_text)

# Verify it's a rate limit error
error_message = str(exc_info.value).lower()
assert "rate limit" in error_message or "tokens per min" in error_message
Loading