diff --git a/tests/test_interpreter.py b/tests/test_interpreter.py index cf71e95863..eaae40ab1d 100644 --- a/tests/test_interpreter.py +++ b/tests/test_interpreter.py @@ -37,9 +37,9 @@ def test_hallucinations(): assert chunk.get("content") == "22" break - code = """{ - "language": "python", - "code": "10+12" + code = """{ + "language": "python", + "code": "10+12" }""" interpreter.messages = [ @@ -50,9 +50,9 @@ def test_hallucinations(): assert chunk.get("content") == "22" break - code = """functions.execute({ - "language": "python", - "code": "10+12" + code = """functions.execute({ + "language": "python", + "code": "10+12" })""" interpreter.messages = [ @@ -236,6 +236,37 @@ def run_server(): async_interpreter.server.run() +async def wait_for_websocket_complete(websocket, max_attempts=5): + """Wait for WebSocket 'complete' status message with retry limit.""" + + import asyncio + import json + + accumulated_content = "" + + for attempt in range(1, max_attempts + 1): + try: + message = await websocket.recv() + message_data = json.loads(message) + if "error" in message_data: + raise Exception(message_data["content"]) + print("Received from WebSocket:", message_data) + if type(message_data.get("content")) == str: + accumulated_content += message_data.get("content") + if message_data == { + "role": "server", + "type": "status", + "content": "complete", + }: + print("Received expected message from server") + return accumulated_content + except Exception as e: + print(f"WebSocket receive failed (attempt {attempt}/{max_attempts}): {e}") + await asyncio.sleep(1) + else: + raise Exception(f"Never received 'complete' status after {max_attempts} attempts") + + # @pytest.mark.skip(reason="Requires uvicorn, which we don't require by default") def test_server(): # Start the server in a new process @@ -299,22 +330,7 @@ async def test_fastapi_server(): print("WebSocket chunks sent") # Wait for a specific response - accumulated_content = "" - while True: - message = await websocket.recv() - message_data = json.loads(message) - if "error" in message_data: - raise Exception(message_data["content"]) - print("Received from WebSocket:", message_data) - if type(message_data.get("content")) == str: - accumulated_content += message_data.get("content") - if message_data == { - "role": "server", - "type": "status", - "content": "complete", - }: - print("Received expected message from server") - break + accumulated_content = await wait_for_websocket_complete(websocket) assert "crunk" in accumulated_content @@ -355,22 +371,7 @@ async def test_fastapi_server(): print("WebSocket chunks sent") # Wait for a specific response - accumulated_content = "" - while True: - message = await websocket.recv() - message_data = json.loads(message) - if "error" in message_data: - raise Exception(message_data["content"]) - print("Received from WebSocket:", message_data) - if message_data.get("content"): - accumulated_content += message_data.get("content") - if message_data == { - "role": "server", - "type": "status", - "content": "complete", - }: - print("Received expected message from server") - break + accumulated_content = await wait_for_websocket_complete(websocket) assert "barloney" in accumulated_content @@ -404,22 +405,7 @@ async def test_fastapi_server(): print("WebSocket chunks sent") # Wait for response - accumulated_content = "" - while True: - message = await websocket.recv() - message_data = json.loads(message) - if "error" in message_data: - raise Exception(message_data["content"]) - print("Received from WebSocket:", message_data) - if message_data.get("content"): - accumulated_content += message_data.get("content") - if message_data == { - "role": "server", - "type": "status", - "content": "complete", - }: - print("Received expected message from server") - break + accumulated_content = await wait_for_websocket_complete(websocket) time.sleep(5) @@ -454,23 +440,7 @@ async def test_fastapi_server(): ) # Wait for a specific response - accumulated_content = "" - while True: - message = await websocket.recv() - message_data = json.loads(message) - if "error" in message_data: - raise Exception(message_data["content"]) - print("Received from WebSocket:", message_data) - if message_data.get("content"): - if type(message_data.get("content")) == str: - accumulated_content += message_data.get("content") - if message_data == { - "role": "server", - "type": "status", - "content": "complete", - }: - print("Received expected message from server") - break + accumulated_content = await wait_for_websocket_complete(websocket) assert "18893094989" in accumulated_content.replace(",", "") @@ -525,22 +495,7 @@ async def test_fastapi_server(): print("WebSocket chunks sent") # Wait for response - accumulated_content = "" - while True: - message = await websocket.recv() - message_data = json.loads(message) - if "error" in message_data: - raise Exception(message_data["content"]) - print("Received from WebSocket:", message_data) - if type(message_data.get("content")) == str: - accumulated_content += message_data.get("content") - if message_data == { - "role": "server", - "type": "status", - "content": "complete", - }: - print("Received expected message from server") - break + accumulated_content = await wait_for_websocket_complete(websocket) # Get messages get_url = "http://localhost:8000/settings/messages" @@ -602,22 +557,7 @@ async def test_fastapi_server(): print("WebSocket chunks sent") # Wait for response - accumulated_content = "" - while True: - message = await websocket.recv() - message_data = json.loads(message) - if "error" in message_data: - raise Exception(message_data["content"]) - print("Received from WebSocket:", message_data) - if type(message_data.get("content")) == str: - accumulated_content += message_data.get("content") - if message_data == { - "role": "server", - "type": "status", - "content": "complete", - }: - print("Received expected message from server") - break + accumulated_content = await wait_for_websocket_complete(websocket) # Get messages get_url = "http://localhost:8000/settings/messages" @@ -1198,7 +1138,7 @@ def test_math(): order_of_operations_message = f""" Please perform the calculation `{n1} + {n2} * ({n1} - {n2}) / ({n2} + {n1})` then reply with just the answer, nothing else. No confirmation. No explanation. No words. Do not use commas. Do not show your work. Just return the result of the calculation. Do not introduce the results with a phrase like \"The result of the calculation is...\" or \"The answer is...\" - + Round to 2 decimal places. """.strip() @@ -1215,20 +1155,20 @@ def test_break_execution(): """ code = r"""print("starting") -import time -import os - +import time +import os + # Always create a fresh file open('numbers.txt', 'w').close() - -# Open the file in append mode -with open('numbers.txt', 'a+') as f: - # Loop through the numbers 1 to 5 - for i in [1,2,3,4,5]: - # Print the number - print("adding", i, "to file") - # Append the number to the file - f.write(str(i) + '\n') + +# Open the file in append mode +with open('numbers.txt', 'a+') as f: + # Loop through the numbers 1 to 5 + for i in [1,2,3,4,5]: + # Print the number + print("adding", i, "to file") + # Append the number to the file + f.write(str(i) + '\n') # Wait for 0.5 second print("starting to sleep") time.sleep(1)