Skip to content

Commit a5e3588

Browse files
authored
Merge pull request #266 from algorithmicsuperintelligence/fix-max-tokens
Add robust response validation and token config support
2 parents 427b8fe + 3e314c6 commit a5e3588

File tree

11 files changed

+303
-103
lines changed

11 files changed

+303
-103
lines changed

optillm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Version information
2-
__version__ = "0.3.3"
2+
__version__ = "0.3.4"
33

44
# Import from server module
55
from .server import (

optillm/bon.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,24 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
2222
"temperature": 1
2323
}
2424
response = client.chat.completions.create(**provider_request)
25-
25+
2626
# Log provider call
2727
if request_id:
2828
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
2929
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
30-
31-
completions = [choice.message.content for choice in response.choices]
30+
31+
# Check for valid response with None-checking
32+
if response is None or not response.choices:
33+
raise Exception("Response is None or has no choices")
34+
35+
completions = [choice.message.content for choice in response.choices if choice.message.content is not None]
3236
logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}")
3337
bon_completion_tokens += response.usage.completion_tokens
34-
38+
39+
# Check if any valid completions were generated
40+
if not completions:
41+
raise Exception("No valid completions generated (all were None)")
42+
3543
except Exception as e:
3644
logger.warning(f"n parameter not supported by provider: {str(e)}")
3745
logger.info(f"Falling back to generating {n} completions one by one")
@@ -46,12 +54,20 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
4654
"temperature": 1
4755
}
4856
response = client.chat.completions.create(**provider_request)
49-
57+
5058
# Log provider call
5159
if request_id:
5260
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
5361
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
54-
62+
63+
# Check for valid response with None-checking
64+
if (response is None or
65+
not response.choices or
66+
response.choices[0].message.content is None or
67+
response.choices[0].finish_reason == "length"):
68+
logger.warning(f"Completion {i+1}/{n} truncated or empty, skipping")
69+
continue
70+
5571
completions.append(response.choices[0].message.content)
5672
bon_completion_tokens += response.usage.completion_tokens
5773
logger.debug(f"Generated completion {i+1}/{n}")
@@ -65,11 +81,16 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
6581
return "Error: Could not generate any completions", 0
6682

6783
logger.info(f"Generated {len(completions)} completions using fallback method. Total tokens used: {bon_completion_tokens}")
68-
84+
85+
# Double-check we have completions before rating
86+
if not completions:
87+
logger.error("No completions available for rating")
88+
return "Error: Could not generate any completions", bon_completion_tokens
89+
6990
# Rate the completions
7091
rating_messages = messages.copy()
7192
rating_messages.append({"role": "system", "content": "Rate the following responses on a scale from 0 to 10, where 0 is poor and 10 is excellent. Consider factors such as relevance, coherence, and helpfulness. Respond with only a number."})
72-
93+
7394
ratings = []
7495
for completion in completions:
7596
rating_messages.append({"role": "assistant", "content": completion})
@@ -83,18 +104,27 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
83104
"temperature": 0.1
84105
}
85106
rating_response = client.chat.completions.create(**provider_request)
86-
107+
87108
# Log provider call
88109
if request_id:
89110
response_dict = rating_response.model_dump() if hasattr(rating_response, 'model_dump') else rating_response
90111
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
91-
112+
92113
bon_completion_tokens += rating_response.usage.completion_tokens
93-
try:
94-
rating = float(rating_response.choices[0].message.content.strip())
95-
ratings.append(rating)
96-
except ValueError:
114+
115+
# Check for valid response with None-checking
116+
if (rating_response is None or
117+
not rating_response.choices or
118+
rating_response.choices[0].message.content is None or
119+
rating_response.choices[0].finish_reason == "length"):
120+
logger.warning("Rating response truncated or empty, using default rating of 0")
97121
ratings.append(0)
122+
else:
123+
try:
124+
rating = float(rating_response.choices[0].message.content.strip())
125+
ratings.append(rating)
126+
except ValueError:
127+
ratings.append(0)
98128

99129
rating_messages = rating_messages[:-2]
100130

optillm/mcts.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,18 @@ def generate_actions(self, state: DialogueState) -> List[str]:
122122
"temperature": 1
123123
}
124124
response = self.client.chat.completions.create(**provider_request)
125-
125+
126126
# Log provider call
127127
if self.request_id:
128128
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
129129
conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
130-
131-
completions = [choice.message.content.strip() for choice in response.choices]
130+
131+
# Check for valid response with None-checking
132+
if response is None or not response.choices:
133+
logger.error("Failed to get valid completions from the model")
134+
return []
135+
136+
completions = [choice.message.content.strip() for choice in response.choices if choice.message.content is not None]
132137
self.completion_tokens += response.usage.completion_tokens
133138
logger.info(f"Received {len(completions)} completions from the model")
134139
return completions
@@ -151,13 +156,22 @@ def apply_action(self, state: DialogueState, action: str) -> DialogueState:
151156
"temperature": 1
152157
}
153158
response = self.client.chat.completions.create(**provider_request)
154-
159+
155160
# Log provider call
156161
if self.request_id:
157162
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
158163
conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
159-
160-
next_query = response.choices[0].message.content
164+
165+
# Check for valid response with None-checking
166+
if (response is None or
167+
not response.choices or
168+
response.choices[0].message.content is None or
169+
response.choices[0].finish_reason == "length"):
170+
logger.warning("Next query response truncated or empty, using default")
171+
next_query = "Please continue."
172+
else:
173+
next_query = response.choices[0].message.content
174+
161175
self.completion_tokens += response.usage.completion_tokens
162176
logger.info(f"Generated next user query: {next_query}")
163177
return DialogueState(state.system_prompt, new_history, next_query)
@@ -181,13 +195,22 @@ def evaluate_state(self, state: DialogueState) -> float:
181195
"temperature": 0.1
182196
}
183197
response = self.client.chat.completions.create(**provider_request)
184-
198+
185199
# Log provider call
186200
if self.request_id:
187201
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
188202
conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
189-
203+
190204
self.completion_tokens += response.usage.completion_tokens
205+
206+
# Check for valid response with None-checking
207+
if (response is None or
208+
not response.choices or
209+
response.choices[0].message.content is None or
210+
response.choices[0].finish_reason == "length"):
211+
logger.warning("Evaluation response truncated or empty. Using default value 0.5")
212+
return 0.5
213+
191214
try:
192215
score = float(response.choices[0].message.content.strip())
193216
score = max(0, min(score, 1)) # Ensure the score is between 0 and 1

optillm/moa.py

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,26 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
2525
}
2626

2727
response = client.chat.completions.create(**provider_request)
28-
28+
2929
# Convert response to dict for logging
3030
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
31-
31+
3232
# Log provider call if conversation logging is enabled
3333
if request_id:
3434
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
35-
36-
completions = [choice.message.content for choice in response.choices]
35+
36+
# Check for valid response with None-checking
37+
if response is None or not response.choices:
38+
raise Exception("Response is None or has no choices")
39+
40+
completions = [choice.message.content for choice in response.choices if choice.message.content is not None]
3741
moa_completion_tokens += response.usage.completion_tokens
3842
logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}")
39-
43+
44+
# Check if any valid completions were generated
45+
if not completions:
46+
raise Exception("No valid completions generated (all were None)")
47+
4048
except Exception as e:
4149
logger.warning(f"n parameter not supported by provider: {str(e)}")
4250
logger.info("Falling back to generating 3 completions one by one")
@@ -56,14 +64,22 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
5664
}
5765

5866
response = client.chat.completions.create(**provider_request)
59-
67+
6068
# Convert response to dict for logging
6169
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
62-
70+
6371
# Log provider call if conversation logging is enabled
6472
if request_id:
6573
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
66-
74+
75+
# Check for valid response with None-checking
76+
if (response is None or
77+
not response.choices or
78+
response.choices[0].message.content is None or
79+
response.choices[0].finish_reason == "length"):
80+
logger.warning(f"Completion {i+1}/3 truncated or empty, skipping")
81+
continue
82+
6783
completions.append(response.choices[0].message.content)
6884
moa_completion_tokens += response.usage.completion_tokens
6985
logger.debug(f"Generated completion {i+1}/3")
@@ -77,7 +93,12 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
7793
return "Error: Could not generate any completions", 0
7894

7995
logger.info(f"Generated {len(completions)} completions using fallback method. Total tokens used: {moa_completion_tokens}")
80-
96+
97+
# Double-check we have at least one completion
98+
if not completions:
99+
logger.error("No completions available for processing")
100+
return "Error: Could not generate any completions", moa_completion_tokens
101+
81102
# Handle case where fewer than 3 completions were generated
82103
if len(completions) < 3:
83104
original_count = len(completions)
@@ -118,15 +139,24 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
118139
}
119140

120141
critique_response = client.chat.completions.create(**provider_request)
121-
142+
122143
# Convert response to dict for logging
123144
response_dict = critique_response.model_dump() if hasattr(critique_response, 'model_dump') else critique_response
124-
145+
125146
# Log provider call if conversation logging is enabled
126147
if request_id:
127148
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
128-
129-
critiques = critique_response.choices[0].message.content
149+
150+
# Check for valid response with None-checking
151+
if (critique_response is None or
152+
not critique_response.choices or
153+
critique_response.choices[0].message.content is None or
154+
critique_response.choices[0].finish_reason == "length"):
155+
logger.warning("Critique response truncated or empty, using generic critique")
156+
critiques = "All candidates show reasonable approaches to the problem."
157+
else:
158+
critiques = critique_response.choices[0].message.content
159+
130160
moa_completion_tokens += critique_response.usage.completion_tokens
131161
logger.info(f"Generated critiques. Tokens used: {critique_response.usage.completion_tokens}")
132162

@@ -165,16 +195,27 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
165195
}
166196

167197
final_response = client.chat.completions.create(**provider_request)
168-
198+
169199
# Convert response to dict for logging
170200
response_dict = final_response.model_dump() if hasattr(final_response, 'model_dump') else final_response
171-
201+
172202
# Log provider call if conversation logging is enabled
173203
if request_id:
174204
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
175-
205+
176206
moa_completion_tokens += final_response.usage.completion_tokens
177207
logger.info(f"Generated final response. Tokens used: {final_response.usage.completion_tokens}")
178-
208+
209+
# Check for valid response with None-checking
210+
if (final_response is None or
211+
not final_response.choices or
212+
final_response.choices[0].message.content is None or
213+
final_response.choices[0].finish_reason == "length"):
214+
logger.error("Final response truncated or empty. Consider increasing max_tokens.")
215+
# Return best completion if final response failed
216+
result = completions[0] if completions else "Error: Response was truncated due to token limit. Please increase max_tokens or max_completion_tokens."
217+
else:
218+
result = final_response.choices[0].message.content
219+
179220
logger.info(f"Total completion tokens used: {moa_completion_tokens}")
180-
return final_response.choices[0].message.content, moa_completion_tokens
221+
return result, moa_completion_tokens

0 commit comments

Comments
 (0)