From 032aab777aa380792055ab9f1cb442a7298e664f Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 4 Nov 2024 19:48:02 +0800 Subject: [PATCH 1/4] update training scripts --- scripts/gen_optillm_ground_truth_dataset.py | 130 ++++++++++++++++++++ scripts/train_optillm_classifier.py | 54 +++++++- 2 files changed, 179 insertions(+), 5 deletions(-) create mode 100644 scripts/gen_optillm_ground_truth_dataset.py diff --git a/scripts/gen_optillm_ground_truth_dataset.py b/scripts/gen_optillm_ground_truth_dataset.py new file mode 100644 index 00000000..8227fb32 --- /dev/null +++ b/scripts/gen_optillm_ground_truth_dataset.py @@ -0,0 +1,130 @@ +import os +import json +import argparse +import asyncio +from tqdm import tqdm +from datasets import load_dataset +from openai import AsyncOpenAI +from typing import List, Dict, Any, Tuple +import random + +# OptILM approaches remain the same as in original script +APPROACHES = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"] + +# Dataset configurations +DATASET_CONFIGS = [ + ("MixEval", "free_form"), + ("MixEval", "multiple_choice"), + ("MixEval_Hard", "free_form"), + ("MixEval_Hard", "multiple_choice") +] + +def construct_prompt(sample: Dict[str, Any], split_type: str) -> str: + """Construct prompt based on split type.""" + context = sample.get("context", "") + prompt = sample["prompt"] + + if split_type == "multiple_choice": + options = sample["options"] + options_text = "\nOptions:\n" + "\n".join([f"{i+1}. {opt}" for i, opt in enumerate(options)]) + return f"Context: {context}\n\nQuestion: {prompt}{options_text}\n\nProvide the correct answer from the options above." + else: + return f"Context: {context}\n\nQuestion: {prompt}\n\nProvide your answer." + +def is_correct_response(response: str, targets: List[str]) -> bool: + """Check if response matches any of the target answers.""" + response = response.strip().lower() + return any(target.strip().lower() == response for target in targets) + +async def generate_response(prompt: str, approach: str) -> Dict[str, Any]: + """Generate a response using the specified approach.""" + if approach == "none": + client = AsyncOpenAI() + response = await client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": prompt}], + ) + return { + "content": response.choices[0].message.content, + "tokens": response.usage.completion_tokens, + } + else: + client = AsyncOpenAI(api_key="none", base_url="http://localhost:8000/v1") + response = await client.chat.completions.create( + model=f"{approach}-gpt-4o-mini", + messages=[{"role": "user", "content": prompt}], + ) + return { + "content": response.choices[0].message.content, + "tokens": response.usage.completion_tokens, + } + +def rank_responses(responses: List[Dict[str, Any]], targets: List[str]) -> List[int]: + """Rank responses based on correctness and token efficiency.""" + # Create tuples of (index, is_correct, tokens) for sorting + ranked_data = [] + for i, response in enumerate(responses): + is_correct = is_correct_response(response["content"], targets) + ranked_data.append((i, is_correct, response["tokens"])) + + # Sort by correctness (True first) and then by tokens (ascending) + ranked_data.sort(key=lambda x: (-int(x[1]), x[2])) + + # Extract indices for final ranking + return [idx for idx, _, _ in ranked_data] + +async def process_sample(sample: Dict[str, Any], split_type: str) -> Dict[str, Any]: + """Process a single sample from the dataset.""" + prompt = construct_prompt(sample, split_type) + results = [] + + # Generate responses for each approach + for approach in APPROACHES: + response = await generate_response(prompt, approach) + results.append({"approach": approach, **response}) + + # Rank the responses based on correctness and token efficiency + rankings = rank_responses(results, sample["target"]) + + # Add rankings to results + for rank, idx in enumerate(rankings): + results[idx]["rank"] = rank + + return { + "prompt": prompt, + "results": results, + } + +async def generate_dataset(num_samples: int, output_file: str): + """Generate the dataset and save it to a JSONL file.""" + with open(output_file, "w") as f: + for config, split_type in DATASET_CONFIGS: + print(f"Processing {config} - {split_type}") + dataset = load_dataset("MixEval/MixEval", config, split=split_type) + + # Calculate samples per configuration + samples_per_config = max(1, num_samples // len(DATASET_CONFIGS)) + + for sample in tqdm(dataset.select(range(samples_per_config)), + total=samples_per_config, + desc=f"{config}-{split_type}"): + try: + result = await process_sample(sample, split_type) + f.write(json.dumps(result) + "\n") + except Exception as e: + print(f"Error processing sample: {str(e)}") + +def main(): + parser = argparse.ArgumentParser(description="Generate OptILM Ground Truth dataset") + parser.add_argument("--num_samples", type=int, default=100, + help="Total number of samples to process (divided among configurations)") + parser.add_argument("--output_file", type=str, + default="optillm_ground_truth_dataset.jsonl", + help="Output file path") + args = parser.parse_args() + + asyncio.run(generate_dataset(args.num_samples, args.output_file)) + print(f"Dataset generated and saved to {args.output_file}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/train_optillm_classifier.py b/scripts/train_optillm_classifier.py index fac25832..904d5f82 100644 --- a/scripts/train_optillm_classifier.py +++ b/scripts/train_optillm_classifier.py @@ -57,7 +57,7 @@ def __getitem__(self, idx): } def load_and_preprocess_data(tokenizer): - dataset = load_dataset('json', data_files='optillm_dataset.jsonl') + dataset = load_dataset('json', data_files='optillm_combined_dataset.jsonl') data_items = [] @@ -290,11 +290,54 @@ def main(args): best_model.eval() test_prompts = [ + # Linear Programming (likely MCTS or Z3) "Maximize x + y subject to: x + 2y <= 10, x >= 0, y >= 0", + # Graph Theory (likely MCTS or RTO) "Find the shortest path between nodes A and B in the given graph", + # Recursive Problem (likely MOA or COT) "Solve the Tower of Hanoi problem with 4 disks", + # Number Theory (likely NONE or Z3) "Determine if the given number is prime", - "Find all possible combinations of coins that sum up to $1" + # Combinatorics (likely MCTS or BON) + "Find all possible combinations of coins that sum up to $1", + # Symbolic Mathematics (likely Z3 or LEAP) + "Solve the equation: 2x^3 - 5x^2 + 3x - 7 = 0", + # Natural Language Processing (likely PVG or SELF_CONSISTENCY) + "Summarize the main points of the given article in three sentences", + # Computer Vision (likely RSTAR or PVG) + "Describe the contents of the image, including any text present", + # Game Theory (likely MCTS or BON) + "Find the Nash equilibrium for the prisoner's dilemma game", + # Constraint Satisfaction (likely Z3 or PLANSEARCH) + "Solve the Sudoku puzzle given the following initial configuration", + # Optimization (likely MCTS or RSTAR) + "Find the optimal route for a salesperson visiting 10 cities", + # Logical Reasoning (likely COT_REFLECTION or SELF_CONSISTENCY) + "If all A are B, and some B are C, what can we conclude about A and C?", + # Time Series Analysis (likely RSTAR or PVG) + "Predict the stock price for the next week given the past year's data", + # Robotics (likely MCTS or RTO) + "Plan a path for a robot to navigate through a room with obstacles", + # Natural Language Understanding (likely PVG or LEAP) + "Identify the sentiment and main topics in the following customer review", + # Theorem Proving (likely Z3 or COT_REFLECTION) + "Prove that the square root of 2 is irrational", + # Reinforcement Learning (likely MCTS or RSTAR) + "Design a policy for an agent to maximize its score in a given game environment", + # Information Retrieval (likely PVG or SELF_CONSISTENCY) + "Find the most relevant documents in the corpus for the given query", + # Cryptography (likely Z3 or LEAP) + "Decrypt the following message encrypted with a simple substitution cipher", + # Quantum Computing (likely NONE or Z3) + "Simulate a quantum circuit with 3 qubits and measure the output", + # Computer Graphics (likely RSTAR or PVG) + "Generate a 3D model of a house based on the given floor plan", + # Bioinformatics (likely Z3 or LEAP) + "Find potential binding sites for a given protein sequence in a DNA strand", + # Automated Reasoning (likely COT_REFLECTION or Z3) + "Given a set of logical statements, determine if the conclusion follows", + # Natural Language Generation (likely PVG or SELF_CONSISTENCY) + "Write a short story in the style of Edgar Allan Poe about a haunted lighthouse" ] effort_levels = [0.0, 0.2, 0.5, 0.8, 1.0] @@ -310,8 +353,8 @@ def main(args): parser = argparse.ArgumentParser(description="Train OptILM classifier") parser.add_argument("--model_name", type=str, default="google-bert/bert-large-uncased", help="Pretrained model name") parser.add_argument("--batch_size", type=int, default=4, help="Batch size for training") - parser.add_argument("--learning_rate", type=float, default=1e-6, help="Learning rate") - parser.add_argument("--num_epochs", type=int, default=10, help="Maximum number of training epochs") + parser.add_argument("--learning_rate", type=float, default=5e-7, help="Learning rate") + parser.add_argument("--num_epochs", type=int, default=20, help="Maximum number of training epochs") parser.add_argument("--push_to_hub", action="store_true", help="Push model to Hugging Face Hub") parser.add_argument("--hub_model_id", type=str, help="Model ID for Hugging Face Hub") parser.add_argument("--k_folds", type=int, default=5, help="Number of folds for cross-validation") @@ -319,4 +362,5 @@ def main(args): parser.add_argument("--clip_value", type=float, default=1.0, help="Gradient clipping value") args = parser.parse_args() - main(args) \ No newline at end of file + main(args) + \ No newline at end of file From a8e1b2033f09d652d76fb10dfb2c6866b9b33164 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Tue, 5 Nov 2024 21:58:13 +0800 Subject: [PATCH 2/4] Update optillm.py --- optillm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/optillm.py b/optillm.py index 12b196f1..64ac46f9 100644 --- a/optillm.py +++ b/optillm.py @@ -12,9 +12,6 @@ import re from concurrent.futures import ThreadPoolExecutor -# Import the LiteLLM wrapper -from optillm.litellm_wrapper import LiteLLMWrapper - # Import approach modules from optillm.mcts import chat_with_mcts from optillm.bon import best_of_n_sampling @@ -74,6 +71,8 @@ def get_config(): azure_ad_token_provider=token_provider ) else: + # Import the LiteLLM wrapper + from optillm.litellm_wrapper import LiteLLMWrapper default_client = LiteLLMWrapper() return default_client, API_KEY From 18eda14740b455abee84f84cd321f2f4c6086f19 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Thu, 7 Nov 2024 10:04:39 +0800 Subject: [PATCH 3/4] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index c758e595..18725d59 100644 --- a/README.md +++ b/README.md @@ -180,6 +180,7 @@ or your own code where you want to use the results from optillm. You can use it | Plugin | Slug | Description | | ----------------------- | ------------------ | ---------------------------------------------------------------------------------------------- | +| Router | `router` | Uses the [optillm-bert-uncased](https://huggingface.co/codelion/optillm-bert-uncased) model to route requests to different approaches based on the user prompt | | Memory | `memory` | Implements a short term memory layer, enables you to use unbounded context length with any LLM | | Privacy | `privacy` | Anonymize PII data in request and deanonymize it back to original value in response | | Read URLs | `readurls` | Reads all URLs found in the request, fetches the content at the URL and adds it to the context | From fb8bd63da4bc6a26d20fba9e4df371d60fa82cf7 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Thu, 7 Nov 2024 10:14:13 +0800 Subject: [PATCH 4/4] Update setup.py bump version for new release --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c90ea4bc..23b3a365 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="optillm", - version="0.0.8", + version="0.0.9", packages=find_packages(), py_modules=['optillm'], package_data={