Skip to content

Commit d99727a

Browse files
authored
Merge pull request #28 from codelion/feat-cot-decoding
Feat cot decoding
2 parents 22c9fc4 + 390d787 commit d99727a

File tree

4 files changed

+153
-1
lines changed

4 files changed

+153
-1
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
optillm is an OpenAI API compatible optimizing inference proxy which implements several state-of-the-art techniques that can improve the accuracy and performance of LLMs. The current focus is on implementing techniques that improve reasoning over coding, logical and mathematical queries. It is possible to beat the frontier models using these techniques across diverse tasks by doing additional compute at inference time.
44

55
[![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/codelion/optillm)
6+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1SpuUb8d9xAoTh32M-9wJsB50AOH54EaH?usp=sharing)
67

78
## Patchwork with optillm
89

@@ -142,6 +143,7 @@ or your own code where you want to use the results from optillm. You can use it
142143
| PlanSearch | `plansearch` | Implements a search algorithm over candidate plans for solving a problem in natural language |
143144
| LEAP | `leap` | Learns task-specific principles from few shot examples |
144145
| ReRead | `re2` | Implements rereading to improve reasoning by processing queries twice |
146+
| CoT Decoding | N/A for proxy | Implements chain-of-thought decoding to elicit reasoning without explicit prompting |
145147

146148
## Available Parameters
147149

@@ -211,6 +213,7 @@ Authorization: Bearer your_secret_api_key
211213

212214
## References
213215

216+
- [Chain-of-Thought Reasoning Without Prompting](https://arxiv.org/abs/2402.10200)
214217
- [Re-Reading Improves Reasoning in Large Language Models](https://arxiv.org/abs/2309.06275)
215218
- [In-Context Principle Learning from Mistakes](https://arxiv.org/abs/2402.05403)
216219
- [Planning In Natural Language Improves LLM Search For Code Generation](https://arxiv.org/abs/2409.03733)

optillm/cot_decoding.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import torch
2+
from transformers import PreTrainedModel, PreTrainedTokenizer
3+
from typing import List, Tuple, Dict, Optional
4+
import numpy as np
5+
6+
def get_device():
7+
if torch.cuda.is_available():
8+
return torch.device("cuda")
9+
else:
10+
return torch.device("cpu")
11+
12+
def calculate_confidence(logits: List[torch.Tensor], answer_ids: torch.Tensor) -> float:
13+
"""
14+
Calculate the confidence score (Δ) as specified in the paper.
15+
16+
Args:
17+
logits: List of logits for each decoding step
18+
answer_ids: Tensor of token ids for the answer
19+
20+
Returns:
21+
Confidence score (Δ)
22+
"""
23+
confidence_sum = 0.0
24+
valid_tokens = 0
25+
for t, token_id in enumerate(answer_ids):
26+
if t >= len(logits):
27+
break
28+
token_logits = logits[t]
29+
probs = torch.softmax(token_logits, dim=-1)
30+
if probs.size(0) > 1:
31+
top_2_probs, _ = torch.topk(probs, min(2, probs.size(0)))
32+
if top_2_probs.size(0) > 1:
33+
confidence_sum += (top_2_probs[0] - top_2_probs[1]).item()
34+
else:
35+
confidence_sum += 1.0 # Max confidence if there's only one token
36+
else:
37+
confidence_sum += 1.0 # Max confidence if there's only one token
38+
valid_tokens += 1
39+
40+
return confidence_sum / valid_tokens if valid_tokens > 0 else 0.0
41+
42+
def aggregate_paths_based_on_scores(paths: List[Tuple[str, float]]) -> Tuple[str, float]:
43+
"""Aggregate multiple paths based on their confidence scores."""
44+
answer_scores = {}
45+
for answer, delta in paths:
46+
answer_scores[answer] = answer_scores.get(answer, 0) + delta
47+
best_answer = max(answer_scores, key=answer_scores.get)
48+
return best_answer, answer_scores[best_answer]
49+
50+
def cot_decode(
51+
model: PreTrainedModel,
52+
tokenizer: PreTrainedTokenizer,
53+
messages: List[Dict[str, str]],
54+
k: int = 10,
55+
num_beams: int = 1,
56+
max_new_tokens: int = 512,
57+
temperature: float = 1.0,
58+
top_p: float = 1.0,
59+
repetition_penalty: float = 1.0,
60+
length_penalty: float = 1.0,
61+
no_repeat_ngram_size: int = 0,
62+
early_stopping: bool = False,
63+
aggregate_paths: bool = False,
64+
) -> Tuple[str, float]:
65+
"""
66+
Implement CoT-decoding for a given chat input.
67+
68+
Args:
69+
model: The Hugging Face transformer model.
70+
tokenizer: The associated tokenizer.
71+
messages: List of chat messages in the format [{"role": "user", "content": "..."}]
72+
k: The number of alternative tokens to consider at the first step.
73+
num_beams: Number of beams for beam search.
74+
max_new_tokens: Maximum number of new tokens to generate.
75+
temperature: Sampling temperature.
76+
top_p: Nucleus sampling probability.
77+
repetition_penalty: Repetition penalty factor.
78+
length_penalty: Length penalty factor.
79+
no_repeat_ngram_size: Size of n-grams to avoid repeating.
80+
early_stopping: Whether to stop generation when all beams are finished.
81+
aggregate_paths: Whether to aggregate multiple paths.
82+
83+
Returns:
84+
A tuple containing the best path (or aggregated result) and its confidence score.
85+
"""
86+
device = get_device()
87+
model.to(device)
88+
89+
# Use the chat template to format the input
90+
if tokenizer.chat_template:
91+
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
92+
else:
93+
# Fallback for tokenizers without chat templates
94+
input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
95+
input_text += "\nassistant:"
96+
97+
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
98+
attention_mask = torch.ones_like(input_ids).to(device)
99+
100+
# Set pad_token_id if it's not set
101+
if tokenizer.pad_token_id is None:
102+
tokenizer.pad_token_id = tokenizer.eos_token_id
103+
104+
# Get the top-k tokens for the first decoding step
105+
with torch.no_grad():
106+
outputs = model(input_ids, attention_mask=attention_mask)
107+
first_token_logits = outputs.logits[0, -1, :]
108+
top_k_logits, top_k_indices = torch.topk(first_token_logits, k)
109+
110+
paths = []
111+
for idx in top_k_indices:
112+
# Generate sequence starting with the selected token
113+
start_ids = torch.cat([input_ids, idx.unsqueeze(0).unsqueeze(0)], dim=-1)
114+
start_mask = torch.cat([attention_mask, torch.ones((1, 1), dtype=torch.long, device=device)], dim=-1)
115+
116+
output = model.generate(
117+
start_ids,
118+
attention_mask=start_mask,
119+
max_new_tokens=max_new_tokens,
120+
num_beams=num_beams,
121+
temperature=temperature,
122+
top_p=top_p,
123+
repetition_penalty=repetition_penalty,
124+
length_penalty=length_penalty,
125+
no_repeat_ngram_size=no_repeat_ngram_size,
126+
early_stopping=early_stopping,
127+
pad_token_id=tokenizer.pad_token_id,
128+
eos_token_id=tokenizer.eos_token_id,
129+
output_scores=True,
130+
return_dict_in_generate=True,
131+
)
132+
133+
generated_sequence = output.sequences[0]
134+
answer_ids = generated_sequence[len(input_ids[0]):]
135+
answer_text = tokenizer.decode(answer_ids, skip_special_tokens=True)
136+
137+
# Calculate confidence score (Δ)
138+
confidence = calculate_confidence(output.scores, answer_ids)
139+
paths.append((answer_text, confidence))
140+
141+
if aggregate_paths:
142+
return aggregate_paths_based_on_scores(paths)
143+
else:
144+
return max(paths, key=lambda x: x[1])
145+

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@ openai
44
z3-solver
55
aiohttp
66
flask
7-
azure.identity
7+
torch
8+
transformers
9+
azure.identity

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
"z3-solver",
1212
"aiohttp",
1313
"flask",
14+
"torch",
15+
"transformers",
1416
"azure-identity",
1517
],
1618
author="codelion",

0 commit comments

Comments
 (0)