Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import os
import sys
import pytest
from typing import Generator, Any
from unittest.mock import Mock, MagicMock
from fastapi.testclient import TestClient

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from openai_harmony import (
HarmonyEncodingName,
load_harmony_encoding,
)
from gpt_oss.responses_api.api_server import create_api_server


@pytest.fixture(scope="session")
def harmony_encoding():
return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)


@pytest.fixture
def mock_infer_token(harmony_encoding):
fake_tokens = harmony_encoding.encode(
"<|channel|>final<|message|>Test response<|return|>",
allowed_special="all"
)
token_queue = fake_tokens.copy()

def _mock_infer(tokens: list[int], temperature: float = 0.0, new_request: bool = False) -> int:
nonlocal token_queue
if len(token_queue) == 0:
token_queue = fake_tokens.copy()
return token_queue.pop(0)
return _mock_infer


@pytest.fixture
def api_client(harmony_encoding, mock_infer_token) -> Generator[TestClient, None, None]:
app = create_api_server(
infer_next_token=mock_infer_token,
encoding=harmony_encoding
)
with TestClient(app) as client:
yield client


@pytest.fixture
def sample_request_data():
return {
"model": "gpt-oss-120b",
"input": "Hello, how can I help you today?",
"stream": False,
"reasoning_effort": "low",
"temperature": 0.7,
"tools": []
}


@pytest.fixture
def mock_browser_tool():
mock = MagicMock()
mock.search.return_value = ["Result 1", "Result 2"]
mock.open_page.return_value = "Page content"
mock.find_on_page.return_value = "Found text"
return mock


@pytest.fixture
def mock_python_tool():
mock = MagicMock()
mock.execute.return_value = {
"output": "print('Hello')",
"error": None,
"exit_code": 0
}
return mock


@pytest.fixture(autouse=True)
def reset_test_environment():
test_env_vars = ['OPENAI_API_KEY', 'GPT_OSS_MODEL_PATH']
original_values = {}

for var in test_env_vars:
if var in os.environ:
original_values[var] = os.environ[var]
del os.environ[var]

yield

for var, value in original_values.items():
os.environ[var] = value


@pytest.fixture
def performance_timer():
import time

class Timer:
def __init__(self):
self.start_time = None
self.end_time = None

def start(self):
self.start_time = time.time()

def stop(self):
self.end_time = time.time()
return self.elapsed

@property
def elapsed(self):
if self.start_time and self.end_time:
return self.end_time - self.start_time
return None

return Timer()
230 changes: 230 additions & 0 deletions tests/test_api_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
import pytest
import json
import asyncio
from fastapi import status
from unittest.mock import patch, MagicMock, AsyncMock


class TestResponsesEndpoint:

def test_basic_response_creation(self, api_client, sample_request_data):
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "id" in data
assert data["object"] == "response"
assert data["model"] == sample_request_data["model"]

def test_response_with_high_reasoning(self, api_client, sample_request_data):
sample_request_data["reasoning_effort"] = "high"
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "id" in data
assert data["status"] == "completed"

def test_response_with_medium_reasoning(self, api_client, sample_request_data):
sample_request_data["reasoning_effort"] = "medium"
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "id" in data
assert data["status"] == "completed"

def test_response_with_invalid_model(self, api_client, sample_request_data):
sample_request_data["model"] = "invalid-model"
response = api_client.post("/v1/responses", json=sample_request_data)
# Should still accept but might handle differently
assert response.status_code == status.HTTP_200_OK

def test_response_with_empty_input(self, api_client, sample_request_data):
sample_request_data["input"] = ""
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK

def test_response_with_tools(self, api_client, sample_request_data):
sample_request_data["tools"] = [
{
"type": "browser_search"
}
]
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK

def test_response_with_custom_temperature(self, api_client, sample_request_data):
for temp in [0.0, 0.5, 1.0, 1.5, 2.0]:
sample_request_data["temperature"] = temp
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "usage" in data

def test_streaming_response(self, api_client, sample_request_data):
sample_request_data["stream"] = True
with api_client.stream("POST", "/v1/responses", json=sample_request_data) as response:
assert response.status_code == status.HTTP_200_OK
# Verify we get SSE events
for line in response.iter_lines():
if line and line.startswith("data: "):
event_data = line[6:] # Remove "data: " prefix
if event_data != "[DONE]":
json.loads(event_data) # Should be valid JSON
break


class TestResponsesWithSession:

def test_response_with_session_id(self, api_client, sample_request_data):
session_id = "test-session-123"
sample_request_data["session_id"] = session_id

# First request
response1 = api_client.post("/v1/responses", json=sample_request_data)
assert response1.status_code == status.HTTP_200_OK
data1 = response1.json()

# Second request with same session
sample_request_data["input"] = "Follow up question"
response2 = api_client.post("/v1/responses", json=sample_request_data)
assert response2.status_code == status.HTTP_200_OK
data2 = response2.json()

# Should have different response IDs
assert data1["id"] != data2["id"]

def test_response_continuation(self, api_client, sample_request_data):
# Create initial response
response1 = api_client.post("/v1/responses", json=sample_request_data)
assert response1.status_code == status.HTTP_200_OK
data1 = response1.json()
response_id = data1["id"]

# Continue the response
continuation_request = {
"model": sample_request_data["model"],
"response_id": response_id,
"input": "Continue the previous thought"
}
response2 = api_client.post("/v1/responses", json=continuation_request)
assert response2.status_code == status.HTTP_200_OK


class TestErrorHandling:

def test_missing_required_fields(self, api_client):
# Model field has default, so test with empty JSON
response = api_client.post("/v1/responses", json={})
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY

def test_invalid_reasoning_effort(self, api_client, sample_request_data):
sample_request_data["reasoning_effort"] = "invalid"
response = api_client.post("/v1/responses", json=sample_request_data)
# May handle gracefully or return error
assert response.status_code in [status.HTTP_200_OK, status.HTTP_422_UNPROCESSABLE_ENTITY]

def test_malformed_json(self, api_client):
response = api_client.post(
"/v1/responses",
data="not json",
headers={"Content-Type": "application/json"}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY

def test_extremely_long_input(self, api_client, sample_request_data):
# Test with very long input
sample_request_data["input"] = "x" * 100000
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK


class TestToolIntegration:

def test_browser_search_tool(self, api_client, sample_request_data):
sample_request_data["tools"] = [
{
"type": "browser_search"
}
]
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK

def test_function_tool_integration(self, api_client, sample_request_data):
sample_request_data["tools"] = [
{
"type": "function",
"name": "test_function",
"parameters": {"type": "object", "properties": {}},
"description": "Test function"
}
]
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK

def test_multiple_tools(self, api_client, sample_request_data):
sample_request_data["tools"] = [
{
"type": "browser_search"
},
{
"type": "function",
"name": "test_function",
"parameters": {"type": "object", "properties": {}},
"description": "Test function"
}
]
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK


class TestPerformance:

def test_response_time_under_threshold(self, api_client, sample_request_data, performance_timer):
performance_timer.start()
response = api_client.post("/v1/responses", json=sample_request_data)
elapsed = performance_timer.stop()

assert response.status_code == status.HTTP_200_OK
# Response should be reasonably fast for mock inference
assert elapsed < 5.0 # 5 seconds threshold

def test_multiple_sequential_requests(self, api_client, sample_request_data):
# Test multiple requests work correctly
for i in range(3):
data = sample_request_data.copy()
data["input"] = f"Request {i}"
response = api_client.post("/v1/responses", json=data)
assert response.status_code == status.HTTP_200_OK


class TestUsageTracking:

def test_usage_object_structure(self, api_client, sample_request_data):
response = api_client.post("/v1/responses", json=sample_request_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()

assert "usage" in data
usage = data["usage"]
assert "input_tokens" in usage
assert "output_tokens" in usage
assert "total_tokens" in usage
# reasoning_tokens may not always be present
# assert "reasoning_tokens" in usage

# Basic validation
assert usage["input_tokens"] >= 0
assert usage["output_tokens"] >= 0
assert usage["total_tokens"] == usage["input_tokens"] + usage["output_tokens"]

def test_usage_increases_with_longer_input(self, api_client, sample_request_data):
# Short input
response1 = api_client.post("/v1/responses", json=sample_request_data)
usage1 = response1.json()["usage"]

# Longer input
sample_request_data["input"] = sample_request_data["input"] * 10
response2 = api_client.post("/v1/responses", json=sample_request_data)
usage2 = response2.json()["usage"]

# Longer input should use more tokens
assert usage2["input_tokens"] > usage1["input_tokens"]