From 1da61110628b198eb8322a570c0c908a6dc19b74 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Thu, 19 Sep 2024 04:28:05 -0700 Subject: [PATCH 01/13] initial impl --- cot_decoding.py | 162 +++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 4 +- 2 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 cot_decoding.py diff --git a/cot_decoding.py b/cot_decoding.py new file mode 100644 index 00000000..50d83472 --- /dev/null +++ b/cot_decoding.py @@ -0,0 +1,162 @@ +import torch +from transformers import PreTrainedModel, PreTrainedTokenizer +from typing import List, Tuple, Dict, Optional +import numpy as np + +def get_device(): + if torch.cuda.is_available(): + return torch.device("cuda") + else: + return torch.device("cpu") + +def calculate_confidence(logits: List[torch.Tensor], answer_ids: torch.Tensor) -> float: + """ + Calculate the confidence score (Δ) as specified in the paper. + + Args: + logits: List of logits for each decoding step + answer_ids: Tensor of token ids for the answer + + Returns: + Confidence score (Δ) + """ + confidence_sum = 0.0 + valid_tokens = 0 + for t, token_id in enumerate(answer_ids): + if t >= len(logits): + break + token_logits = logits[t] + probs = torch.softmax(token_logits, dim=-1) + if probs.size(0) > 1: + top_2_probs, _ = torch.topk(probs, min(2, probs.size(0))) + if top_2_probs.size(0) > 1: + confidence_sum += (top_2_probs[0] - top_2_probs[1]).item() + else: + confidence_sum += 1.0 # Max confidence if there's only one token + else: + confidence_sum += 1.0 # Max confidence if there's only one token + valid_tokens += 1 + + return confidence_sum / valid_tokens if valid_tokens > 0 else 0.0 + +def aggregate_paths_based_on_scores(paths: List[Tuple[str, float]]) -> Tuple[str, float]: + """Aggregate multiple paths based on their confidence scores.""" + answer_scores = {} + for answer, delta in paths: + answer_scores[answer] = answer_scores.get(answer, 0) + delta + best_answer = max(answer_scores, key=answer_scores.get) + return best_answer, answer_scores[best_answer] + +def cot_decode( + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + messages: List[Dict[str, str]], + k: int = 10, + num_beams: int = 1, + max_length: int = 512, + temperature: float = 1.0, + top_p: float = 1.0, + repetition_penalty: float = 1.0, + length_penalty: float = 1.0, + no_repeat_ngram_size: int = 0, + early_stopping: bool = False, + aggregate_paths: bool = False, +) -> Tuple[str, float]: + """ + Implement CoT-decoding for a given chat input. + + Args: + model: The Hugging Face transformer model. + tokenizer: The associated tokenizer. + messages: List of chat messages in the format [{"role": "user", "content": "..."}] + k: The number of alternative tokens to consider at the first step. + num_beams: Number of beams for beam search. + max_length: Maximum length of generated sequence. + temperature: Sampling temperature. + top_p: Nucleus sampling probability. + repetition_penalty: Repetition penalty factor. + length_penalty: Length penalty factor. + no_repeat_ngram_size: Size of n-grams to avoid repeating. + early_stopping: Whether to stop generation when all beams are finished. + aggregate_paths: Whether to aggregate multiple paths. + + Returns: + A tuple containing the best path (or aggregated result) and its confidence score. + """ + device = get_device() + model.to(device) + + # Use the chat template to format the input + if tokenizer.chat_template: + input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + else: + # Fallback for tokenizers without chat templates + input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + input_text += "\nassistant:" + + input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device) + attention_mask = torch.ones_like(input_ids).to(device) + + # Set pad_token_id if it's not set + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + # Get the top-k tokens for the first decoding step + with torch.no_grad(): + outputs = model(input_ids, attention_mask=attention_mask) + first_token_logits = outputs.logits[0, -1, :] + top_k_logits, top_k_indices = torch.topk(first_token_logits, k) + + paths = [] + for idx in top_k_indices: + # Generate sequence starting with the selected token + start_ids = torch.cat([input_ids, idx.unsqueeze(0).unsqueeze(0)], dim=-1) + start_mask = torch.cat([attention_mask, torch.ones((1, 1), dtype=torch.long, device=device)], dim=-1) + + output = model.generate( + start_ids, + attention_mask=start_mask, + max_length=max_length, + num_beams=num_beams, + temperature=temperature, + top_p=top_p, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + early_stopping=early_stopping, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + output_scores=True, + return_dict_in_generate=True, + ) + + generated_sequence = output.sequences[0] + answer_ids = generated_sequence[len(input_ids[0]):] + answer_text = tokenizer.decode(answer_ids, skip_special_tokens=True) + + # Calculate confidence score (Δ) + confidence = calculate_confidence(output.scores, answer_ids) + paths.append((answer_text, confidence)) + + if aggregate_paths: + return aggregate_paths_based_on_scores(paths) + else: + return max(paths, key=lambda x: x[1]) + +# Example usage +if __name__ == "__main__": + from transformers import AutoModelForCausalLM, AutoTokenizer + + # Change this to a chat-based model that supports chat templates + model_name = "Qwen/Qwen2.5-0.5B-Instruct" + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + messages = [ + {"role": "user", "content": "I have 3 apples, my dad has 2 more apples than me, how many apples do we have in total?"} + ] + + result, confidence = cot_decode(model, tokenizer, messages, aggregate_paths=True) + print(f"Result: {result}") + print(f"Confidence: {confidence}") + print(f"Using device: {get_device()}") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2bbd104c..ee68a52d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,6 @@ networkx openai z3-solver aiohttp -flask \ No newline at end of file +flask +torch +transformers \ No newline at end of file From 051db8679f356d955ab64f497a91498c5f3cfc1d Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Thu, 19 Sep 2024 07:06:39 -0700 Subject: [PATCH 02/13] Update optillm.py --- optillm.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/optillm.py b/optillm.py index 1044b7aa..22d1beb5 100644 --- a/optillm.py +++ b/optillm.py @@ -4,6 +4,7 @@ import secrets from flask import Flask, request, jsonify from openai import AzureOpenAI, OpenAI +from transformers import AutoModelForCausalLM, AutoTokenizer # Import approach modules from mcts import chat_with_mcts @@ -18,6 +19,7 @@ from plansearch import plansearch from leap import leap from agent import agent_approach +from cot_decoding import cot_decode # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @@ -57,7 +59,7 @@ # List of known approaches known_approaches = ["mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", - "cot_reflection", "plansearch", "leap", "agent"] + "cot_reflection", "plansearch", "leap", "agent", "cot_decoding"] # Optional API key configuration to secure the proxy @app.before_request @@ -139,6 +141,10 @@ def proxy(): final_response = leap(system_prompt, initial_query, client, model) elif approach == 'agent': final_response = agent_approach(system_prompt, initial_query, client, model, max_attempts=3) + elif approach == 'cot_decoding': + local_model = AutoModelForCausalLM.from_pretrained(model) + tokenizer = AutoTokenizer.from_pretrained(local_model) + final_response, _ = cot_decode(local_model, tokenizer, messages, aggregate_paths=True) else: raise ValueError(f"Unknown approach: {approach}") except Exception as e: From 3eb177686a93606a47d79f3b15ccf5a17e9e5d36 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Thu, 19 Sep 2024 10:02:29 -0700 Subject: [PATCH 03/13] Update optillm.py --- optillm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optillm.py b/optillm.py index 22d1beb5..7d733751 100644 --- a/optillm.py +++ b/optillm.py @@ -143,8 +143,8 @@ def proxy(): final_response = agent_approach(system_prompt, initial_query, client, model, max_attempts=3) elif approach == 'cot_decoding': local_model = AutoModelForCausalLM.from_pretrained(model) - tokenizer = AutoTokenizer.from_pretrained(local_model) - final_response, _ = cot_decode(local_model, tokenizer, messages, aggregate_paths=True) + tokenizer = AutoTokenizer.from_pretrained(model) + final_response, _ = cot_decode(local_model, tokenizer, messages, max_length=2048, aggregate_paths=True) else: raise ValueError(f"Unknown approach: {approach}") except Exception as e: From bc0e982bf137ba00003ca4e62dcdc59531189ac1 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Fri, 20 Sep 2024 16:17:48 -0700 Subject: [PATCH 04/13] move file --- cot_decoding.py => optillm/cot_decoding.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename cot_decoding.py => optillm/cot_decoding.py (100%) diff --git a/cot_decoding.py b/optillm/cot_decoding.py similarity index 100% rename from cot_decoding.py rename to optillm/cot_decoding.py From bc56731e4dc2248fdf300717a0555582322a4bdf Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Fri, 20 Sep 2024 16:22:30 -0700 Subject: [PATCH 05/13] Update setup.py --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index e72f3a90..07b0edbb 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,8 @@ "z3-solver", "aiohttp", "flask", + "torch", + "transformers" "azure-identity", ], author="codelion", From 211d3ace2c3c59679dd9ff4abae45a8aaa345d22 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Fri, 20 Sep 2024 16:22:54 -0700 Subject: [PATCH 06/13] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 07b0edbb..61579c36 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ "aiohttp", "flask", "torch", - "transformers" + "transformers", "azure-identity", ], author="codelion", From ffff9263b0ba211487370c607b48214752d8b0ad Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Fri, 20 Sep 2024 16:37:29 -0700 Subject: [PATCH 07/13] Update cot_decoding.py --- optillm/cot_decoding.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/optillm/cot_decoding.py b/optillm/cot_decoding.py index 50d83472..9e5994a4 100644 --- a/optillm/cot_decoding.py +++ b/optillm/cot_decoding.py @@ -142,21 +142,4 @@ def cot_decode( return aggregate_paths_based_on_scores(paths) else: return max(paths, key=lambda x: x[1]) - -# Example usage -if __name__ == "__main__": - from transformers import AutoModelForCausalLM, AutoTokenizer - - # Change this to a chat-based model that supports chat templates - model_name = "Qwen/Qwen2.5-0.5B-Instruct" - model = AutoModelForCausalLM.from_pretrained(model_name) - tokenizer = AutoTokenizer.from_pretrained(model_name) - - messages = [ - {"role": "user", "content": "I have 3 apples, my dad has 2 more apples than me, how many apples do we have in total?"} - ] - - result, confidence = cot_decode(model, tokenizer, messages, aggregate_paths=True) - print(f"Result: {result}") - print(f"Confidence: {confidence}") - print(f"Using device: {get_device()}") \ No newline at end of file + \ No newline at end of file From 0fc305d65cee24f350f8221e79b936e2d0704025 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Fri, 20 Sep 2024 17:25:50 -0700 Subject: [PATCH 08/13] Update optillm.py --- optillm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optillm.py b/optillm.py index 91a684cf..020b4961 100644 --- a/optillm.py +++ b/optillm.py @@ -16,6 +16,7 @@ from optillm.z3_solver import Z3SolverSystem from optillm.rstar import RStar from optillm.cot_reflection import cot_reflection +from optillm.cot_decoding import cot_decode from optillm.plansearch import plansearch from optillm.leap import leap from optillm.agent import agent_approach From bb9c56dbfdc610000fecb87984f02e0677e12089 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Sat, 21 Sep 2024 03:22:19 -0700 Subject: [PATCH 09/13] Update cot_decoding.py --- optillm/cot_decoding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optillm/cot_decoding.py b/optillm/cot_decoding.py index 9e5994a4..2f5b173c 100644 --- a/optillm/cot_decoding.py +++ b/optillm/cot_decoding.py @@ -53,7 +53,7 @@ def cot_decode( messages: List[Dict[str, str]], k: int = 10, num_beams: int = 1, - max_length: int = 512, + max_new_tokens: int = 512, temperature: float = 1.0, top_p: float = 1.0, repetition_penalty: float = 1.0, @@ -71,7 +71,7 @@ def cot_decode( messages: List of chat messages in the format [{"role": "user", "content": "..."}] k: The number of alternative tokens to consider at the first step. num_beams: Number of beams for beam search. - max_length: Maximum length of generated sequence. + max_new_tokens: Maximum number of new tokens to generate. temperature: Sampling temperature. top_p: Nucleus sampling probability. repetition_penalty: Repetition penalty factor. @@ -116,7 +116,7 @@ def cot_decode( output = model.generate( start_ids, attention_mask=start_mask, - max_length=max_length, + max_new_tokens=max_new_tokens, num_beams=num_beams, temperature=temperature, top_p=top_p, From f0c964022629dde0e9411de34b8d3c95342a31a3 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 23 Sep 2024 16:38:59 -0700 Subject: [PATCH 10/13] Update optillm.py --- optillm.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/optillm.py b/optillm.py index 020b4961..c7482037 100644 --- a/optillm.py +++ b/optillm.py @@ -4,7 +4,6 @@ import secrets from flask import Flask, request, jsonify from openai import AzureOpenAI, OpenAI -from transformers import AutoModelForCausalLM, AutoTokenizer # Import approach modules from optillm.mcts import chat_with_mcts @@ -16,7 +15,6 @@ from optillm.z3_solver import Z3SolverSystem from optillm.rstar import RStar from optillm.cot_reflection import cot_reflection -from optillm.cot_decoding import cot_decode from optillm.plansearch import plansearch from optillm.leap import leap from optillm.agent import agent_approach @@ -72,7 +70,7 @@ # List of known approaches known_approaches = ["mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", - "cot_reflection", "plansearch", "leap", "agent", "cot_decoding"] + "cot_reflection", "plansearch", "leap", "agent"] # Optional API key configuration to secure the proxy @app.before_request @@ -154,10 +152,6 @@ def proxy(): final_response = leap(system_prompt, initial_query, client, model) elif approach == 'agent': final_response = agent_approach(system_prompt, initial_query, client, model, max_attempts=3) - elif approach == 'cot_decoding': - local_model = AutoModelForCausalLM.from_pretrained(model) - tokenizer = AutoTokenizer.from_pretrained(model) - final_response, _ = cot_decode(local_model, tokenizer, messages, max_length=2048, aggregate_paths=True) else: raise ValueError(f"Unknown approach: {approach}") except Exception as e: From f4cd64b57c3746d3b8ccab79c7e6ad7d33a730a8 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 23 Sep 2024 16:41:21 -0700 Subject: [PATCH 11/13] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 1f97e2e4..53d65562 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ optillm is an OpenAI API compatible optimizing inference proxy which implements several state-of-the-art techniques that can improve the accuracy and performance of LLMs. The current focus is on implementing techniques that improve reasoning over coding, logical and mathematical queries. It is possible to beat the frontier models using these techniques across diverse tasks by doing additional compute at inference time. [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/codelion/optillm) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1SpuUb8d9xAoTh32M-9wJsB50AOH54EaH?usp=sharing) ## Patchwork with optillm From d14c8df2de79b70ed9290ff72872d1f1d9c55dcd Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 23 Sep 2024 16:42:34 -0700 Subject: [PATCH 12/13] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 53d65562..564bc46f 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,7 @@ Authorization: Bearer your_secret_api_key ## References +- [Chain-of-Thought Reasoning Without Prompting](https://arxiv.org/abs/2402.10200) - [Re-Reading Improves Reasoning in Large Language Models](https://arxiv.org/abs/2309.06275) - [In-Context Principle Learning from Mistakes](https://arxiv.org/abs/2402.05403) - [Planning In Natural Language Improves LLM Search For Code Generation](https://arxiv.org/abs/2409.03733) From 390d787cb2c43f6e7a33afbf7cc97505ee560628 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 23 Sep 2024 16:45:58 -0700 Subject: [PATCH 13/13] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 564bc46f..ae868cd6 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,7 @@ or your own code where you want to use the results from optillm. You can use it | PlanSearch | `plansearch` | Implements a search algorithm over candidate plans for solving a problem in natural language | | LEAP | `leap` | Learns task-specific principles from few shot examples | | ReRead | `re2` | Implements rereading to improve reasoning by processing queries twice | +| CoT Decoding | N/A for proxy | Implements chain-of-thought decoding to elicit reasoning without explicit prompting | ## Available Parameters