Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 141 additions & 1 deletion tests/unit/backends/test_openai_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from __future__ import annotations

from unittest.mock import Mock, patch
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock, patch

import httpx
import pytest
Expand Down Expand Up @@ -33,6 +35,46 @@ def test_usage_metrics():
assert metrics_with_values.text_characters == 50


class FakeHandler:
def __init__(self):
self.lines = []

def add_streaming_line(self, chunk):
if chunk == "[END]":
return None
self.lines.append(chunk)
return 1

def compile_streaming(self, request):
return {"compiled": "".join(self.lines)}

def compile_non_streaming(self, request, data):
return {"non_streamed": data}


def make_request(stream=True, body=None):
args = SimpleNamespace(
stream=stream,
method=None,
params=None,
headers=None,
body=body,
files=None,
)
return SimpleNamespace(request_type="chat_completions", arguments=args)


def make_request_info():
timings = SimpleNamespace(
request_start=None,
request_end=None,
first_iteration=None,
last_iteration=None,
iterations=None,
)
return SimpleNamespace(timings=timings)


class TestOpenAIHTTPBackend:
"""Test cases for OpenAIHTTPBackend."""

Expand Down Expand Up @@ -467,3 +509,101 @@ async def test_resolve_chat_completions(self):
assert len(responses) == 1
final_response = responses[0][0]
assert final_response.request_id == "test-id"

@pytest.mark.regression
@pytest.mark.asyncio
@async_timeout(10.0)
async def test_resolve_chat_completions_streaming(self):
backend = OpenAIHTTPBackend(target="http://test")
fake_handler = FakeHandler()

lines = ["hello\n", " world\n", "[END]"]

async def aiter_lines_gen():
for line in lines:
await asyncio.sleep(0)
yield line

stream_obj = SimpleNamespace(
raise_for_status=lambda: None,
aiter_lines=aiter_lines_gen,
)

mock_cm = AsyncMock()
mock_cm.__aenter__.return_value = stream_obj
mock_cm.__aexit__.return_value = None

mock_stream = Mock(return_value=mock_cm)

with (
patch.object(
backend, "_resolve_response_handler", lambda request_type: fake_handler
),
patch.object(backend, "_async_client", SimpleNamespace(stream=mock_stream)),
):
request = make_request(stream=True)
request_info = make_request_info()

results = []
async for resp, info in backend.resolve(request, request_info):
results.append((resp, info))

mock_stream.assert_called_once() # ensure stream was called
# one final compiled streaming result expected
assert len(results) == 1
compiled, returned_info = results[0]
assert compiled == {
"compiled": "hello\n world\n"
} # [END] treated as end marker
assert returned_info is request_info

# timings should be recorded
assert returned_info.timings.request_start is not None
assert returned_info.timings.request_end is not None
assert returned_info.timings.first_iteration is not None
assert returned_info.timings.last_iteration is not None
assert (
returned_info.timings.iterations == 2
) # two real iterations from our handler

@pytest.mark.regression
@pytest.mark.asyncio
@async_timeout(10.0)
async def test_resolve_chat_completions_streaming_cancelled(self):
backend = OpenAIHTTPBackend(target="http://test")
fake_handler = FakeHandler()

async def aiter_lines_gen():
await asyncio.sleep(0)
yield "partial\n"
await asyncio.sleep(0)
raise asyncio.CancelledError

stream_obj = SimpleNamespace(
raise_for_status=lambda: None,
aiter_lines=aiter_lines_gen,
)

mock_cm = AsyncMock()
mock_cm.__aenter__.return_value = stream_obj
mock_cm.__aexit__.return_value = None

mock_stream = Mock(return_value=mock_cm)

with (
patch.object(
backend, "_resolve_response_handler", lambda request_type: fake_handler
),
patch.object(backend, "_async_client", SimpleNamespace(stream=mock_stream)),
):
request = make_request(stream=True)
request_info = make_request_info()

agen = backend.resolve(request, request_info).__aiter__()

first = await agen.__anext__()
compiled, info = first
assert compiled == {"compiled": "partial\n"}

with pytest.raises(asyncio.CancelledError):
await agen.__anext__()
Loading