Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions optillm.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,37 +123,38 @@ def proxy():


logger.info(f'Using approach {approach}, with {model}')
completion_tokens = 0

try:
if approach == 'mcts':
final_response = chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
final_response, completion_tokens = chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
server_config['mcts_exploration'], server_config['mcts_depth'])
elif approach == 'bon':
final_response = best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
final_response, completion_tokens = best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
elif approach == 'moa':
final_response = mixture_of_agents(system_prompt, initial_query, client, model)
final_response, completion_tokens = mixture_of_agents(system_prompt, initial_query, client, model)
elif approach == 'rto':
final_response = round_trip_optimization(system_prompt, initial_query, client, model)
final_response, completion_tokens = round_trip_optimization(system_prompt, initial_query, client, model)
elif approach == 'z3':
z3_solver = Z3SolverSystem(system_prompt, client, model)
final_response = z3_solver.process_query(initial_query)
final_response, completion_tokens = z3_solver.process_query(initial_query)
elif approach == "self_consistency":
final_response = advanced_self_consistency_approach(system_prompt, initial_query, client, model)
final_response, completion_tokens = advanced_self_consistency_approach(system_prompt, initial_query, client, model)
elif approach == "pvg":
final_response = inference_time_pv_game(system_prompt, initial_query, client, model)
final_response, completion_tokens = inference_time_pv_game(system_prompt, initial_query, client, model)
elif approach == "rstar":
rstar = RStar(system_prompt, client, model,
max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'],
c=server_config['rstar_c'])
final_response = rstar.solve(initial_query)
final_response, completion_tokens = rstar.solve(initial_query)
elif approach == "cot_reflection":
final_response = cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'])
final_response, completion_tokens = cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'])
elif approach == 'plansearch':
final_response = plansearch(system_prompt, initial_query, client, model, n=n)
final_response, completion_tokens = plansearch(system_prompt, initial_query, client, model, n=n)
elif approach == 'leap':
final_response = leap(system_prompt, initial_query, client, model)
final_response, completion_tokens = leap(system_prompt, initial_query, client, model)
elif approach == 're2':
final_response = re2_approach(system_prompt, initial_query, client, model, n=n)
final_response, completion_tokens = re2_approach(system_prompt, initial_query, client, model, n=n)
else:
raise ValueError(f"Unknown approach: {approach}")
except Exception as e:
Expand All @@ -162,7 +163,10 @@ def proxy():

response_data = {
'model': model,
'choices': []
'choices': [],
'usage': {
'completion_tokens': completion_tokens,
}
}

if isinstance(final_response, list):
Expand Down
7 changes: 5 additions & 2 deletions optillm/bon.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
logger = logging.getLogger(__name__)

def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3) -> str:
bon_completion_tokens = 0

messages = [{"role": "system", "content": system_prompt},
{"role": "user", "content": initial_query}]

Expand All @@ -16,6 +18,7 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
temperature=1
)
completions = [choice.message.content for choice in response.choices]
bon_completion_tokens += response.usage.completion_tokens

# Rate the completions
rating_messages = messages.copy()
Expand All @@ -33,7 +36,7 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
n=1,
temperature=0.1
)

bon_completion_tokens += rating_response.usage.completion_tokens
try:
rating = float(rating_response.choices[0].message.content.strip())
ratings.append(rating)
Expand All @@ -43,4 +46,4 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
rating_messages = rating_messages[:-2]

best_index = ratings.index(max(ratings))
return completions[best_index]
return completions[best_index], bon_completion_tokens
6 changes: 4 additions & 2 deletions optillm/cot_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
logger = logging.getLogger(__name__)

def cot_reflection(system_prompt, initial_query, client, model: str, return_full_response: bool=False):
cot_completion_tokens = 0
cot_prompt = f"""
{system_prompt}

Expand Down Expand Up @@ -44,6 +45,7 @@ def cot_reflection(system_prompt, initial_query, client, model: str, return_full

# Extract the full response
full_response = response.choices[0].message.content
cot_completion_tokens += response.usage.completion_tokens
logger.info(f"CoT with Reflection :\n{full_response}")

# Use regex to extract the content within <thinking> and <output> tags
Expand All @@ -56,7 +58,7 @@ def cot_reflection(system_prompt, initial_query, client, model: str, return_full
logger.info(f"Final output :\n{output}")

if return_full_response:
return full_response
return full_response, cot_completion_tokens
else:
return output
return output, cot_completion_tokens

8 changes: 7 additions & 1 deletion optillm/leap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, system_prompt: str, client, model: str):
self.model = model
self.low_level_principles = []
self.high_level_principles = []
self.leap_completion_tokens = 0

def extract_output(self, text: str) -> str:
match = re.search(r'<output>(.*?)(?:</output>|$)', text, re.DOTALL)
Expand Down Expand Up @@ -46,6 +47,7 @@ def extract_examples_from_query(self, initial_query: str) -> List[Tuple[str, str
"""}
]
)
self.leap_completion_tokens += response.usage.completion_tokens
examples_str = self.extract_output(response.choices[0].message.content)
logger.debug(f"Extracted examples: {examples_str}")
examples = []
Expand Down Expand Up @@ -80,6 +82,7 @@ def generate_mistakes(self, examples: List[Tuple[str, str]]) -> List[Tuple[str,
],
temperature=0.7,
)
self.leap_completion_tokens += response.usage.completion_tokens
generated_reasoning = response.choices[0].message.content
generated_answer = self.extract_output(generated_reasoning)
if generated_answer != correct_answer:
Expand Down Expand Up @@ -110,6 +113,7 @@ def generate_low_level_principles(self, mistakes: List[Tuple[str, str, str, str]
"""}
]
)
self.leap_completion_tokens += response.usage.completion_tokens
self.low_level_principles.append(self.extract_output(response.choices[0].message.content))
return self.low_level_principles

Expand All @@ -134,6 +138,7 @@ def generate_high_level_principles(self) -> List[str]:
"""}
]
)
self.leap_completion_tokens += response.usage.completion_tokens
self.high_level_principles = self.extract_output(response.choices[0].message.content).split("\n")
return self.high_level_principles

Expand All @@ -154,6 +159,7 @@ def apply_principles(self, query: str) -> str:
"""}
]
)
self.leap_completion_tokens += response.usage.completion_tokens
return response.choices[0].message.content

def solve(self, initial_query: str) -> str:
Expand All @@ -171,4 +177,4 @@ def solve(self, initial_query: str) -> str:

def leap(system_prompt: str, initial_query: str, client, model: str) -> str:
leap_solver = LEAP(system_prompt, client, model)
return leap_solver.solve(initial_query)
return leap_solver.solve(initial_query), leap_solver.leap_completion_tokens
7 changes: 5 additions & 2 deletions optillm/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, simulation_depth, exploration_weight, client, model):
self.node_labels = {}
self.client = client
self.model = model
self.completion_tokens = 0

def select(self, node: MCTSNode) -> MCTSNode:
logger.debug(f"Selecting node. Current node visits: {node.visits}, value: {node.value}")
Expand Down Expand Up @@ -118,6 +119,7 @@ def generate_actions(self, state: DialogueState) -> List[str]:
temperature=1
)
completions = [choice.message.content.strip() for choice in response.choices]
self.completion_tokens += response.usage.completion_tokens
logger.info(f"Received {len(completions)} completions from the model")
return completions

Expand All @@ -140,6 +142,7 @@ def apply_action(self, state: DialogueState, action: str) -> DialogueState:
)

next_query = response.choices[0].message.content
self.completion_tokens += response.usage.completion_tokens
logger.info(f"Generated next user query: {next_query}")
return DialogueState(state.system_prompt, new_history, next_query)

Expand All @@ -161,7 +164,7 @@ def evaluate_state(self, state: DialogueState) -> float:
n=1,
temperature=0.1
)

self.completion_tokens += response.usage.completion_tokens
try:
score = float(response.choices[0].message.content.strip())
score = max(0, min(score, 1)) # Ensure the score is between 0 and 1
Expand All @@ -181,4 +184,4 @@ def chat_with_mcts(system_prompt: str, initial_query: str, client, model: str, n
final_state = mcts.search(initial_state, num_simulations)
response = final_state.conversation_history[-1]['content'] if final_state.conversation_history else ""
logger.info(f"MCTS chat complete. Final response: {response[:100]}...")
return response
return response, mcts.completion_tokens
7 changes: 5 additions & 2 deletions optillm/moa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
logger = logging.getLogger(__name__)

def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str) -> str:
moa_completion_tokens = 0
completions = []

response = client.chat.completions.create(
Expand All @@ -16,6 +17,7 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
temperature=1
)
completions = [choice.message.content for choice in response.choices]
moa_completion_tokens += response.usage.completion_tokens

critique_prompt = f"""
Original query: {initial_query}
Expand Down Expand Up @@ -45,6 +47,7 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
temperature=0.1
)
critiques = critique_response.choices[0].message.content
moa_completion_tokens += critique_response.usage.completion_tokens

final_prompt = f"""
Original query: {initial_query}
Expand Down Expand Up @@ -76,5 +79,5 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
n=1,
temperature=0.1
)

return final_response.choices[0].message.content
moa_completion_tokens += final_response.usage.completion_tokens
return final_response.choices[0].message.content, moa_completion_tokens
11 changes: 6 additions & 5 deletions optillm/plansearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def __init__(self, system_prompt: str, client, model: str):
self.system_prompt = system_prompt
self.client = client
self.model = model
self.plansearch_completion_tokens = 0

def generate_observations(self, problem: str, num_observations: int = 3) -> List[str]:
prompt = f"""You are an expert Python programmer. You will be given a competitive programming question
Expand All @@ -28,7 +29,7 @@ def generate_observations(self, problem: str, num_observations: int = 3) -> List
{"role": "user", "content": prompt}
]
)

self.plansearch_completion_tokens += response.usage.completion_tokens
observations = response.choices[0].message.content.strip().split('\n')
return [obs.strip() for obs in observations if obs.strip()]

Expand All @@ -55,7 +56,7 @@ def generate_derived_observations(self, problem: str, observations: List[str], n
{"role": "user", "content": prompt}
]
)

self.plansearch_completion_tokens += response.usage.completion_tokens
new_observations = response.choices[0].message.content.strip().split('\n')
return [obs.strip() for obs in new_observations if obs.strip()]

Expand All @@ -80,7 +81,7 @@ def generate_solution(self, problem: str, observations: List[str]) -> str:
{"role": "user", "content": prompt}
]
)

self.plansearch_completion_tokens += response.usage.completion_tokens
return response.choices[0].message.content.strip()

def implement_solution(self, problem: str, solution: str) -> str:
Expand All @@ -105,7 +106,7 @@ def implement_solution(self, problem: str, solution: str) -> str:
{"role": "user", "content": prompt}
]
)

self.plansearch_completion_tokens += response.usage.completion_tokens
return response.choices[0].message.content.strip()

def solve(self, problem: str, num_initial_observations: int = 3, num_derived_observations: int = 2) -> Tuple[str, str]:
Expand Down Expand Up @@ -134,4 +135,4 @@ def solve_multiple(self, problem: str, n: int, num_initial_observations: int = 3

def plansearch(system_prompt: str, initial_query: str, client, model: str, n: int = 1) -> List[str]:
planner = PlanSearch(system_prompt, client, model)
return planner.solve_multiple(initial_query, n)
return planner.solve_multiple(initial_query, n), planner.plansearch_completion_tokens
10 changes: 9 additions & 1 deletion optillm/pvg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

logger = logging.getLogger(__name__)

pvg_completion_tokens = 0

def generate_solutions(client, system_prompt: str, query: str, model: str, num_solutions: int, is_sneaky: bool = False, temperature: float = 0.7) -> List[str]:
global pvg_completion_tokens
role = "sneaky" if is_sneaky else "helpful"
logger.info(f"Generating {num_solutions} {role} solutions")

Expand Down Expand Up @@ -34,11 +37,13 @@ def generate_solutions(client, system_prompt: str, query: str, model: str, num_s
max_tokens=4096,
temperature=temperature,
)
pvg_completion_tokens += response.usage.completion_tokens
solutions = [choice.message.content for choice in response.choices]
logger.debug(f"Generated {role} solutions: {solutions}")
return solutions

def verify_solutions(client, system_prompt: str, initial_query: str, solutions: List[str], model: str) -> List[float]:
global pvg_completion_tokens
logger.info(f"Verifying {len(solutions)} solutions")
verify_prompt = f"""{system_prompt}
You are a verifier tasked with evaluating the correctness and clarity of solutions to the given problem.
Expand Down Expand Up @@ -75,6 +80,7 @@ def verify_solutions(client, system_prompt: str, initial_query: str, solutions:
max_tokens=1024,
temperature=0.2,
)
pvg_completion_tokens += response.usage.completion_tokens
rating = response.choices[0].message.content
logger.debug(f"Raw rating for solution {i+1}: {rating}")

Expand Down Expand Up @@ -130,6 +136,7 @@ def extract_answer(final_state: str) -> Tuple[str, float]:
return "", 0.0

def inference_time_pv_game(system_prompt: str, initial_query: str, client, model: str, num_rounds: int = 2, num_solutions: int = 3) -> str:
global pvg_completion_tokens
logger.info(f"Starting inference-time PV game with {num_rounds} rounds and {num_solutions} solutions per round")

best_solution = ""
Expand Down Expand Up @@ -178,9 +185,10 @@ def inference_time_pv_game(system_prompt: str, initial_query: str, client, model
max_tokens=1024,
temperature=0.5,
)
pvg_completion_tokens += response.usage.completion_tokens
initial_query = response.choices[0].message.content
logger.debug(f"Refined query: {initial_query}")

logger.info(f"Inference-time PV game completed. Best solution score: {best_score}")

return best_solution
return best_solution, pvg_completion_tokens
7 changes: 4 additions & 3 deletions optillm/reread.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def re2_approach(system_prompt, initial_query, client, model, n=1):
str or list: The generated response(s) from the model.
"""
logger.info("Using RE2 approach for query processing")
re2_completion_tokens = 0

# Construct the RE2 prompt
re2_prompt = f"{initial_query}\nRead the question again: {initial_query}"
Expand All @@ -32,11 +33,11 @@ def re2_approach(system_prompt, initial_query, client, model, n=1):
messages=messages,
n=n
)

re2_completion_tokens += response.usage.completion_tokens
if n == 1:
return response.choices[0].message.content.strip()
return response.choices[0].message.content.strip(), re2_completion_tokens
else:
return [choice.message.content.strip() for choice in response.choices]
return [choice.message.content.strip() for choice in response.choices], re2_completion_tokens

except Exception as e:
logger.error(f"Error in RE2 approach: {str(e)}")
Expand Down
Loading