Skip to content

Commit 6b803e2

Browse files
authored
Merge pull request #129 from codelion/fix-parsing-tagged-conv-bug
Fix parsing tagged conv bug
2 parents 7a23694 + f96b435 commit 6b803e2

File tree

3 files changed

+67
-17
lines changed

3 files changed

+67
-17
lines changed

optillm.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -416,16 +416,23 @@ def parse_conversation(messages):
416416

417417
def tagged_conversation_to_messages(response_text):
418418
"""Convert a tagged conversation string or list of strings into a list of messages.
419+
If the input doesn't contain User:/Assistant: tags, return it as is.
419420
420421
Args:
421422
response_text: Either a string containing "User:" and "Assistant:" tags,
422423
or a list of such strings.
423424
424425
Returns:
425-
If input is a string: A list of message dictionaries.
426-
If input is a list: A list of lists of message dictionaries.
426+
If input has tags: A list of message dictionaries.
427+
If input has no tags: The original input.
427428
"""
429+
def has_conversation_tags(text):
430+
return "User:" in text or "Assistant:" in text
431+
428432
def process_single_response(text):
433+
if not has_conversation_tags(text):
434+
return text
435+
429436
messages = []
430437
# Split on "User:" or "Assistant:" while keeping the delimiter
431438
parts = re.split(r'(?=(User:|Assistant:))', text.strip())
@@ -447,7 +454,11 @@ def process_single_response(text):
447454
return messages
448455

449456
if isinstance(response_text, list):
450-
return [process_single_response(text) for text in response_text]
457+
processed = [process_single_response(text) for text in response_text]
458+
# If none of the responses had tags, return original list
459+
if all(isinstance(p, str) for p in processed):
460+
return response_text
461+
return processed
451462
else:
452463
return process_single_response(response_text)
453464

@@ -555,14 +566,18 @@ def proxy():
555566
except Exception as e:
556567
logger.error(f"Error processing request: {str(e)}")
557568
return jsonify({"error": str(e)}), 500
558-
569+
559570
# Convert tagged conversation to messages format if needed
560571
if isinstance(response, list):
561-
response = [msg[-1]['content'] if isinstance(msg, list) and msg else msg
562-
for msg in tagged_conversation_to_messages(response)]
572+
processed_response = tagged_conversation_to_messages(response)
573+
# If processed_response is a list of message lists, extract last message content
574+
if processed_response != response: # Only process if format changed
575+
response = [msg[-1]['content'] if isinstance(msg, list) and msg else msg
576+
for msg in processed_response]
577+
# Otherwise keep original response
563578
else:
564579
messages = tagged_conversation_to_messages(response)
565-
if messages: # Only take the last message if we have any
580+
if isinstance(messages, list) and messages: # Only process if format changed
566581
response = messages[-1]['content']
567582

568583
if stream:

scripts/eval_aime_benchmark.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import re
66
import time
7-
from typing import List, Dict, Tuple, Optional
7+
from typing import List, Dict, Tuple, Optional, Union
88
from datetime import datetime
99
from openai import OpenAI
1010
from datasets import load_dataset
@@ -89,9 +89,17 @@ def extract_answer(response: str) -> Optional[int]:
8989

9090
return None
9191

92-
def get_llm_response(problem: str, model: str) -> str:
92+
def get_llm_response(problem: str, model: str) -> Union[str, List[Dict]]:
9393
"""
9494
Get response from the LLM for a given problem.
95+
If multiple choices are returned, formats them as attempt dictionaries.
96+
97+
Args:
98+
problem (str): The problem text
99+
model (str): The model identifier
100+
101+
Returns:
102+
Union[str, List[Dict]]: Either a string response or list of attempt dictionaries
95103
"""
96104
try:
97105
response = client.with_options(timeout=1000.0).chat.completions.create(
@@ -101,7 +109,23 @@ def get_llm_response(problem: str, model: str) -> str:
101109
],
102110
max_tokens=8192,
103111
)
112+
113+
# If there's more than one choice, format as attempts
114+
if len(response.choices) > 1:
115+
attempts = []
116+
for i, choice in enumerate(response.choices):
117+
response_text = choice.message.content.strip()
118+
predicted_answer = extract_answer(response_text)
119+
attempts.append({
120+
"attempt_number": i + 1,
121+
"response": response_text,
122+
"predicted_answer": predicted_answer
123+
})
124+
return attempts
125+
126+
# If single choice, return as before
104127
return response.choices[0].message.content.strip()
128+
105129
except Exception as e:
106130
logger.error(f"Error getting LLM response: {e}")
107131
return ""
@@ -119,14 +143,25 @@ def make_n_attempts(problem: str, model: str, n: int) -> List[Dict]:
119143
List[Dict]: List of dictionaries containing response and predicted answer for each attempt
120144
"""
121145
attempts = []
122-
for i in range(n):
146+
remaining_attempts = n
147+
148+
while remaining_attempts > 0:
123149
response = get_llm_response(problem, model)
124-
predicted_answer = extract_answer(response)
125-
attempts.append({
126-
"attempt_number": i + 1,
127-
"response": response,
128-
"predicted_answer": predicted_answer
129-
})
150+
151+
# If response is already formatted as attempts
152+
if isinstance(response, list):
153+
attempts.extend(response)
154+
remaining_attempts = n - len(attempts)
155+
else:
156+
# Process single response as before
157+
predicted_answer = extract_answer(response)
158+
attempts.append({
159+
"attempt_number": len(attempts) + 1,
160+
"response": response,
161+
"predicted_answer": predicted_answer
162+
})
163+
remaining_attempts -= 1
164+
130165
return attempts
131166

132167
def evaluate_pass_at_n(attempts: List[Dict], correct_answer: int) -> Tuple[bool, Optional[int]]:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="optillm",
5-
version="0.0.23",
5+
version="0.0.24",
66
packages=find_packages(),
77
py_modules=['optillm'],
88
package_data={

0 commit comments

Comments
 (0)