Skip to content

Commit a7d42a5

Browse files
authored
Merge pull request #30 from codelion/feat-add-completion-tokens-count
Add completion tokens
2 parents 2a82be4 + 0bb8ece commit a7d42a5

File tree

16 files changed

+683
-224
lines changed

16 files changed

+683
-224
lines changed

optillm.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -123,37 +123,38 @@ def proxy():
123123

124124

125125
logger.info(f'Using approach {approach}, with {model}')
126+
completion_tokens = 0
126127

127128
try:
128129
if approach == 'mcts':
129-
final_response = chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
130+
final_response, completion_tokens = chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
130131
server_config['mcts_exploration'], server_config['mcts_depth'])
131132
elif approach == 'bon':
132-
final_response = best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
133+
final_response, completion_tokens = best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
133134
elif approach == 'moa':
134-
final_response = mixture_of_agents(system_prompt, initial_query, client, model)
135+
final_response, completion_tokens = mixture_of_agents(system_prompt, initial_query, client, model)
135136
elif approach == 'rto':
136-
final_response = round_trip_optimization(system_prompt, initial_query, client, model)
137+
final_response, completion_tokens = round_trip_optimization(system_prompt, initial_query, client, model)
137138
elif approach == 'z3':
138139
z3_solver = Z3SolverSystem(system_prompt, client, model)
139-
final_response = z3_solver.process_query(initial_query)
140+
final_response, completion_tokens = z3_solver.process_query(initial_query)
140141
elif approach == "self_consistency":
141-
final_response = advanced_self_consistency_approach(system_prompt, initial_query, client, model)
142+
final_response, completion_tokens = advanced_self_consistency_approach(system_prompt, initial_query, client, model)
142143
elif approach == "pvg":
143-
final_response = inference_time_pv_game(system_prompt, initial_query, client, model)
144+
final_response, completion_tokens = inference_time_pv_game(system_prompt, initial_query, client, model)
144145
elif approach == "rstar":
145146
rstar = RStar(system_prompt, client, model,
146147
max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'],
147148
c=server_config['rstar_c'])
148-
final_response = rstar.solve(initial_query)
149+
final_response, completion_tokens = rstar.solve(initial_query)
149150
elif approach == "cot_reflection":
150-
final_response = cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'])
151+
final_response, completion_tokens = cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'])
151152
elif approach == 'plansearch':
152-
final_response = plansearch(system_prompt, initial_query, client, model, n=n)
153+
final_response, completion_tokens = plansearch(system_prompt, initial_query, client, model, n=n)
153154
elif approach == 'leap':
154-
final_response = leap(system_prompt, initial_query, client, model)
155+
final_response, completion_tokens = leap(system_prompt, initial_query, client, model)
155156
elif approach == 're2':
156-
final_response = re2_approach(system_prompt, initial_query, client, model, n=n)
157+
final_response, completion_tokens = re2_approach(system_prompt, initial_query, client, model, n=n)
157158
else:
158159
raise ValueError(f"Unknown approach: {approach}")
159160
except Exception as e:
@@ -162,7 +163,10 @@ def proxy():
162163

163164
response_data = {
164165
'model': model,
165-
'choices': []
166+
'choices': [],
167+
'usage': {
168+
'completion_tokens': completion_tokens,
169+
}
166170
}
167171

168172
if isinstance(final_response, list):

optillm/bon.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
logger = logging.getLogger(__name__)
44

55
def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3) -> str:
6+
bon_completion_tokens = 0
7+
68
messages = [{"role": "system", "content": system_prompt},
79
{"role": "user", "content": initial_query}]
810

@@ -16,6 +18,7 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
1618
temperature=1
1719
)
1820
completions = [choice.message.content for choice in response.choices]
21+
bon_completion_tokens += response.usage.completion_tokens
1922

2023
# Rate the completions
2124
rating_messages = messages.copy()
@@ -33,7 +36,7 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
3336
n=1,
3437
temperature=0.1
3538
)
36-
39+
bon_completion_tokens += rating_response.usage.completion_tokens
3740
try:
3841
rating = float(rating_response.choices[0].message.content.strip())
3942
ratings.append(rating)
@@ -43,4 +46,4 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
4346
rating_messages = rating_messages[:-2]
4447

4548
best_index = ratings.index(max(ratings))
46-
return completions[best_index]
49+
return completions[best_index], bon_completion_tokens

optillm/cot_reflection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
logger = logging.getLogger(__name__)
55

66
def cot_reflection(system_prompt, initial_query, client, model: str, return_full_response: bool=False):
7+
cot_completion_tokens = 0
78
cot_prompt = f"""
89
{system_prompt}
910
@@ -44,6 +45,7 @@ def cot_reflection(system_prompt, initial_query, client, model: str, return_full
4445

4546
# Extract the full response
4647
full_response = response.choices[0].message.content
48+
cot_completion_tokens += response.usage.completion_tokens
4749
logger.info(f"CoT with Reflection :\n{full_response}")
4850

4951
# Use regex to extract the content within <thinking> and <output> tags
@@ -56,7 +58,7 @@ def cot_reflection(system_prompt, initial_query, client, model: str, return_full
5658
logger.info(f"Final output :\n{output}")
5759

5860
if return_full_response:
59-
return full_response
61+
return full_response, cot_completion_tokens
6062
else:
61-
return output
63+
return output, cot_completion_tokens
6264

optillm/leap.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def __init__(self, system_prompt: str, client, model: str):
1414
self.model = model
1515
self.low_level_principles = []
1616
self.high_level_principles = []
17+
self.leap_completion_tokens = 0
1718

1819
def extract_output(self, text: str) -> str:
1920
match = re.search(r'<output>(.*?)(?:</output>|$)', text, re.DOTALL)
@@ -46,6 +47,7 @@ def extract_examples_from_query(self, initial_query: str) -> List[Tuple[str, str
4647
"""}
4748
]
4849
)
50+
self.leap_completion_tokens += response.usage.completion_tokens
4951
examples_str = self.extract_output(response.choices[0].message.content)
5052
logger.debug(f"Extracted examples: {examples_str}")
5153
examples = []
@@ -80,6 +82,7 @@ def generate_mistakes(self, examples: List[Tuple[str, str]]) -> List[Tuple[str,
8082
],
8183
temperature=0.7,
8284
)
85+
self.leap_completion_tokens += response.usage.completion_tokens
8386
generated_reasoning = response.choices[0].message.content
8487
generated_answer = self.extract_output(generated_reasoning)
8588
if generated_answer != correct_answer:
@@ -110,6 +113,7 @@ def generate_low_level_principles(self, mistakes: List[Tuple[str, str, str, str]
110113
"""}
111114
]
112115
)
116+
self.leap_completion_tokens += response.usage.completion_tokens
113117
self.low_level_principles.append(self.extract_output(response.choices[0].message.content))
114118
return self.low_level_principles
115119

@@ -134,6 +138,7 @@ def generate_high_level_principles(self) -> List[str]:
134138
"""}
135139
]
136140
)
141+
self.leap_completion_tokens += response.usage.completion_tokens
137142
self.high_level_principles = self.extract_output(response.choices[0].message.content).split("\n")
138143
return self.high_level_principles
139144

@@ -154,6 +159,7 @@ def apply_principles(self, query: str) -> str:
154159
"""}
155160
]
156161
)
162+
self.leap_completion_tokens += response.usage.completion_tokens
157163
return response.choices[0].message.content
158164

159165
def solve(self, initial_query: str) -> str:
@@ -171,4 +177,4 @@ def solve(self, initial_query: str) -> str:
171177

172178
def leap(system_prompt: str, initial_query: str, client, model: str) -> str:
173179
leap_solver = LEAP(system_prompt, client, model)
174-
return leap_solver.solve(initial_query)
180+
return leap_solver.solve(initial_query), leap_solver.leap_completion_tokens

optillm/mcts.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(self, simulation_depth, exploration_weight, client, model):
3232
self.node_labels = {}
3333
self.client = client
3434
self.model = model
35+
self.completion_tokens = 0
3536

3637
def select(self, node: MCTSNode) -> MCTSNode:
3738
logger.debug(f"Selecting node. Current node visits: {node.visits}, value: {node.value}")
@@ -118,6 +119,7 @@ def generate_actions(self, state: DialogueState) -> List[str]:
118119
temperature=1
119120
)
120121
completions = [choice.message.content.strip() for choice in response.choices]
122+
self.completion_tokens += response.usage.completion_tokens
121123
logger.info(f"Received {len(completions)} completions from the model")
122124
return completions
123125

@@ -140,6 +142,7 @@ def apply_action(self, state: DialogueState, action: str) -> DialogueState:
140142
)
141143

142144
next_query = response.choices[0].message.content
145+
self.completion_tokens += response.usage.completion_tokens
143146
logger.info(f"Generated next user query: {next_query}")
144147
return DialogueState(state.system_prompt, new_history, next_query)
145148

@@ -161,7 +164,7 @@ def evaluate_state(self, state: DialogueState) -> float:
161164
n=1,
162165
temperature=0.1
163166
)
164-
167+
self.completion_tokens += response.usage.completion_tokens
165168
try:
166169
score = float(response.choices[0].message.content.strip())
167170
score = max(0, min(score, 1)) # Ensure the score is between 0 and 1
@@ -181,4 +184,4 @@ def chat_with_mcts(system_prompt: str, initial_query: str, client, model: str, n
181184
final_state = mcts.search(initial_state, num_simulations)
182185
response = final_state.conversation_history[-1]['content'] if final_state.conversation_history else ""
183186
logger.info(f"MCTS chat complete. Final response: {response[:100]}...")
184-
return response
187+
return response, mcts.completion_tokens

optillm/moa.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
logger = logging.getLogger(__name__)
44

55
def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str) -> str:
6+
moa_completion_tokens = 0
67
completions = []
78

89
response = client.chat.completions.create(
@@ -16,6 +17,7 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
1617
temperature=1
1718
)
1819
completions = [choice.message.content for choice in response.choices]
20+
moa_completion_tokens += response.usage.completion_tokens
1921

2022
critique_prompt = f"""
2123
Original query: {initial_query}
@@ -45,6 +47,7 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
4547
temperature=0.1
4648
)
4749
critiques = critique_response.choices[0].message.content
50+
moa_completion_tokens += critique_response.usage.completion_tokens
4851

4952
final_prompt = f"""
5053
Original query: {initial_query}
@@ -76,5 +79,5 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
7679
n=1,
7780
temperature=0.1
7881
)
79-
80-
return final_response.choices[0].message.content
82+
moa_completion_tokens += final_response.usage.completion_tokens
83+
return final_response.choices[0].message.content, moa_completion_tokens

optillm/plansearch.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def __init__(self, system_prompt: str, client, model: str):
88
self.system_prompt = system_prompt
99
self.client = client
1010
self.model = model
11+
self.plansearch_completion_tokens = 0
1112

1213
def generate_observations(self, problem: str, num_observations: int = 3) -> List[str]:
1314
prompt = f"""You are an expert Python programmer. You will be given a competitive programming question
@@ -28,7 +29,7 @@ def generate_observations(self, problem: str, num_observations: int = 3) -> List
2829
{"role": "user", "content": prompt}
2930
]
3031
)
31-
32+
self.plansearch_completion_tokens += response.usage.completion_tokens
3233
observations = response.choices[0].message.content.strip().split('\n')
3334
return [obs.strip() for obs in observations if obs.strip()]
3435

@@ -55,7 +56,7 @@ def generate_derived_observations(self, problem: str, observations: List[str], n
5556
{"role": "user", "content": prompt}
5657
]
5758
)
58-
59+
self.plansearch_completion_tokens += response.usage.completion_tokens
5960
new_observations = response.choices[0].message.content.strip().split('\n')
6061
return [obs.strip() for obs in new_observations if obs.strip()]
6162

@@ -80,7 +81,7 @@ def generate_solution(self, problem: str, observations: List[str]) -> str:
8081
{"role": "user", "content": prompt}
8182
]
8283
)
83-
84+
self.plansearch_completion_tokens += response.usage.completion_tokens
8485
return response.choices[0].message.content.strip()
8586

8687
def implement_solution(self, problem: str, solution: str) -> str:
@@ -105,7 +106,7 @@ def implement_solution(self, problem: str, solution: str) -> str:
105106
{"role": "user", "content": prompt}
106107
]
107108
)
108-
109+
self.plansearch_completion_tokens += response.usage.completion_tokens
109110
return response.choices[0].message.content.strip()
110111

111112
def solve(self, problem: str, num_initial_observations: int = 3, num_derived_observations: int = 2) -> Tuple[str, str]:
@@ -134,4 +135,4 @@ def solve_multiple(self, problem: str, n: int, num_initial_observations: int = 3
134135

135136
def plansearch(system_prompt: str, initial_query: str, client, model: str, n: int = 1) -> List[str]:
136137
planner = PlanSearch(system_prompt, client, model)
137-
return planner.solve_multiple(initial_query, n)
138+
return planner.solve_multiple(initial_query, n), planner.plansearch_completion_tokens

optillm/pvg.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44

55
logger = logging.getLogger(__name__)
66

7+
pvg_completion_tokens = 0
8+
79
def generate_solutions(client, system_prompt: str, query: str, model: str, num_solutions: int, is_sneaky: bool = False, temperature: float = 0.7) -> List[str]:
10+
global pvg_completion_tokens
811
role = "sneaky" if is_sneaky else "helpful"
912
logger.info(f"Generating {num_solutions} {role} solutions")
1013

@@ -34,11 +37,13 @@ def generate_solutions(client, system_prompt: str, query: str, model: str, num_s
3437
max_tokens=4096,
3538
temperature=temperature,
3639
)
40+
pvg_completion_tokens += response.usage.completion_tokens
3741
solutions = [choice.message.content for choice in response.choices]
3842
logger.debug(f"Generated {role} solutions: {solutions}")
3943
return solutions
4044

4145
def verify_solutions(client, system_prompt: str, initial_query: str, solutions: List[str], model: str) -> List[float]:
46+
global pvg_completion_tokens
4247
logger.info(f"Verifying {len(solutions)} solutions")
4348
verify_prompt = f"""{system_prompt}
4449
You are a verifier tasked with evaluating the correctness and clarity of solutions to the given problem.
@@ -75,6 +80,7 @@ def verify_solutions(client, system_prompt: str, initial_query: str, solutions:
7580
max_tokens=1024,
7681
temperature=0.2,
7782
)
83+
pvg_completion_tokens += response.usage.completion_tokens
7884
rating = response.choices[0].message.content
7985
logger.debug(f"Raw rating for solution {i+1}: {rating}")
8086

@@ -130,6 +136,7 @@ def extract_answer(final_state: str) -> Tuple[str, float]:
130136
return "", 0.0
131137

132138
def inference_time_pv_game(system_prompt: str, initial_query: str, client, model: str, num_rounds: int = 2, num_solutions: int = 3) -> str:
139+
global pvg_completion_tokens
133140
logger.info(f"Starting inference-time PV game with {num_rounds} rounds and {num_solutions} solutions per round")
134141

135142
best_solution = ""
@@ -178,9 +185,10 @@ def inference_time_pv_game(system_prompt: str, initial_query: str, client, model
178185
max_tokens=1024,
179186
temperature=0.5,
180187
)
188+
pvg_completion_tokens += response.usage.completion_tokens
181189
initial_query = response.choices[0].message.content
182190
logger.debug(f"Refined query: {initial_query}")
183191

184192
logger.info(f"Inference-time PV game completed. Best solution score: {best_score}")
185193

186-
return best_solution
194+
return best_solution, pvg_completion_tokens

optillm/reread.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def re2_approach(system_prompt, initial_query, client, model, n=1):
1717
str or list: The generated response(s) from the model.
1818
"""
1919
logger.info("Using RE2 approach for query processing")
20+
re2_completion_tokens = 0
2021

2122
# Construct the RE2 prompt
2223
re2_prompt = f"{initial_query}\nRead the question again: {initial_query}"
@@ -32,11 +33,11 @@ def re2_approach(system_prompt, initial_query, client, model, n=1):
3233
messages=messages,
3334
n=n
3435
)
35-
36+
re2_completion_tokens += response.usage.completion_tokens
3637
if n == 1:
37-
return response.choices[0].message.content.strip()
38+
return response.choices[0].message.content.strip(), re2_completion_tokens
3839
else:
39-
return [choice.message.content.strip() for choice in response.choices]
40+
return [choice.message.content.strip() for choice in response.choices], re2_completion_tokens
4041

4142
except Exception as e:
4243
logger.error(f"Error in RE2 approach: {str(e)}")

0 commit comments

Comments
 (0)