Skip to content
Merged

CePO #152

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ce62922
Update requirements to include Cerebras dependency
pawelf-cerebras Dec 18, 2024
d0d2570
Add support for CePO
pawelf-cerebras Dec 18, 2024
e83d89c
Initial version of CePO
pawelf-cerebras Dec 18, 2024
8f55a56
Add support for Cerebras client
pawelf-cerebras Dec 19, 2024
82bc4b8
Add pairwise rating and clean up
pawelf-cerebras Dec 19, 2024
0617e0c
Fix default rating type
pawelf-cerebras Dec 19, 2024
8b4617e
Add modification of CePO configs through yaml and cli arguments
erich-cerebras Dec 19, 2024
48a13eb
Fix default CePO config yaml
erich-cerebras Dec 19, 2024
3ef183b
Fix check if cepo_config_file provided
pawelf-cerebras Dec 26, 2024
83dfb30
Add cepo to readme
pawelf-cerebras Jan 2, 2025
581e26f
Minor cleaning
pawelf-cerebras Jan 2, 2025
20c98a0
Add VS Code to ignore list
pawelf-cerebras Jan 2, 2025
e168ea4
Minor readibility improvements
pawelf-cerebras Jan 3, 2025
1b3aa8a
Removed unnecessary comment
pawelf-cerebras Jan 6, 2025
05ff108
Add cepo results
pawelf-cerebras Jan 9, 2025
bef08cf
Make cepo_config.yaml define the default values instead of the datacl…
pawelf-cerebras Jan 9, 2025
cbebb2c
Updated documentation of CePO
pawelf-cerebras Jan 9, 2025
c39c1cf
Add description of CePO method
pawelf-cerebras Jan 9, 2025
d845eef
Update CePO section of README
pawelf-cerebras Jan 10, 2025
a1ae99e
Add results for LiveCodeBench and SimpleQA
pawelf-cerebras Jan 10, 2025
3807724
Correct type of the output of cepo
pawelf-cerebras Jan 13, 2025
2337abb
Minor fixes and add docstrings
pawelf-cerebras Jan 13, 2025
643ad9c
Updatee README.md to add the discord/research channel link
emmac-cerebras Jan 15, 2025
76b6181
Create NOTICE.md
emmac-cerebras Jan 15, 2025
d75239b
Update .gitignore
emmac-cerebras Jan 15, 2025
9ad5957
Update NOTICE.md
emmac-cerebras Jan 15, 2025
5f5553c
Update optillm.py
emmac-cerebras Jan 15, 2025
23035b5
Update cepo.py
emmac-cerebras Jan 15, 2025
550421f
Fix typo
pawelf-cerebras Jan 15, 2025
e1bc7da
Make cepo_config required parameter
pawelf-cerebras Jan 17, 2025
42e5f6a
Remove unused imports
pawelf-cerebras Jan 20, 2025
474e60b
Add debug logging to CePO
pawelf-cerebras Jan 20, 2025
e7dbd1a
Remove unneeded license notes
pawelf-cerebras Jan 21, 2025
f72224e
Add a flag to print intermediate outputs in CePO
pawelf-cerebras Jan 21, 2025
b3beadf
Fix formatting
pawelf-cerebras Jan 21, 2025
b54bb0d
Update README.md
emmac-cerebras Jan 22, 2025
81847da
Update README.md
emmac-cerebras Jan 22, 2025
92877dc
Update README.md
emmac-cerebras Jan 22, 2025
7d9c9ee
Move CePO into its own directory
pawelf-cerebras Jan 23, 2025
888bfd1
Move CePO documentation to its own README
pawelf-cerebras Jan 23, 2025
a2c301b
Revert deletion of a comment
pawelf-cerebras Jan 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,6 @@ cython_debug/
# Ignore Mac DS_Store files
.DS_Store
**/.DS_Store

# VS Code
.vscode/
90 changes: 58 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,22 +212,23 @@ response = client.chat.completions.create(

## Implemented techniques

| Approach | Slug | Description |
| ----------------------- | ------------------ | ---------------------------------------------------------------------------------------------- |
| CoT with Reflection | `cot_reflection` | Implements chain-of-thought reasoning with \<thinking\>, \<reflection> and \<output\> sections |
| PlanSearch | `plansearch` | Implements a search algorithm over candidate plans for solving a problem in natural language |
| ReRead | `re2` | Implements rereading to improve reasoning by processing queries twice |
| Self-Consistency | `self_consistency` | Implements an advanced self-consistency method |
| Z3 Solver | `z3` | Utilizes the Z3 theorem prover for logical reasoning |
| R* Algorithm | `rstar` | Implements the R* algorithm for problem-solving |
| LEAP | `leap` | Learns task-specific principles from few shot examples |
| Round Trip Optimization | `rto` | Optimizes responses through a round-trip process |
| Best of N Sampling | `bon` | Generates multiple responses and selects the best one |
| Mixture of Agents | `moa` | Combines responses from multiple critiques |
| Monte Carlo Tree Search | `mcts` | Uses MCTS for decision-making in chat responses |
| PV Game | `pvg` | Applies a prover-verifier game approach at inference time |
| CoT Decoding | N/A for proxy | Implements chain-of-thought decoding to elicit reasoning without explicit prompting |
| Entropy Decoding | N/A for proxy | Implements adaptive sampling based on the uncertainty of tokens during generation |
| Approach | Slug | Description |
| ------------------------------------ | ------------------ | ---------------------------------------------------------------------------------------------- |
| Cerebras Planning and Optimimization | `cepo` | Combines Best of N, Chain-of-Thought, Self-Reflection, Self-Improvement, and various prompting techniques |
| CoT with Reflection | `cot_reflection` | Implements chain-of-thought reasoning with \<thinking\>, \<reflection> and \<output\> sections |
| PlanSearch | `plansearch` | Implements a search algorithm over candidate plans for solving a problem in natural language |
| ReRead | `re2` | Implements rereading to improve reasoning by processing queries twice |
| Self-Consistency | `self_consistency` | Implements an advanced self-consistency method |
| Z3 Solver | `z3` | Utilizes the Z3 theorem prover for logical reasoning |
| R* Algorithm | `rstar` | Implements the R* algorithm for problem-solving |
| LEAP | `leap` | Learns task-specific principles from few shot examples |
| Round Trip Optimization | `rto` | Optimizes responses through a round-trip process |
| Best of N Sampling | `bon` | Generates multiple responses and selects the best one |
| Mixture of Agents | `moa` | Combines responses from multiple critiques |
| Monte Carlo Tree Search | `mcts` | Uses MCTS for decision-making in chat responses |
| PV Game | `pvg` | Applies a prover-verifier game approach at inference time |
| CoT Decoding | N/A for proxy | Implements chain-of-thought decoding to elicit reasoning without explicit prompting |
| Entropy Decoding | N/A for proxy | Implements adaptive sampling based on the uncertainty of tokens during generation |

## Implemented plugins

Expand All @@ -244,22 +245,38 @@ response = client.chat.completions.create(

optillm supports various command-line arguments and environment variables for configuration.

| Parameter | Description | Default Value |
|--------------------------|-----------------------------------------------------------------|-----------------|
| `--approach` | Inference approach to use | `"auto"` |
| `--simulations` | Number of MCTS simulations | 2 |
| `--exploration` | Exploration weight for MCTS | 0.2 |
| `--depth` | Simulation depth for MCTS | 1 |
| `--best-of-n` | Number of samples for best_of_n approach | 3 |
| `--model` | OpenAI model to use | `"gpt-4o-mini"` |
| `--base-url` | Base URL for OpenAI compatible endpoint | `""` |
| `--rstar-max-depth` | Maximum depth for rStar algorithm | 3 |
| `--rstar-num-rollouts` | Number of rollouts for rStar algorithm | 5 |
| `--rstar-c` | Exploration constant for rStar algorithm | 1.4 |
| `--n` | Number of final responses to be returned | 1 |
| `--return-full-response` | Return the full response including the CoT with <thinking> tags | `False` |
| `--port` | Specify the port to run the proxy | 8000 |
| `--optillm-api-key` | Optional API key for client authentication to optillm | `""` |
| Parameter | Description | Default Value |
|-------------------------------------|-----------------------------------------------------------------|-----------------|
| `--approach` | Inference approach to use | `"auto"` |
| `--simulations` | Number of MCTS simulations | 2 |
| `--exploration` | Exploration weight for MCTS | 0.2 |
| `--depth` | Simulation depth for MCTS | 1 |
| `--best-of-n` | Number of samples for best_of_n approach | 3 |
| `--model` | OpenAI model to use | `"gpt-4o-mini"` |
| `--base-url` | Base URL for OpenAI compatible endpoint | `""` |
| `--rstar-max-depth` | Maximum depth for rStar algorithm | 3 |
| `--rstar-num-rollouts` | Number of rollouts for rStar algorithm | 5 |
| `--rstar-c` | Exploration constant for rStar algorithm | 1.4 |
| `--n` | Number of final responses to be returned | 1 |
| `--return-full-response` | Return the full response including the CoT with <thinking> tags | `False` |
| `--port` | Specify the port to run the proxy | 8000 |
| `--optillm-api-key` | Optional API key for client authentication to optillm | `""` |
| `--cepo_bestofn_n` | Number of responses to be generated in best of n stage | 3 |
| `--cepo_bestofn_temperature` | Temperature for verifier in best of n stage | 0.1 |
| `--cepo_bestofn_max_tokens` | Maximum number of tokens for verifier in best of n stage | 4096 |
| `--cepo_bestofn_rating_type` | Type of rating in best of n stage ("absolute" or "pairwise") | `"absolute"` |
| `--cepo_planning_n` | Number of plans generated in planning stage | 3 |
| `--cepo_planning_m` | Number of attempts to generate n plans in planning stage | 6 |
| `--cepo_planning_temperature_step1` | Temperature for generator in step 1 of planning stage | 0.55 |
| `--cepo_planning_temperature_step2` | Temperature for generator in step 2 of planning stage | 0.25 |
| `--cepo_planning_temperature_step3` | Temperature for generator in step 3 of planning stage | 0.1 |
| `--cepo_planning_temperature_step4` | Temperature for generator in step 4 of planning stage | 0 |
| `--cepo_planning_max_tokens_step1` | Maximum number of tokens in step 1 of planning stage | 4096 |
| `--cepo_planning_max_tokens_step2` | Maximum number of tokens in step 2 of planning stage | 4096 |
| `--cepo_planning_max_tokens_step3` | Maximum number of tokens in step 3 of planning stage | 4096 |
| `--cepo_planning_max_tokens_step4` | Maximum number of tokens in step 4 of planning stage | 4096 |
| `--cepo_print_output` | Whether to print the output of each stage | `False` |
| `--cepo_config_file` | Path to CePO configuration file | None |

When using Docker, these can be set as environment variables prefixed with `OPTILLM_`.

Expand Down Expand Up @@ -308,6 +325,15 @@ Authorization: Bearer your_secret_api_key

## SOTA results on benchmarks with optillm

### CePO on math and code benchmarks

| Method | Math-L5 | MMLU-Pro (Math) | GPQA | CRUX | LiveCodeBench (pass@1) | Simple QA |
| -------------------------: | :-----: | :-------------: | :--: | :--: | :--------------------: | :-------: |
| Llama 3.1 70B | 41.6 | 72.9 | 41.7 | 64.2 | 24.5 | 14.7 |
| Llama 3.3 70B | 51.0 | 78.6 | 49.1 | 72.6 | 27.1 | 20.9 |
| Llama 3.1 405B | 49.8 | 79.2 | 50.7 | 73.0 | 31.8 | 13.5 |
| CePO (using Llama 3.3 70B) | 69.6 | 84.8 | 55.5 | 80.1 | 31.9 | 22.6 |

### coc-claude-3-5-sonnet-20241022 on AIME 2024 pass@1 (Nov 2024)

| Model | Score |
Expand Down
30 changes: 27 additions & 3 deletions optillm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import secrets
from flask import Flask, request, jsonify
from cerebras.cloud.sdk import Cerebras
from openai import AzureOpenAI, OpenAI
from flask import Response
import json
Expand All @@ -13,6 +14,7 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple, Optional, Union, Dict, Any, List
from importlib.metadata import version
from dataclasses import fields

# Import approach modules
from optillm.mcts import chat_with_mcts
Expand All @@ -27,6 +29,7 @@
from optillm.plansearch import plansearch
from optillm.leap import leap
from optillm.reread import re2_approach
from optillm.cepo.cepo import cepo, CepoConfig, init_cepo_config

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
Expand All @@ -50,7 +53,14 @@ def get_config():
from optillm.inference import create_inference_client
API_KEY = os.environ.get("OPTILLM_API_KEY")
default_client = create_inference_client()
# OpenAI, Azure, or LiteLLM API configuration
# Cerebras, OpenAI, Azure, or LiteLLM API configuration
elif os.environ.get("CEREBRAS_API_KEY"):
API_KEY = os.environ.get("CEREBRAS_API_KEY")
base_url = server_config['base_url']
if base_url != "":
default_client = Cerebras(api_key=API_KEY, base_url=base_url)
else:
default_client = Cerebras(api_key=API_KEY)
elif os.environ.get("OPENAI_API_KEY"):
API_KEY = os.environ.get("OPENAI_API_KEY")
base_url = server_config['base_url']
Expand Down Expand Up @@ -104,7 +114,7 @@ def get_config():

# List of known approaches
known_approaches = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency",
"pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"]
"pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2", "cepo"]

plugin_approaches = {}

Expand All @@ -124,7 +134,7 @@ def none_approach(
model: Model identifier
original_messages: Original messages from the request
**kwargs: Additional parameters to pass through

Returns:
Dict[str, Any]: Full OpenAI API response
"""
Expand Down Expand Up @@ -282,6 +292,8 @@ def execute_single_approach(approach, system_prompt, initial_query, client, mode
return leap(system_prompt, initial_query, client, model)
elif approach == 're2':
return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'])
elif approach == 'cepo':
return cepo(system_prompt, initial_query, client, model, cepo_config)
elif approach in plugin_approaches:
return plugin_approaches[approach](system_prompt, initial_query, client, model)
else:
Expand Down Expand Up @@ -690,6 +702,12 @@ def parse_args():
parser.add_argument("--base-url", "--base_url", dest="base_url", type=str, default=base_url_default,
help="Base url for OpenAI compatible endpoint")

# Special handling of all the CePO Configurations
for field in fields(CepoConfig):
parser.add_argument(f"--cepo_{field.name}", dest=f"cepo_{field.name}", type=field.type, default=None, help=f"CePO configuration for {field.name}")

parser.add_argument(f"--cepo_config_file", dest=f"cepo_config_file", type=str, default="./optillm/cepo/configs/cepo_config.yaml", help="Path to CePO configuration file")

args = parser.parse_args()

# Convert argument names to match server_config keys
Expand All @@ -703,6 +721,7 @@ def parse_args():

def main():
global server_config
global cepo_config
# Call this function at the start of main()
args = parse_args()
# Update server_config with all argument values
Expand All @@ -717,6 +736,11 @@ def main():
if logging_level in logging_levels.keys():
logger.setLevel(logging_levels[logging_level])

# set and log the cepo configs
cepo_config = init_cepo_config(server_config)
if args.approach == 'cepo':
logger.info(f"CePO Config: {cepo_config}")

logger.info(f"Starting server with approach: {server_config['approach']}")
server_config_clean = server_config.copy()
if server_config_clean['optillm_api_key']:
Expand Down
44 changes: 44 additions & 0 deletions optillm/cepo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# The Cerebras Planning and Optimization (CePO) Method

CePO is an inference-time computation method designed to enhance the accuracy of large language models (LLMs) on tasks requiring reasoning and planning, such as solving math or coding problems. It integrates several advanced techniques, including Best of N, Chain of Thought (CoT), Self-Reflection, Self-Improvement, and Prompt Engineering.

If you have any questions or want to contribute, please reach out to us on [cerebras.ai/discord](cerebras.ai/discord)

## CePO Methodology

In CePO, the Best of N technique is applied to `bestofn_n` solution candidates. Each solution is generated through the following four steps:

**Step 1**: Plan Generation
The model generates a detailed, step-by-step plan to solve the problem, along with its confidence level for each step.

**Step 2**: Initial Solution
Using the plan from Step 1, the model produces an initial solution.

Steps 1 and 2 are repeated `planning_n` times to generate multiple solution proposals.
If the model exceeds the token budget during Step 1 or 2, the plan/solution is marked as incomplete, rejected, and regenerated. A maximum of `planning_m` attempts is made to generate `planning_n` valid proposals.

**Step 3**: Plan Refinement
The model reviews all generated solution proposals and their associated plans, identifying inconsistencies. Based on this analysis, a refined, final step-by-step plan is constructed.

**Step 4**: Final Solution
The model uses the refined plan from Step 3 to produce the final answer.

## CePO Current Status

This project is a work in progress, and the provided code is in an early experimental stage. While the proposed approach works well across the benchmarks we tested, further improvements can be achieved by task-specific customizations to prompts.

## CePO Ablation studies

We conducted ablation studies to evaluate the impact of various hyperparameters in the CePO framework. Our results indicate that the chosen hyperparameter settings strike a good balance between computational cost and accuracy.

Interestingly, the self-critique and quality improvement capabilities of existing off-the-shelf models do not always scale proportionally with increased inference compute. Addressing this limitation remains a key focus, and we plan to explore custom model fine-tuning as a potential solution in the future.

| bestofn_n | planning_n | planning_m | bestofn_rating_type | Math-L5 | MMLU-Pro (Math) | GPQA | CRUX | Comments |
| :-------: | :--------: | :--------: | :-----------------: | :-----: | :-------------: | :---: | :---: | :------------- |
| 3 | 3 | 6 | absolute | 69.6 | 84.8 | 55.5 | 80.1 | Default config |
| 3 | 3 | 6 | pairwise | 67.7 | 83.5 | 55.6 | 79.8 | |
| 3 | 2 | 5 | absolute | 67.1 | 85.1 | 55.1 | 79.0 | |
| 3 | 5 | 8 | absolute | 69.4 | 84.3 | 55.6 | 81.1 | |
| 5 | 3 | 6 | absolute | 68.7 | 85.4 | 54.8 | 79.9 | |
| 7 | 3 | 6 | absolute | 69.6 | 82.8 | 54.7 | 78.4 | |
| 9 | 3 | 6 | absolute | 68.9 | 83.4 | 55.7 | 80.6 | |
Loading