Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 56 additions & 116 deletions tests/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def test_hallucinations():
assert chunk.get("content") == "22"
break

code = """{
"language": "python",
"code": "10+12"
code = """{
"language": "python",
"code": "10+12"
}"""

interpreter.messages = [
Expand All @@ -50,9 +50,9 @@ def test_hallucinations():
assert chunk.get("content") == "22"
break

code = """functions.execute({
"language": "python",
"code": "10+12"
code = """functions.execute({
"language": "python",
"code": "10+12"
})"""

interpreter.messages = [
Expand Down Expand Up @@ -236,6 +236,37 @@ def run_server():
async_interpreter.server.run()


async def wait_for_websocket_complete(websocket, max_attempts=5):
"""Wait for WebSocket 'complete' status message with retry limit."""

import asyncio
import json

accumulated_content = ""

for attempt in range(1, max_attempts + 1):
try:
message = await websocket.recv()
message_data = json.loads(message)
if "error" in message_data:
raise Exception(message_data["content"])
print("Received from WebSocket:", message_data)
if type(message_data.get("content")) == str:
accumulated_content += message_data.get("content")
if message_data == {
"role": "server",
"type": "status",
"content": "complete",
}:
print("Received expected message from server")
return accumulated_content
except Exception as e:
print(f"WebSocket receive failed (attempt {attempt}/{max_attempts}): {e}")
await asyncio.sleep(1)
else:
raise Exception(f"Never received 'complete' status after {max_attempts} attempts")


# @pytest.mark.skip(reason="Requires uvicorn, which we don't require by default")
def test_server():
# Start the server in a new process
Expand Down Expand Up @@ -299,22 +330,7 @@ async def test_fastapi_server():
print("WebSocket chunks sent")

# Wait for a specific response
accumulated_content = ""
while True:
message = await websocket.recv()
message_data = json.loads(message)
if "error" in message_data:
raise Exception(message_data["content"])
print("Received from WebSocket:", message_data)
if type(message_data.get("content")) == str:
accumulated_content += message_data.get("content")
if message_data == {
"role": "server",
"type": "status",
"content": "complete",
}:
print("Received expected message from server")
break
accumulated_content = await wait_for_websocket_complete(websocket)

assert "crunk" in accumulated_content

Expand Down Expand Up @@ -355,22 +371,7 @@ async def test_fastapi_server():
print("WebSocket chunks sent")

# Wait for a specific response
accumulated_content = ""
while True:
message = await websocket.recv()
message_data = json.loads(message)
if "error" in message_data:
raise Exception(message_data["content"])
print("Received from WebSocket:", message_data)
if message_data.get("content"):
accumulated_content += message_data.get("content")
if message_data == {
"role": "server",
"type": "status",
"content": "complete",
}:
print("Received expected message from server")
break
accumulated_content = await wait_for_websocket_complete(websocket)

assert "barloney" in accumulated_content

Expand Down Expand Up @@ -404,22 +405,7 @@ async def test_fastapi_server():
print("WebSocket chunks sent")

# Wait for response
accumulated_content = ""
while True:
message = await websocket.recv()
message_data = json.loads(message)
if "error" in message_data:
raise Exception(message_data["content"])
print("Received from WebSocket:", message_data)
if message_data.get("content"):
accumulated_content += message_data.get("content")
if message_data == {
"role": "server",
"type": "status",
"content": "complete",
}:
print("Received expected message from server")
break
accumulated_content = await wait_for_websocket_complete(websocket)

time.sleep(5)

Expand Down Expand Up @@ -454,23 +440,7 @@ async def test_fastapi_server():
)

# Wait for a specific response
accumulated_content = ""
while True:
message = await websocket.recv()
message_data = json.loads(message)
if "error" in message_data:
raise Exception(message_data["content"])
print("Received from WebSocket:", message_data)
if message_data.get("content"):
if type(message_data.get("content")) == str:
accumulated_content += message_data.get("content")
if message_data == {
"role": "server",
"type": "status",
"content": "complete",
}:
print("Received expected message from server")
break
accumulated_content = await wait_for_websocket_complete(websocket)

assert "18893094989" in accumulated_content.replace(",", "")

Expand Down Expand Up @@ -525,22 +495,7 @@ async def test_fastapi_server():
print("WebSocket chunks sent")

# Wait for response
accumulated_content = ""
while True:
message = await websocket.recv()
message_data = json.loads(message)
if "error" in message_data:
raise Exception(message_data["content"])
print("Received from WebSocket:", message_data)
if type(message_data.get("content")) == str:
accumulated_content += message_data.get("content")
if message_data == {
"role": "server",
"type": "status",
"content": "complete",
}:
print("Received expected message from server")
break
accumulated_content = await wait_for_websocket_complete(websocket)

# Get messages
get_url = "http://localhost:8000/settings/messages"
Expand Down Expand Up @@ -602,22 +557,7 @@ async def test_fastapi_server():
print("WebSocket chunks sent")

# Wait for response
accumulated_content = ""
while True:
message = await websocket.recv()
message_data = json.loads(message)
if "error" in message_data:
raise Exception(message_data["content"])
print("Received from WebSocket:", message_data)
if type(message_data.get("content")) == str:
accumulated_content += message_data.get("content")
if message_data == {
"role": "server",
"type": "status",
"content": "complete",
}:
print("Received expected message from server")
break
accumulated_content = await wait_for_websocket_complete(websocket)

# Get messages
get_url = "http://localhost:8000/settings/messages"
Expand Down Expand Up @@ -1198,7 +1138,7 @@ def test_math():

order_of_operations_message = f"""
Please perform the calculation `{n1} + {n2} * ({n1} - {n2}) / ({n2} + {n1})` then reply with just the answer, nothing else. No confirmation. No explanation. No words. Do not use commas. Do not show your work. Just return the result of the calculation. Do not introduce the results with a phrase like \"The result of the calculation is...\" or \"The answer is...\"

Round to 2 decimal places.
""".strip()

Expand All @@ -1215,20 +1155,20 @@ def test_break_execution():
"""

code = r"""print("starting")
import time
import os
import time
import os

# Always create a fresh file
open('numbers.txt', 'w').close()
# Open the file in append mode
with open('numbers.txt', 'a+') as f:
# Loop through the numbers 1 to 5
for i in [1,2,3,4,5]:
# Print the number
print("adding", i, "to file")
# Append the number to the file
f.write(str(i) + '\n')

# Open the file in append mode
with open('numbers.txt', 'a+') as f:
# Loop through the numbers 1 to 5
for i in [1,2,3,4,5]:
# Print the number
print("adding", i, "to file")
# Append the number to the file
f.write(str(i) + '\n')
# Wait for 0.5 second
print("starting to sleep")
time.sleep(1)
Expand Down
Loading