diff --git a/README.md b/README.md index 1f97e2e4..ae868cd6 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ optillm is an OpenAI API compatible optimizing inference proxy which implements several state-of-the-art techniques that can improve the accuracy and performance of LLMs. The current focus is on implementing techniques that improve reasoning over coding, logical and mathematical queries. It is possible to beat the frontier models using these techniques across diverse tasks by doing additional compute at inference time. [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/codelion/optillm) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1SpuUb8d9xAoTh32M-9wJsB50AOH54EaH?usp=sharing) ## Patchwork with optillm @@ -142,6 +143,7 @@ or your own code where you want to use the results from optillm. You can use it | PlanSearch | `plansearch` | Implements a search algorithm over candidate plans for solving a problem in natural language | | LEAP | `leap` | Learns task-specific principles from few shot examples | | ReRead | `re2` | Implements rereading to improve reasoning by processing queries twice | +| CoT Decoding | N/A for proxy | Implements chain-of-thought decoding to elicit reasoning without explicit prompting | ## Available Parameters @@ -211,6 +213,7 @@ Authorization: Bearer your_secret_api_key ## References +- [Chain-of-Thought Reasoning Without Prompting](https://arxiv.org/abs/2402.10200) - [Re-Reading Improves Reasoning in Large Language Models](https://arxiv.org/abs/2309.06275) - [In-Context Principle Learning from Mistakes](https://arxiv.org/abs/2402.05403) - [Planning In Natural Language Improves LLM Search For Code Generation](https://arxiv.org/abs/2409.03733) diff --git a/optillm/cot_decoding.py b/optillm/cot_decoding.py new file mode 100644 index 00000000..2f5b173c --- /dev/null +++ b/optillm/cot_decoding.py @@ -0,0 +1,145 @@ +import torch +from transformers import PreTrainedModel, PreTrainedTokenizer +from typing import List, Tuple, Dict, Optional +import numpy as np + +def get_device(): + if torch.cuda.is_available(): + return torch.device("cuda") + else: + return torch.device("cpu") + +def calculate_confidence(logits: List[torch.Tensor], answer_ids: torch.Tensor) -> float: + """ + Calculate the confidence score (Δ) as specified in the paper. + + Args: + logits: List of logits for each decoding step + answer_ids: Tensor of token ids for the answer + + Returns: + Confidence score (Δ) + """ + confidence_sum = 0.0 + valid_tokens = 0 + for t, token_id in enumerate(answer_ids): + if t >= len(logits): + break + token_logits = logits[t] + probs = torch.softmax(token_logits, dim=-1) + if probs.size(0) > 1: + top_2_probs, _ = torch.topk(probs, min(2, probs.size(0))) + if top_2_probs.size(0) > 1: + confidence_sum += (top_2_probs[0] - top_2_probs[1]).item() + else: + confidence_sum += 1.0 # Max confidence if there's only one token + else: + confidence_sum += 1.0 # Max confidence if there's only one token + valid_tokens += 1 + + return confidence_sum / valid_tokens if valid_tokens > 0 else 0.0 + +def aggregate_paths_based_on_scores(paths: List[Tuple[str, float]]) -> Tuple[str, float]: + """Aggregate multiple paths based on their confidence scores.""" + answer_scores = {} + for answer, delta in paths: + answer_scores[answer] = answer_scores.get(answer, 0) + delta + best_answer = max(answer_scores, key=answer_scores.get) + return best_answer, answer_scores[best_answer] + +def cot_decode( + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + messages: List[Dict[str, str]], + k: int = 10, + num_beams: int = 1, + max_new_tokens: int = 512, + temperature: float = 1.0, + top_p: float = 1.0, + repetition_penalty: float = 1.0, + length_penalty: float = 1.0, + no_repeat_ngram_size: int = 0, + early_stopping: bool = False, + aggregate_paths: bool = False, +) -> Tuple[str, float]: + """ + Implement CoT-decoding for a given chat input. + + Args: + model: The Hugging Face transformer model. + tokenizer: The associated tokenizer. + messages: List of chat messages in the format [{"role": "user", "content": "..."}] + k: The number of alternative tokens to consider at the first step. + num_beams: Number of beams for beam search. + max_new_tokens: Maximum number of new tokens to generate. + temperature: Sampling temperature. + top_p: Nucleus sampling probability. + repetition_penalty: Repetition penalty factor. + length_penalty: Length penalty factor. + no_repeat_ngram_size: Size of n-grams to avoid repeating. + early_stopping: Whether to stop generation when all beams are finished. + aggregate_paths: Whether to aggregate multiple paths. + + Returns: + A tuple containing the best path (or aggregated result) and its confidence score. + """ + device = get_device() + model.to(device) + + # Use the chat template to format the input + if tokenizer.chat_template: + input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + else: + # Fallback for tokenizers without chat templates + input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + input_text += "\nassistant:" + + input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device) + attention_mask = torch.ones_like(input_ids).to(device) + + # Set pad_token_id if it's not set + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + # Get the top-k tokens for the first decoding step + with torch.no_grad(): + outputs = model(input_ids, attention_mask=attention_mask) + first_token_logits = outputs.logits[0, -1, :] + top_k_logits, top_k_indices = torch.topk(first_token_logits, k) + + paths = [] + for idx in top_k_indices: + # Generate sequence starting with the selected token + start_ids = torch.cat([input_ids, idx.unsqueeze(0).unsqueeze(0)], dim=-1) + start_mask = torch.cat([attention_mask, torch.ones((1, 1), dtype=torch.long, device=device)], dim=-1) + + output = model.generate( + start_ids, + attention_mask=start_mask, + max_new_tokens=max_new_tokens, + num_beams=num_beams, + temperature=temperature, + top_p=top_p, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + early_stopping=early_stopping, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + output_scores=True, + return_dict_in_generate=True, + ) + + generated_sequence = output.sequences[0] + answer_ids = generated_sequence[len(input_ids[0]):] + answer_text = tokenizer.decode(answer_ids, skip_special_tokens=True) + + # Calculate confidence score (Δ) + confidence = calculate_confidence(output.scores, answer_ids) + paths.append((answer_text, confidence)) + + if aggregate_paths: + return aggregate_paths_based_on_scores(paths) + else: + return max(paths, key=lambda x: x[1]) + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 834a0567..f771d90d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,6 @@ openai z3-solver aiohttp flask -azure.identity \ No newline at end of file +torch +transformers +azure.identity diff --git a/setup.py b/setup.py index e72f3a90..61579c36 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,8 @@ "z3-solver", "aiohttp", "flask", + "torch", + "transformers", "azure-identity", ], author="codelion",