From 50f4c8ef215de4a2450b2bde99ce83cefc70dd82 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Thu, 4 Sep 2025 17:03:28 +0800 Subject: [PATCH 1/3] fixes --- .../workflows/publish-docker-proxy-amd64.yml | 7 +- .../workflows/publish-docker-proxy-arm64.yml | 7 +- MANIFEST.in | 1 - optillm.py | 1234 +---------------- optillm/__init__.py | 53 +- optillm/server.py | 1232 ++++++++++++++++ pyproject.toml | 5 +- 7 files changed, 1270 insertions(+), 1269 deletions(-) create mode 100644 optillm/server.py diff --git a/.github/workflows/publish-docker-proxy-amd64.yml b/.github/workflows/publish-docker-proxy-amd64.yml index 26efe60b..1a2ffeea 100644 --- a/.github/workflows/publish-docker-proxy-amd64.yml +++ b/.github/workflows/publish-docker-proxy-amd64.yml @@ -25,7 +25,12 @@ jobs: - name: Extract version from tag id: version - run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT + run: | + VERSION=${GITHUB_REF#refs/tags/} + if [ -z "$VERSION" ] || [ "$VERSION" = "$GITHUB_REF" ]; then + VERSION="latest" + fi + echo "VERSION=$VERSION" >> $GITHUB_OUTPUT - name: Build and push proxy AMD64 image uses: docker/build-push-action@v5 diff --git a/.github/workflows/publish-docker-proxy-arm64.yml b/.github/workflows/publish-docker-proxy-arm64.yml index c75dc886..ab1e8077 100644 --- a/.github/workflows/publish-docker-proxy-arm64.yml +++ b/.github/workflows/publish-docker-proxy-arm64.yml @@ -28,7 +28,12 @@ jobs: - name: Extract version from tag id: version - run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT + run: | + VERSION=${GITHUB_REF#refs/tags/} + if [ -z "$VERSION" ] || [ "$VERSION" = "$GITHUB_REF" ]; then + VERSION="latest" + fi + echo "VERSION=$VERSION" >> $GITHUB_OUTPUT - name: Build and push proxy ARM64 image uses: docker/build-push-action@v5 diff --git a/MANIFEST.in b/MANIFEST.in index e3a43fcb..e50b514b 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,3 @@ -include optillm.py include optillm/plugins/*.py include optillm/cepo/*.py include optillm/cepo/configs/*.yaml diff --git a/optillm.py b/optillm.py index 5df85030..63410b23 100644 --- a/optillm.py +++ b/optillm.py @@ -1,1232 +1,12 @@ -import argparse -import logging -import os -import secrets -import time -from pathlib import Path -from flask import Flask, request, jsonify -from cerebras.cloud.sdk import Cerebras -from openai import AzureOpenAI, OpenAI -from flask import Response -import json -import importlib -import glob -import asyncio -import re -from concurrent.futures import ThreadPoolExecutor -from typing import Tuple, Optional, Union, Dict, Any, List -from importlib.metadata import version -from dataclasses import fields +#!/usr/bin/env python3 +""" +OptILLM - OpenAI API compatible optimizing inference proxy -# Import approach modules -from optillm.mcts import chat_with_mcts -from optillm.bon import best_of_n_sampling -from optillm.moa import mixture_of_agents -from optillm.rto import round_trip_optimization -from optillm.self_consistency import advanced_self_consistency_approach -from optillm.pvg import inference_time_pv_game -from optillm.z3_solver import Z3SymPySolverSystem -from optillm.rstar import RStar -from optillm.cot_reflection import cot_reflection -from optillm.plansearch import plansearch -from optillm.leap import leap -from optillm.reread import re2_approach -from optillm.cepo.cepo import cepo, CepoConfig, init_cepo_config -from optillm.batching import RequestBatcher, BatchingError -from optillm.conversation_logger import ConversationLogger -import optillm.conversation_logger +This is a thin wrapper that imports and runs the main server from the optillm package. +For backwards compatibility with direct execution of optillm.py. +""" -# Setup logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) -logging_levels = { - "notset": logging.NOTSET, - "debug": logging.DEBUG, - "info": logging.INFO, - "warning": logging.WARNING, - "error": logging.ERROR, - "critical": logging.CRITICAL, -} - -# Initialize Flask app -app = Flask(__name__) - -# Global request batcher (initialized in main() if batch mode enabled) -request_batcher = None - -# Global conversation logger (initialized in main() if logging enabled) -conversation_logger = None - -def get_config(): - API_KEY = None - if os.environ.get("OPTILLM_API_KEY"): - # Use local inference engine - from optillm.inference import create_inference_client - API_KEY = os.environ.get("OPTILLM_API_KEY") - default_client = create_inference_client() - # Cerebras, OpenAI, Azure, or LiteLLM API configuration - elif os.environ.get("CEREBRAS_API_KEY"): - API_KEY = os.environ.get("CEREBRAS_API_KEY") - base_url = server_config['base_url'] - if base_url != "": - default_client = Cerebras(api_key=API_KEY, base_url=base_url) - else: - default_client = Cerebras(api_key=API_KEY) - elif os.environ.get("OPENAI_API_KEY"): - API_KEY = os.environ.get("OPENAI_API_KEY") - base_url = server_config['base_url'] - if base_url != "": - default_client = OpenAI(api_key=API_KEY, base_url=base_url) - else: - default_client = OpenAI(api_key=API_KEY) - elif os.environ.get("AZURE_OPENAI_API_KEY"): - API_KEY = os.environ.get("AZURE_OPENAI_API_KEY") - API_VERSION = os.environ.get("AZURE_API_VERSION") - AZURE_ENDPOINT = os.environ.get("AZURE_API_BASE") - if API_KEY is not None: - default_client = AzureOpenAI( - api_key=API_KEY, - api_version=API_VERSION, - azure_endpoint=AZURE_ENDPOINT, - ) - else: - from azure.identity import DefaultAzureCredential, get_bearer_token_provider - azure_credential = DefaultAzureCredential() - token_provider = get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default") - default_client = AzureOpenAI( - api_version=API_VERSION, - azure_endpoint=AZURE_ENDPOINT, - azure_ad_token_provider=token_provider - ) - else: - # Import the LiteLLM wrapper - from optillm.litellm_wrapper import LiteLLMWrapper - default_client = LiteLLMWrapper() - return default_client, API_KEY - -def count_reasoning_tokens(text: str, tokenizer=None) -> int: - """ - Count tokens within ... tags in the given text. - - Args: - text: The text to analyze - tokenizer: Optional tokenizer instance for precise counting - - Returns: - Number of reasoning tokens (0 if no think tags found) - """ - if not text or not isinstance(text, str): - return 0 - - # Extract all content within ... tags - # Handle both complete and truncated think blocks - - # First, find all complete ... blocks - complete_pattern = r'(.*?)' - complete_matches = re.findall(complete_pattern, text, re.DOTALL) - - # Then check for unclosed tag (truncated response) - # This finds that doesn't have a matching after it - truncated_pattern = r'(?!.*)(.*)$' - truncated_match = re.search(truncated_pattern, text, re.DOTALL) - - # Combine all thinking content - thinking_content = ''.join(complete_matches) - if truncated_match: - thinking_content += truncated_match.group(1) - - if not thinking_content: - return 0 - - if tokenizer and hasattr(tokenizer, 'encode'): - # Use tokenizer for precise counting - try: - tokens = tokenizer.encode(thinking_content) - return len(tokens) - except Exception as e: - logger.warning(f"Failed to count tokens with tokenizer: {e}") - - # Fallback: rough estimation (4 chars per token on average, minimum 1 token for non-empty content) - content_length = len(thinking_content.strip()) - return max(1, content_length // 4) if content_length > 0 else 0 - -# Server configuration -server_config = { - 'approach': 'none', - 'mcts_simulations': 2, - 'mcts_exploration': 0.2, - 'mcts_depth': 1, - 'best_of_n': 3, - 'model': 'gpt-4o-mini', - 'rstar_max_depth': 3, - 'rstar_num_rollouts': 5, - 'rstar_c': 1.4, - 'n': 1, - 'base_url': '', - 'optillm_api_key': '', - 'return_full_response': False, - 'port': 8000, - 'log': 'info', -} - -# List of known approaches -known_approaches = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency", - "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2", "cepo"] - -plugin_approaches = {} - -def normalize_message_content(messages): - """ - Ensure all message content fields are strings, not lists. - Some models don't handle list-format content correctly. - """ - normalized_messages = [] - for message in messages: - normalized_message = message.copy() - content = message.get('content', '') - - # Convert list content to string if needed - if isinstance(content, list): - # Extract text content from the list - text_content = ' '.join( - item.get('text', '') for item in content - if isinstance(item, dict) and item.get('type') == 'text' - ) - normalized_message['content'] = text_content - - normalized_messages.append(normalized_message) - - return normalized_messages - -def none_approach( - client: Any, - model: str, - original_messages: List[Dict[str, str]], - request_id: str = None, - **kwargs -) -> Dict[str, Any]: - """ - Direct proxy approach that passes through all parameters to the underlying endpoint. - - Args: - client: OpenAI client instance - model: Model identifier - original_messages: Original messages from the request - request_id: Optional request ID for conversation logging - **kwargs: Additional parameters to pass through - - Returns: - Dict[str, Any]: Full OpenAI API response - """ - # Strip 'none-' prefix from model if present - if model.startswith('none-'): - model = model[5:] - - try: - # Normalize message content to ensure it's always string - normalized_messages = normalize_message_content(original_messages) - - # Prepare request data for logging - provider_request = { - "model": model, - "messages": normalized_messages, - **kwargs - } - - # Make the direct completion call with normalized messages and parameters - response = client.chat.completions.create( - model=model, - messages=normalized_messages, - **kwargs - ) - - # Convert to dict if it's not already - response_dict = response.model_dump() if hasattr(response, 'model_dump') else response - - # Log the provider call if conversation logging is enabled - if conversation_logger and request_id: - conversation_logger.log_provider_call(request_id, provider_request, response_dict) - - return response_dict - - except Exception as e: - # Log error if conversation logging is enabled - if conversation_logger and request_id: - conversation_logger.log_error(request_id, f"Error in none approach: {str(e)}") - logger.error(f"Error in none approach: {str(e)}") - raise - -def load_plugins(): - # Clear existing plugins first but modify the global dict in place - plugin_approaches.clear() - - # Get installed package plugins directory - import optillm - package_plugin_dir = os.path.join(os.path.dirname(optillm.__file__), 'plugins') - - # Get local project plugins directory - current_dir = os.getcwd() if server_config.get("plugins_dir", "") == "" else server_config["plugins_dir"] - local_plugin_dir = os.path.join(current_dir, 'optillm', 'plugins') - - plugin_dirs = [] - - # Add package plugin dir - plugin_dirs.append((package_plugin_dir, "package")) - - # Add local plugin dir only if it's different from package dir - if local_plugin_dir != package_plugin_dir: - plugin_dirs.append((local_plugin_dir, "local")) - - for plugin_dir, source in plugin_dirs: - logger.info(f"Looking for {source} plugins in: {plugin_dir}") - - if not os.path.exists(plugin_dir): - logger.debug(f"{source.capitalize()} plugin directory not found: {plugin_dir}") - continue - - plugin_files = glob.glob(os.path.join(plugin_dir, '*.py')) - if not plugin_files: - logger.debug(f"No plugin files found in {source} directory: {plugin_dir}") - continue - - logger.info(f"Found {source} plugin files: {plugin_files}") - - for plugin_file in plugin_files: - try: - module_name = os.path.basename(plugin_file)[:-3] # Remove .py extension - spec = importlib.util.spec_from_file_location(module_name, plugin_file) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - if hasattr(module, 'SLUG') and hasattr(module, 'run'): - if module.SLUG in plugin_approaches: - logger.info(f"Overriding {source} plugin: {module.SLUG}") - plugin_approaches[module.SLUG] = module.run - logger.info(f"Loaded {source} plugin: {module.SLUG}") - else: - logger.warning(f"Plugin {module_name} from {source} missing required attributes (SLUG and run)") - except Exception as e: - logger.error(f"Error loading {source} plugin {plugin_file}: {str(e)}") - - if not plugin_approaches: - logger.warning("No plugins loaded from any location") - -def get_config_path(): - # Get installed package config directory - import optillm - package_config_dir = os.path.join(os.path.dirname(optillm.__file__), 'cepo', 'configs') - package_config_path = os.path.join(package_config_dir, 'cepo_config.yaml') - - # Get local project config directory - current_dir = os.getcwd() if server_config.get("config_dir", "") == "" else server_config["config_dir"] - local_config_dir = os.path.join(current_dir, 'optillm', 'cepo', 'configs') - local_config_path = os.path.join(local_config_dir, 'cepo_config.yaml') - - # If local config exists and is different from package config, use local - if os.path.exists(local_config_path) and local_config_path != package_config_path: - logger.debug(f"Using local config from: {local_config_path}") - return local_config_path - - # Otherwise use package config - logger.debug(f"Using package config from: {package_config_path}") - return package_config_path - -def parse_combined_approach(model: str, known_approaches: list, plugin_approaches: dict): - if model == 'auto': - return 'SINGLE', ['none'], model - - parts = model.split('-') - approaches = [] - operation = 'SINGLE' - model_parts = [] - parsing_approaches = True - - for part in parts: - if parsing_approaches: - if part in known_approaches or part in plugin_approaches: - approaches.append(part) - elif '&' in part: - operation = 'AND' - approaches.extend(part.split('&')) - elif '|' in part: - operation = 'OR' - approaches.extend(part.split('|')) - else: - parsing_approaches = False - model_parts.append(part) - else: - model_parts.append(part) - - if not approaches: - approaches = ['none'] - operation = 'SINGLE' - - actual_model = '-'.join(model_parts) - - return operation, approaches, actual_model - -def execute_single_approach(approach, system_prompt, initial_query, client, model, request_config: dict = None, request_id: str = None): - if approach in known_approaches: - if approach == 'none': - # Use the request_config that was already prepared and passed to this function - kwargs = request_config.copy() if request_config else {} - - # Remove items that are handled separately by the framework - kwargs.pop('n', None) # n is handled by execute_n_times - kwargs.pop('stream', None) # stream is handled by proxy() - - # Reconstruct original messages from system_prompt and initial_query - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - if initial_query: - messages.append({"role": "user", "content": initial_query}) - - response = none_approach(original_messages=messages, client=client, model=model, request_id=request_id, **kwargs) - # For none approach, we return the response and a token count of 0 - # since the full token count is already in the response - return response, 0 - elif approach == 'mcts': - return chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'], - server_config['mcts_exploration'], server_config['mcts_depth'], request_id) - elif approach == 'bon': - return best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'], request_id) - elif approach == 'moa': - return mixture_of_agents(system_prompt, initial_query, client, model, request_id) - elif approach == 'rto': - return round_trip_optimization(system_prompt, initial_query, client, model, request_id) - elif approach == 'z3': - z3_solver = Z3SymPySolverSystem(system_prompt, client, model, request_id=request_id) - return z3_solver.process_query(initial_query) - elif approach == "self_consistency": - return advanced_self_consistency_approach(system_prompt, initial_query, client, model, request_id) - elif approach == "pvg": - return inference_time_pv_game(system_prompt, initial_query, client, model, request_id) - elif approach == "rstar": - rstar = RStar(system_prompt, client, model, - max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'], - c=server_config['rstar_c'], request_id=request_id) - return rstar.solve(initial_query) - elif approach == "cot_reflection": - return cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'], request_config=request_config, request_id=request_id) - elif approach == 'plansearch': - return plansearch(system_prompt, initial_query, client, model, n=server_config['n'], request_id=request_id) - elif approach == 'leap': - return leap(system_prompt, initial_query, client, model, request_id) - elif approach == 're2': - return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'], request_id=request_id) - elif approach == 'cepo': - return cepo(system_prompt, initial_query, client, model, cepo_config, request_id) - elif approach in plugin_approaches: - # Check if the plugin accepts request_config - plugin_func = plugin_approaches[approach] - import inspect - sig = inspect.signature(plugin_func) - - # Check if the plugin function is async - is_async = inspect.iscoroutinefunction(plugin_func) - - if is_async: - # For async functions, we need to run them in an event loop - import asyncio - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - if 'request_config' in sig.parameters: - # Plugin supports request_config - result = loop.run_until_complete(plugin_func(system_prompt, initial_query, client, model, request_config=request_config)) - else: - # Legacy plugin without request_config support - result = loop.run_until_complete(plugin_func(system_prompt, initial_query, client, model)) - return result - finally: - loop.close() - else: - # For synchronous functions, call directly - if 'request_config' in sig.parameters: - # Plugin supports request_config - return plugin_func(system_prompt, initial_query, client, model, request_config=request_config) - else: - # Legacy plugin without request_config support - return plugin_func(system_prompt, initial_query, client, model) - else: - raise ValueError(f"Unknown approach: {approach}") - -def execute_combined_approaches(approaches, system_prompt, initial_query, client, model, request_config: dict = None): - final_response = initial_query - total_tokens = 0 - for approach in approaches: - response, tokens = execute_single_approach(approach, system_prompt, final_response, client, model, request_config) - final_response = response - total_tokens += tokens - return final_response, total_tokens - -async def execute_parallel_approaches(approaches, system_prompt, initial_query, client, model, request_config: dict = None): - async def run_approach(approach): - return await asyncio.to_thread(execute_single_approach, approach, system_prompt, initial_query, client, model, request_config) - - tasks = [run_approach(approach) for approach in approaches] - results = await asyncio.gather(*tasks) - responses, tokens = zip(*results) - return list(responses), sum(tokens) - -def execute_n_times(n: int, approaches, operation: str, system_prompt: str, initial_query: str, client: Any, model: str, - request_config: dict = None, request_id: str = None) -> Tuple[Union[str, List[str]], int]: - """ - Execute the pipeline n times and return n responses. - - Args: - n (int): Number of times to run the pipeline - approaches (list): List of approaches to execute - operation (str): Operation type ('SINGLE', 'AND', or 'OR') - system_prompt (str): System prompt - initial_query (str): Initial query - client: OpenAI client instance - model (str): Model identifier - - Returns: - Tuple[Union[str, List[str]], int]: List of responses and total token count - """ - responses = [] - total_tokens = 0 - - for _ in range(n): - if operation == 'SINGLE': - response, tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config, request_id) - elif operation == 'AND': - response, tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model, request_config) - elif operation == 'OR': - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - response, tokens = loop.run_until_complete(execute_parallel_approaches(approaches, system_prompt, initial_query, client, model, request_config)) - loop.close() - else: - raise ValueError(f"Unknown operation: {operation}") - - # If response is already a list (from OR operation), extend responses - # Otherwise append the single response - if isinstance(response, list): - responses.extend(response) - else: - responses.append(response) - total_tokens += tokens - - # If n=1 and we got a single response, return it as is - # Otherwise return the list of responses - if n == 1 and len(responses) == 1: - return responses[0], total_tokens - return responses, total_tokens - -def generate_streaming_response(final_response, model): - # Yield the final response - if isinstance(final_response, list): - for index, response in enumerate(final_response): - yield "data: " + json.dumps({ - "choices": [{"delta": {"content": response}, "index": index, "finish_reason": "stop"}], - "model": model, - }) + "\n\n" - else: - yield "data: " + json.dumps({ - "choices": [{"delta": {"content": final_response}, "index": 0, "finish_reason": "stop"}], - "model": model, - }) + "\n\n" - - # Yield the final message to indicate the stream has ended - yield "data: [DONE]\n\n" - -def extract_contents(response_obj): - contents = [] - # Handle both single response and list of responses - responses = response_obj if isinstance(response_obj, list) else [response_obj] - - for response in responses: - # Extract content from first choice if it exists - if (response.get('choices') and - len(response['choices']) > 0 and - response['choices'][0].get('message') and - response['choices'][0]['message'].get('content')): - contents.append(response['choices'][0]['message']['content']) - - return contents - -def parse_conversation(messages): - system_prompt = "" - conversation = [] - optillm_approach = None - - for message in messages: - role = message['role'] - content = message['content'] - - # Handle content that could be a list or string - if isinstance(content, list): - # Extract text content from the list - text_content = ' '.join( - item['text'] for item in content - if isinstance(item, dict) and item.get('type') == 'text' - ) - else: - text_content = content - - if role == 'system': - system_prompt, optillm_approach = extract_optillm_approach(text_content) - elif role == 'user': - if not optillm_approach: - text_content, optillm_approach = extract_optillm_approach(text_content) - conversation.append(f"User: {text_content}") - elif role == 'assistant': - conversation.append(f"Assistant: {text_content}") - - initial_query = "\n".join(conversation) - return system_prompt, initial_query, optillm_approach - -def tagged_conversation_to_messages(response_text): - """Convert a tagged conversation string or list of strings into a list of messages. - If the input doesn't contain User:/Assistant: tags, return it as is. - - Args: - response_text: Either a string containing "User:" and "Assistant:" tags, - or a list of such strings. - - Returns: - If input has tags: A list of message dictionaries. - If input has no tags: The original input. - """ - def has_conversation_tags(text): - return "User:" in text or "Assistant:" in text - - def process_single_response(text): - if not has_conversation_tags(text): - return text - - messages = [] - # Split on "User:" or "Assistant:" while keeping the delimiter - parts = re.split(r'(?=(User:|Assistant:))', text.strip()) - # Remove empty strings - parts = [p for p in parts if p.strip()] - - for part in parts: - part = part.strip() - if part.startswith('User:'): - messages.append({ - 'role': 'user', - 'content': part[5:].strip() - }) - elif part.startswith('Assistant:'): - messages.append({ - 'role': 'assistant', - 'content': part[10:].strip() - }) - return messages - - if isinstance(response_text, list): - processed = [process_single_response(text) for text in response_text] - # If none of the responses had tags, return original list - if all(isinstance(p, str) for p in processed): - return response_text - return processed - else: - return process_single_response(response_text) - -def extract_optillm_approach(content): - match = re.search(r'(.*?)', content) - if match: - approach = match.group(1) - content = re.sub(r'.*?', '', content).strip() - return content, approach - return content, None - -# Optional API key configuration to secure the proxy -@app.before_request -def check_api_key(): - if server_config['optillm_api_key']: - if request.path == "/health": - return - - auth_header = request.headers.get('Authorization') - if not auth_header or not auth_header.startswith('Bearer '): - return jsonify({"error": "Invalid Authorization header. Expected format: 'Authorization: Bearer YOUR_API_KEY'"}), 401 - - client_key = auth_header.split('Bearer ', 1)[1].strip() - if not secrets.compare_digest(client_key, server_config['optillm_api_key']): - return jsonify({"error": "Invalid API key"}), 401 - -@app.route('/v1/chat/completions', methods=['POST']) -def proxy(): - logger.info('Received request to /v1/chat/completions') - data = request.get_json() - auth_header = request.headers.get("Authorization") - bearer_token = "" - - if auth_header and auth_header.startswith("Bearer "): - bearer_token = auth_header.split("Bearer ")[1].strip() - logger.debug(f"Intercepted Bearer Token: {bearer_token}") - - logger.debug(f'Request data: {data}') - - stream = data.get('stream', False) - messages = data.get('messages', []) - model = data.get('model', server_config['model']) - n = data.get('n', server_config['n']) # Get n value from request or config - # Extract response_format if present - response_format = data.get("response_format", None) - - # Explicit keys that we are already handling - explicit_keys = {'stream', 'messages', 'model', 'n', 'response_format'} - - # Copy the rest into request_config - request_config = {k: v for k, v in data.items() if k not in explicit_keys} - - # Add the explicitly handled ones - request_config.update({ - "stream": stream, - "n": n, - "response_format": response_format, # Add response_format to config - }) - - optillm_approach = data.get('optillm_approach', server_config['approach']) - logger.debug(data) - server_config['mcts_depth'] = data.get('mcts_depth', server_config['mcts_depth']) - server_config['mcts_exploration'] = data.get('mcts_exploration', server_config['mcts_exploration']) - server_config['mcts_simulations'] = data.get('mcts_simulations', server_config['mcts_simulations']) - - system_prompt, initial_query, message_optillm_approach = parse_conversation(messages) - - if message_optillm_approach: - optillm_approach = message_optillm_approach - - if optillm_approach != "auto": - model = f"{optillm_approach}-{model}" - - base_url = server_config['base_url'] - default_client, api_key = get_config() - - operation, approaches, model = parse_combined_approach(model, known_approaches, plugin_approaches) - - # Start conversation logging if enabled - request_id = None - if conversation_logger and conversation_logger.enabled: - request_id = conversation_logger.start_conversation( - client_request={ - 'messages': messages, - 'model': data.get('model', server_config['model']), - 'stream': stream, - 'n': n, - **{k: v for k, v in data.items() if k not in {'messages', 'model', 'stream', 'n'}} - }, - approach=approaches[0] if len(approaches) == 1 else f"{operation}({','.join(approaches)})", - model=model - ) - - # Log approach and request start with ID for terminal monitoring - request_id_str = f' [Request: {request_id}]' if request_id else '' - logger.info(f'Using approach(es) {approaches}, operation {operation}, with model {model}{request_id_str}') - if request_id: - logger.info(f'Request {request_id}: Starting processing') - - if bearer_token != "" and bearer_token.startswith("sk-"): - api_key = bearer_token - if base_url != "": - client = OpenAI(api_key=api_key, base_url=base_url) - else: - client = OpenAI(api_key=api_key) - else: - client = default_client - - try: - # Route to batch processing if batch mode is enabled - if request_batcher is not None: - try: - # Create request data for batching - batch_request_data = { - 'system_prompt': system_prompt, - 'initial_query': initial_query, - 'client': client, - 'model': model, - 'request_config': request_config, - 'approaches': approaches, - 'operation': operation, - 'n': n, - 'stream': stream, - 'optillm_approach': optillm_approach - } - - logger.debug("Routing request to batch processor") - result = request_batcher.add_request(batch_request_data) - return jsonify(result), 200 - - except BatchingError as e: - logger.error(f"Batch processing failed: {e}") - return jsonify({"error": str(e)}), 500 - - # Check if any of the approaches is 'none' - contains_none = any(approach == 'none' for approach in approaches) - - if operation == 'SINGLE' and approaches[0] == 'none': - # Pass through the request including the n parameter - result, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config, request_id) - - logger.debug(f'Direct proxy response: {result}') - - # Log the final response and finalize conversation logging - if conversation_logger and request_id: - conversation_logger.log_final_response(request_id, result) - conversation_logger.finalize_conversation(request_id) - - if stream: - if request_id: - logger.info(f'Request {request_id}: Completed (streaming response)') - return Response(generate_streaming_response(extract_contents(result), model), content_type='text/event-stream') - else : - if request_id: - logger.info(f'Request {request_id}: Completed') - return jsonify(result), 200 - - elif operation == 'AND' or operation == 'OR': - if contains_none: - raise ValueError("'none' approach cannot be combined with other approaches") - - # Handle non-none approaches with n attempts - response, completion_tokens = execute_n_times(n, approaches, operation, system_prompt, initial_query, client, model, request_config, request_id) - - except Exception as e: - # Log error to conversation logger if enabled - if conversation_logger and request_id: - conversation_logger.log_error(request_id, str(e)) - conversation_logger.finalize_conversation(request_id) - - request_id_str = f' {request_id}' if request_id else '' - logger.error(f"Error processing request{request_id_str}: {str(e)}") - return jsonify({"error": str(e)}), 500 - - # Convert tagged conversation to messages format if needed - if isinstance(response, list): - processed_response = tagged_conversation_to_messages(response) - # If processed_response is a list of message lists, extract last message content - if processed_response != response: # Only process if format changed - response = [msg[-1]['content'] if isinstance(msg, list) and msg else msg - for msg in processed_response] - # Otherwise keep original response - else: - messages = tagged_conversation_to_messages(response) - if isinstance(messages, list) and messages: # Only process if format changed - response = messages[-1]['content'] - - if stream: - return Response(generate_streaming_response(response, model), content_type='text/event-stream') - else: - # Calculate reasoning tokens from the response - reasoning_tokens = 0 - if isinstance(response, str): - reasoning_tokens = count_reasoning_tokens(response) - elif isinstance(response, list) and response: - # For multiple responses, sum up reasoning tokens from all - reasoning_tokens = sum(count_reasoning_tokens(resp) for resp in response if isinstance(resp, str)) - - response_data = { - 'model': model, - 'choices': [], - 'usage': { - 'completion_tokens': completion_tokens, - 'completion_tokens_details': { - 'reasoning_tokens': reasoning_tokens - } - } - } - - if isinstance(response, list): - for index, resp in enumerate(response): - response_data['choices'].append({ - 'index': index, - 'message': { - 'role': 'assistant', - 'content': resp, - }, - 'finish_reason': 'stop' - }) - else: - response_data['choices'].append({ - 'index': 0, - 'message': { - 'role': 'assistant', - 'content': response, - }, - 'finish_reason': 'stop' - }) - - # Log the final response and finalize conversation logging - if conversation_logger and request_id: - conversation_logger.log_final_response(request_id, response_data) - conversation_logger.finalize_conversation(request_id) - - logger.debug(f'API response: {response_data}') - if request_id: - logger.info(f'Request {request_id}: Completed') - return jsonify(response_data), 200 - -@app.route('/v1/models', methods=['GET']) -def proxy_models(): - logger.info('Received request to /v1/models') - default_client, API_KEY = get_config() - try: - if server_config['base_url']: - client = OpenAI(api_key=API_KEY, base_url=server_config['base_url']) - else: - client = default_client - - # Fetch models using the OpenAI client and return the raw response - models_response = client.models.list().json() - - logger.debug('Models retrieved successfully') - return models_response, 200 - except Exception as e: - logger.error(f"Error fetching models: {str(e)}") - return jsonify({"error": f"Error fetching models: {str(e)}"}), 500 - -@app.route('/health', methods=['GET']) -def health(): - return jsonify({"status": "ok"}), 200 - -def parse_args(): - parser = argparse.ArgumentParser(description="Run LLM inference with various approaches.") - - try: - from optillm import __version__ as package_version - except ImportError: - package_version = "unknown" - - parser.add_argument('--version', action='version', - version=f'%(prog)s {package_version}', - help="Show program's version number and exit") - - # Define arguments and their corresponding environment variables - args_env = [ - ("--optillm-api-key", "OPTILLM_API_KEY", str, "", "Optional API key for client authentication to optillm"), - ("--approach", "OPTILLM_APPROACH", str, "auto", "Inference approach to use", known_approaches + list(plugin_approaches.keys())), - ("--mcts-simulations", "OPTILLM_SIMULATIONS", int, 2, "Number of MCTS simulations"), - ("--mcts-exploration", "OPTILLM_EXPLORATION", float, 0.2, "Exploration weight for MCTS"), - ("--mcts-depth", "OPTILLM_DEPTH", int, 1, "Simulation depth for MCTS"), - ("--model", "OPTILLM_MODEL", str, "gpt-4o-mini", "OpenAI model to use"), - ("--rstar-max-depth", "OPTILLM_RSTAR_MAX_DEPTH", int, 3, "Maximum depth for rStar algorithm"), - ("--rstar-num-rollouts", "OPTILLM_RSTAR_NUM_ROLLOUTS", int, 5, "Number of rollouts for rStar algorithm"), - ("--rstar-c", "OPTILLM_RSTAR_C", float, 1.4, "Exploration constant for rStar algorithm"), - ("--n", "OPTILLM_N", int, 1, "Number of final responses to be returned"), - ("--return-full-response", "OPTILLM_RETURN_FULL_RESPONSE", bool, False, "Return the full response including the CoT with tags"), - ("--port", "OPTILLM_PORT", int, 8000, "Specify the port to run the proxy"), - ("--log", "OPTILLM_LOG", str, "info", "Specify the logging level", list(logging_levels.keys())), - ("--launch-gui", "OPTILLM_LAUNCH_GUI", bool, False, "Launch a Gradio chat interface"), - ("--plugins-dir", "OPTILLM_PLUGINS_DIR", str, "", "Path to the plugins directory"), - ("--log-conversations", "OPTILLM_LOG_CONVERSATIONS", bool, False, "Enable conversation logging with full metadata"), - ("--conversation-log-dir", "OPTILLM_CONVERSATION_LOG_DIR", str, str(Path.home() / ".optillm" / "conversations"), "Directory to save conversation logs"), - ] - - for arg, env, type_, default, help_text, *extra in args_env: - env_value = os.environ.get(env) - if env_value is not None: - if type_ == bool: - default = env_value.lower() in ('true', '1', 'yes') - else: - default = type_(env_value) - if extra and extra[0]: # Check if there are choices for this argument - parser.add_argument(arg, type=type_, default=default, help=help_text, choices=extra[0]) - else: - if type_ == bool: - # For boolean flags, use store_true action - parser.add_argument(arg, action='store_true', default=default, help=help_text) - else: - parser.add_argument(arg, type=type_, default=default, help=help_text) - - # Special handling for best_of_n to support both formats - best_of_n_default = int(os.environ.get("OPTILLM_BEST_OF_N", 3)) - parser.add_argument("--best-of-n", "--best_of_n", dest="best_of_n", type=int, default=best_of_n_default, - help="Number of samples for best_of_n approach") - - # Special handling for base_url to support both formats - base_url_default = os.environ.get("OPTILLM_BASE_URL", "") - parser.add_argument("--base-url", "--base_url", dest="base_url", type=str, default=base_url_default, - help="Base url for OpenAI compatible endpoint") - - # Use the function to get the default path - default_config_path = get_config_path() - - # Batch mode arguments - batch_mode_default = os.environ.get("OPTILLM_BATCH_MODE", "false").lower() == "true" - batch_size_default = int(os.environ.get("OPTILLM_BATCH_SIZE", 4)) - batch_wait_ms_default = int(os.environ.get("OPTILLM_BATCH_WAIT_MS", 50)) - - parser.add_argument("--batch-mode", action="store_true", default=batch_mode_default, - help="Enable automatic request batching (fail-fast, no fallback)") - parser.add_argument("--batch-size", type=int, default=batch_size_default, - help="Maximum batch size for request batching") - parser.add_argument("--batch-wait-ms", dest="batch_wait_ms", type=int, default=batch_wait_ms_default, - help="Maximum wait time in milliseconds for batch formation") - - # Special handling of all the CePO Configurations - for field in fields(CepoConfig): - parser.add_argument(f"--cepo_{field.name}", - dest=f"cepo_{field.name}", - type=field.type, - default=None, - help=f"CePO configuration for {field.name}") - - parser.add_argument("--cepo_config_file", - dest="cepo_config_file", - type=str, - default=default_config_path, - help="Path to CePO configuration file") - - args = parser.parse_args() - - # Convert argument names to match server_config keys - args_dict = vars(args) - for key in list(args_dict.keys()): - new_key = key.replace("-", "_") - if new_key != key: - args_dict[new_key] = args_dict.pop(key) - - return args - -def main(): - global server_config - global cepo_config - global request_batcher - global conversation_logger - # Call this function at the start of main() - args = parse_args() - # Update server_config with all argument values - server_config.update(vars(args)) - - load_plugins() - - port = server_config['port'] - - # Initialize request batcher if batch mode is enabled - if server_config.get('batch_mode', False): - logger.info(f"Batch mode enabled: size={server_config['batch_size']}, " - f"wait={server_config['batch_wait_ms']}ms") - request_batcher = RequestBatcher( - max_batch_size=server_config['batch_size'], - max_wait_ms=server_config['batch_wait_ms'], - enable_logging=True - ) - - # Set up the batch processor function - def process_batch_requests(batch_requests): - """ - Process a batch of requests using true batching when possible - - Args: - batch_requests: List of request data dictionaries - - Returns: - List of response dictionaries - """ - import time - from optillm.batching import BatchingError - - if not batch_requests: - return [] - - logger.info(f"Processing batch of {len(batch_requests)} requests") - - # Check if we can use true batching (all requests compatible and using 'none' approach) - can_use_true_batching = True - first_req = batch_requests[0] - - # Check compatibility across all requests - for req_data in batch_requests: - if (req_data['stream'] or - req_data['approaches'] != first_req['approaches'] or - req_data['operation'] != first_req['operation'] or - req_data['model'] != first_req['model']): - can_use_true_batching = False - break - - # For now, implement sequential processing but with proper infrastructure - # TODO: Implement true PyTorch/MLX batching in next phase - responses = [] - - for i, req_data in enumerate(batch_requests): - try: - logger.debug(f"Processing batch request {i+1}/{len(batch_requests)}") - - # Extract request parameters - system_prompt = req_data['system_prompt'] - initial_query = req_data['initial_query'] - client = req_data['client'] - model = req_data['model'] - request_config = req_data['request_config'] - approaches = req_data['approaches'] - operation = req_data['operation'] - n = req_data['n'] - stream = req_data['stream'] - - # Validate request - if stream: - raise BatchingError("Streaming requests cannot be batched") - - # Check if any of the approaches is 'none' - contains_none = any(approach == 'none' for approach in approaches) - - if operation == 'SINGLE' and approaches[0] == 'none': - # Pass through the request including the n parameter - result, completion_tokens = execute_single_approach( - approaches[0], system_prompt, initial_query, client, model, request_config) - elif operation == 'AND' or operation == 'OR': - if contains_none: - raise ValueError("'none' approach cannot be combined with other approaches") - # Handle non-none approaches with n attempts - result, completion_tokens = execute_n_times( - n, approaches, operation, system_prompt, initial_query, client, model, request_config) - else: - # Handle non-none approaches with n attempts - result, completion_tokens = execute_n_times( - n, approaches, operation, system_prompt, initial_query, client, model, request_config) - - # Convert tagged conversation to messages format if needed - if isinstance(result, list): - processed_response = tagged_conversation_to_messages(result) - if processed_response != result: # Only process if format changed - result = [msg[-1]['content'] if isinstance(msg, list) and msg else msg - for msg in processed_response] - else: - messages = tagged_conversation_to_messages(result) - if isinstance(messages, list) and messages: # Only process if format changed - result = messages[-1]['content'] - - # Generate the response in OpenAI format - if isinstance(result, list): - choices = [] - for j, res in enumerate(result): - choices.append({ - "index": j, - "message": { - "role": "assistant", - "content": res - }, - "finish_reason": "stop" - }) - else: - choices = [{ - "index": 0, - "message": { - "role": "assistant", - "content": result - }, - "finish_reason": "stop" - }] - - response_dict = { - "id": f"chatcmpl-{int(time.time()*1000)}-{i}", - "object": "chat.completion", - "created": int(time.time()), - "model": model, - "choices": choices, - "usage": { - "prompt_tokens": 0, # Will be calculated properly later - "completion_tokens": completion_tokens if isinstance(completion_tokens, int) else 0, - "total_tokens": completion_tokens if isinstance(completion_tokens, int) else 0 - } - } - - responses.append(response_dict) - - except Exception as e: - logger.error(f"Error processing batch request {i+1}: {e}") - raise BatchingError(f"Failed to process request {i+1}: {str(e)}") - - logger.info(f"Completed batch processing of {len(responses)} requests") - return responses - - # Set the processor function on the batcher - request_batcher.set_processor(process_batch_requests) - - # Set logging level from user request - logging_level = server_config['log'] - if logging_level in logging_levels.keys(): - logger.setLevel(logging_levels[logging_level]) - - # Initialize conversation logger if enabled - global conversation_logger - conversation_logger = ConversationLogger( - log_dir=Path(server_config['conversation_log_dir']), - enabled=server_config['log_conversations'] - ) - # Set the global logger instance for access from approach modules - optillm.conversation_logger.set_global_logger(conversation_logger) - if server_config['log_conversations']: - logger.info(f"Conversation logging enabled. Logs will be saved to: {server_config['conversation_log_dir']}") - - # set and log the cepo configs - cepo_config = init_cepo_config(server_config) - if args.approach == 'cepo': - logger.info(f"CePO Config: {cepo_config}") - - logger.info(f"Starting server with approach: {server_config['approach']}") - server_config_clean = server_config.copy() - if server_config_clean['optillm_api_key']: - server_config_clean['optillm_api_key'] = '[REDACTED]' - logger.info(f"Server configuration: {server_config_clean}") - - # Launch GUI if requested - if server_config.get('launch_gui'): - try: - import gradio as gr - # Start server in a separate thread - import threading - server_thread = threading.Thread(target=app.run, kwargs={'host': '0.0.0.0', 'port': port}) - server_thread.daemon = True - server_thread.start() - - # Configure the base URL for the Gradio interface - base_url = f"http://localhost:{port}/v1" - logger.info(f"Launching Gradio interface connected to {base_url}") - - # Create custom chat function with extended timeout - def chat_with_optillm(message, history): - import httpx - from openai import OpenAI - - # Create client with extended timeout and no retries - custom_client = OpenAI( - api_key="optillm", - base_url=base_url, - timeout=httpx.Timeout(1800.0, connect=5.0), # 30 min timeout - max_retries=0 # No retries - prevents duplicate requests - ) - - # Convert history to messages format - messages = [] - for h in history: - if h[0]: # User message - messages.append({"role": "user", "content": h[0]}) - if h[1]: # Assistant message - messages.append({"role": "assistant", "content": h[1]}) - messages.append({"role": "user", "content": message}) - - # Make request - try: - response = custom_client.chat.completions.create( - model=server_config['model'], - messages=messages - ) - return response.choices[0].message.content - except Exception as e: - return f"Error: {str(e)}" - - # Create Gradio interface with queue for long operations - demo = gr.ChatInterface( - chat_with_optillm, - title="OptILLM Chat Interface", - description=f"Connected to OptILLM proxy at {base_url}" - ) - demo.queue() # Enable queue to handle long operations properly - demo.launch(server_name="0.0.0.0", share=False) - except ImportError: - logger.error("Gradio is required for GUI. Install it with: pip install gradio") - return - - app.run(host='0.0.0.0', port=port) +from optillm import main if __name__ == "__main__": main() diff --git a/optillm/__init__.py b/optillm/__init__.py index 8c831666..01f1face 100644 --- a/optillm/__init__.py +++ b/optillm/__init__.py @@ -1,42 +1,25 @@ -from importlib import util -import os - # Version information __version__ = "0.2.1" -# Get the path to the root optillm.py -spec = util.spec_from_file_location( - "optillm.root", - os.path.join(os.path.dirname(os.path.dirname(__file__)), "optillm.py") +# Import from server module +from .server import ( + main, + server_config, + app, + known_approaches, + plugin_approaches, + parse_combined_approach, + parse_conversation, + extract_optillm_approach, + get_config, + load_plugins, + count_reasoning_tokens, + parse_args, + execute_single_approach, + execute_combined_approaches, + execute_parallel_approaches, + generate_streaming_response, ) -module = util.module_from_spec(spec) -spec.loader.exec_module(module) - -# Export the main entry point -main = module.main - -# Export the core configuration and server components -server_config = module.server_config -app = module.app -known_approaches = module.known_approaches -plugin_approaches = module.plugin_approaches - -# Export utility functions -parse_combined_approach = module.parse_combined_approach -parse_conversation = module.parse_conversation -extract_optillm_approach = module.extract_optillm_approach -get_config = module.get_config -load_plugins = module.load_plugins -count_reasoning_tokens = module.count_reasoning_tokens -parse_args = module.parse_args - -# Export execution functions -execute_single_approach = module.execute_single_approach -execute_combined_approaches = module.execute_combined_approaches -execute_parallel_approaches = module.execute_parallel_approaches - -# Export streaming response generation -generate_streaming_response = module.generate_streaming_response # List of exported symbols __all__ = [ diff --git a/optillm/server.py b/optillm/server.py new file mode 100644 index 00000000..29271d34 --- /dev/null +++ b/optillm/server.py @@ -0,0 +1,1232 @@ +import argparse +import logging +import os +import secrets +import time +from pathlib import Path +from flask import Flask, request, jsonify +from cerebras.cloud.sdk import Cerebras +from openai import AzureOpenAI, OpenAI +from flask import Response +import json +import importlib +import glob +import asyncio +import re +from concurrent.futures import ThreadPoolExecutor +from typing import Tuple, Optional, Union, Dict, Any, List +from importlib.metadata import version +from dataclasses import fields + +# Import approach modules +from optillm.mcts import chat_with_mcts +from optillm.bon import best_of_n_sampling +from optillm.moa import mixture_of_agents +from optillm.rto import round_trip_optimization +from optillm.self_consistency import advanced_self_consistency_approach +from optillm.pvg import inference_time_pv_game +from optillm.z3_solver import Z3SymPySolverSystem +from optillm.rstar import RStar +from optillm.cot_reflection import cot_reflection +from optillm.plansearch import plansearch +from optillm.leap import leap +from optillm.reread import re2_approach +from optillm.cepo.cepo import cepo, CepoConfig, init_cepo_config +from optillm.batching import RequestBatcher, BatchingError +from optillm.conversation_logger import ConversationLogger +import optillm.conversation_logger + +# Setup logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) +logging_levels = { + "notset": logging.NOTSET, + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +# Initialize Flask app +app = Flask(__name__) + +# Global request batcher (initialized in main() if batch mode enabled) +request_batcher = None + +# Global conversation logger (initialized in main() if logging enabled) +conversation_logger = None + +def get_config(): + API_KEY = None + if os.environ.get("OPTILLM_API_KEY"): + # Use local inference engine + from optillm.inference import create_inference_client + API_KEY = os.environ.get("OPTILLM_API_KEY") + default_client = create_inference_client() + # Cerebras, OpenAI, Azure, or LiteLLM API configuration + elif os.environ.get("CEREBRAS_API_KEY"): + API_KEY = os.environ.get("CEREBRAS_API_KEY") + base_url = server_config['base_url'] + if base_url != "": + default_client = Cerebras(api_key=API_KEY, base_url=base_url) + else: + default_client = Cerebras(api_key=API_KEY) + elif os.environ.get("OPENAI_API_KEY"): + API_KEY = os.environ.get("OPENAI_API_KEY") + base_url = server_config['base_url'] + if base_url != "": + default_client = OpenAI(api_key=API_KEY, base_url=base_url) + else: + default_client = OpenAI(api_key=API_KEY) + elif os.environ.get("AZURE_OPENAI_API_KEY"): + API_KEY = os.environ.get("AZURE_OPENAI_API_KEY") + API_VERSION = os.environ.get("AZURE_API_VERSION") + AZURE_ENDPOINT = os.environ.get("AZURE_API_BASE") + if API_KEY is not None: + default_client = AzureOpenAI( + api_key=API_KEY, + api_version=API_VERSION, + azure_endpoint=AZURE_ENDPOINT, + ) + else: + from azure.identity import DefaultAzureCredential, get_bearer_token_provider + azure_credential = DefaultAzureCredential() + token_provider = get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default") + default_client = AzureOpenAI( + api_version=API_VERSION, + azure_endpoint=AZURE_ENDPOINT, + azure_ad_token_provider=token_provider + ) + else: + # Import the LiteLLM wrapper + from optillm.litellm_wrapper import LiteLLMWrapper + default_client = LiteLLMWrapper() + return default_client, API_KEY + +def count_reasoning_tokens(text: str, tokenizer=None) -> int: + """ + Count tokens within ... tags in the given text. + + Args: + text: The text to analyze + tokenizer: Optional tokenizer instance for precise counting + + Returns: + Number of reasoning tokens (0 if no think tags found) + """ + if not text or not isinstance(text, str): + return 0 + + # Extract all content within ... tags + # Handle both complete and truncated think blocks + + # First, find all complete ... blocks + complete_pattern = r'(.*?)' + complete_matches = re.findall(complete_pattern, text, re.DOTALL) + + # Then check for unclosed tag (truncated response) + # This finds that doesn't have a matching after it + truncated_pattern = r'(?!.*)(.*)$' + truncated_match = re.search(truncated_pattern, text, re.DOTALL) + + # Combine all thinking content + thinking_content = ''.join(complete_matches) + if truncated_match: + thinking_content += truncated_match.group(1) + + if not thinking_content: + return 0 + + if tokenizer and hasattr(tokenizer, 'encode'): + # Use tokenizer for precise counting + try: + tokens = tokenizer.encode(thinking_content) + return len(tokens) + except Exception as e: + logger.warning(f"Failed to count tokens with tokenizer: {e}") + + # Fallback: rough estimation (4 chars per token on average, minimum 1 token for non-empty content) + content_length = len(thinking_content.strip()) + return max(1, content_length // 4) if content_length > 0 else 0 + +# Server configuration +server_config = { + 'approach': 'none', + 'mcts_simulations': 2, + 'mcts_exploration': 0.2, + 'mcts_depth': 1, + 'best_of_n': 3, + 'model': 'gpt-4o-mini', + 'rstar_max_depth': 3, + 'rstar_num_rollouts': 5, + 'rstar_c': 1.4, + 'n': 1, + 'base_url': '', + 'optillm_api_key': '', + 'return_full_response': False, + 'port': 8000, + 'log': 'info', +} + +# List of known approaches +known_approaches = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency", + "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2", "cepo"] + +plugin_approaches = {} + +def normalize_message_content(messages): + """ + Ensure all message content fields are strings, not lists. + Some models don't handle list-format content correctly. + """ + normalized_messages = [] + for message in messages: + normalized_message = message.copy() + content = message.get('content', '') + + # Convert list content to string if needed + if isinstance(content, list): + # Extract text content from the list + text_content = ' '.join( + item.get('text', '') for item in content + if isinstance(item, dict) and item.get('type') == 'text' + ) + normalized_message['content'] = text_content + + normalized_messages.append(normalized_message) + + return normalized_messages + +def none_approach( + client: Any, + model: str, + original_messages: List[Dict[str, str]], + request_id: str = None, + **kwargs +) -> Dict[str, Any]: + """ + Direct proxy approach that passes through all parameters to the underlying endpoint. + + Args: + client: OpenAI client instance + model: Model identifier + original_messages: Original messages from the request + request_id: Optional request ID for conversation logging + **kwargs: Additional parameters to pass through + + Returns: + Dict[str, Any]: Full OpenAI API response + """ + # Strip 'none-' prefix from model if present + if model.startswith('none-'): + model = model[5:] + + try: + # Normalize message content to ensure it's always string + normalized_messages = normalize_message_content(original_messages) + + # Prepare request data for logging + provider_request = { + "model": model, + "messages": normalized_messages, + **kwargs + } + + # Make the direct completion call with normalized messages and parameters + response = client.chat.completions.create( + model=model, + messages=normalized_messages, + **kwargs + ) + + # Convert to dict if it's not already + response_dict = response.model_dump() if hasattr(response, 'model_dump') else response + + # Log the provider call if conversation logging is enabled + if conversation_logger and request_id: + conversation_logger.log_provider_call(request_id, provider_request, response_dict) + + return response_dict + + except Exception as e: + # Log error if conversation logging is enabled + if conversation_logger and request_id: + conversation_logger.log_error(request_id, f"Error in none approach: {str(e)}") + logger.error(f"Error in none approach: {str(e)}") + raise + +def load_plugins(): + # Clear existing plugins first but modify the global dict in place + plugin_approaches.clear() + + # Get installed package plugins directory + import optillm + package_plugin_dir = os.path.join(os.path.dirname(optillm.__file__), 'plugins') + + # Get local project plugins directory + current_dir = os.getcwd() if server_config.get("plugins_dir", "") == "" else server_config["plugins_dir"] + local_plugin_dir = os.path.join(current_dir, 'optillm', 'plugins') + + plugin_dirs = [] + + # Add package plugin dir + plugin_dirs.append((package_plugin_dir, "package")) + + # Add local plugin dir only if it's different from package dir + if local_plugin_dir != package_plugin_dir: + plugin_dirs.append((local_plugin_dir, "local")) + + for plugin_dir, source in plugin_dirs: + logger.info(f"Looking for {source} plugins in: {plugin_dir}") + + if not os.path.exists(plugin_dir): + logger.debug(f"{source.capitalize()} plugin directory not found: {plugin_dir}") + continue + + plugin_files = glob.glob(os.path.join(plugin_dir, '*.py')) + if not plugin_files: + logger.debug(f"No plugin files found in {source} directory: {plugin_dir}") + continue + + logger.info(f"Found {source} plugin files: {plugin_files}") + + for plugin_file in plugin_files: + try: + module_name = os.path.basename(plugin_file)[:-3] # Remove .py extension + spec = importlib.util.spec_from_file_location(module_name, plugin_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + if hasattr(module, 'SLUG') and hasattr(module, 'run'): + if module.SLUG in plugin_approaches: + logger.info(f"Overriding {source} plugin: {module.SLUG}") + plugin_approaches[module.SLUG] = module.run + logger.info(f"Loaded {source} plugin: {module.SLUG}") + else: + logger.warning(f"Plugin {module_name} from {source} missing required attributes (SLUG and run)") + except Exception as e: + logger.error(f"Error loading {source} plugin {plugin_file}: {str(e)}") + + if not plugin_approaches: + logger.warning("No plugins loaded from any location") + +def get_config_path(): + # Get installed package config directory + import optillm + package_config_dir = os.path.join(os.path.dirname(optillm.__file__), 'cepo', 'configs') + package_config_path = os.path.join(package_config_dir, 'cepo_config.yaml') + + # Get local project config directory + current_dir = os.getcwd() if server_config.get("config_dir", "") == "" else server_config["config_dir"] + local_config_dir = os.path.join(current_dir, 'optillm', 'cepo', 'configs') + local_config_path = os.path.join(local_config_dir, 'cepo_config.yaml') + + # If local config exists and is different from package config, use local + if os.path.exists(local_config_path) and local_config_path != package_config_path: + logger.debug(f"Using local config from: {local_config_path}") + return local_config_path + + # Otherwise use package config + logger.debug(f"Using package config from: {package_config_path}") + return package_config_path + +def parse_combined_approach(model: str, known_approaches: list, plugin_approaches: dict): + if model == 'auto': + return 'SINGLE', ['none'], model + + parts = model.split('-') + approaches = [] + operation = 'SINGLE' + model_parts = [] + parsing_approaches = True + + for part in parts: + if parsing_approaches: + if part in known_approaches or part in plugin_approaches: + approaches.append(part) + elif '&' in part: + operation = 'AND' + approaches.extend(part.split('&')) + elif '|' in part: + operation = 'OR' + approaches.extend(part.split('|')) + else: + parsing_approaches = False + model_parts.append(part) + else: + model_parts.append(part) + + if not approaches: + approaches = ['none'] + operation = 'SINGLE' + + actual_model = '-'.join(model_parts) + + return operation, approaches, actual_model + +def execute_single_approach(approach, system_prompt, initial_query, client, model, request_config: dict = None, request_id: str = None): + if approach in known_approaches: + if approach == 'none': + # Use the request_config that was already prepared and passed to this function + kwargs = request_config.copy() if request_config else {} + + # Remove items that are handled separately by the framework + kwargs.pop('n', None) # n is handled by execute_n_times + kwargs.pop('stream', None) # stream is handled by proxy() + + # Reconstruct original messages from system_prompt and initial_query + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + if initial_query: + messages.append({"role": "user", "content": initial_query}) + + response = none_approach(original_messages=messages, client=client, model=model, request_id=request_id, **kwargs) + # For none approach, we return the response and a token count of 0 + # since the full token count is already in the response + return response, 0 + elif approach == 'mcts': + return chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'], + server_config['mcts_exploration'], server_config['mcts_depth'], request_id) + elif approach == 'bon': + return best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'], request_id) + elif approach == 'moa': + return mixture_of_agents(system_prompt, initial_query, client, model, request_id) + elif approach == 'rto': + return round_trip_optimization(system_prompt, initial_query, client, model, request_id) + elif approach == 'z3': + z3_solver = Z3SymPySolverSystem(system_prompt, client, model, request_id=request_id) + return z3_solver.process_query(initial_query) + elif approach == "self_consistency": + return advanced_self_consistency_approach(system_prompt, initial_query, client, model, request_id) + elif approach == "pvg": + return inference_time_pv_game(system_prompt, initial_query, client, model, request_id) + elif approach == "rstar": + rstar = RStar(system_prompt, client, model, + max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'], + c=server_config['rstar_c'], request_id=request_id) + return rstar.solve(initial_query) + elif approach == "cot_reflection": + return cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'], request_config=request_config, request_id=request_id) + elif approach == 'plansearch': + return plansearch(system_prompt, initial_query, client, model, n=server_config['n'], request_id=request_id) + elif approach == 'leap': + return leap(system_prompt, initial_query, client, model, request_id) + elif approach == 're2': + return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'], request_id=request_id) + elif approach == 'cepo': + return cepo(system_prompt, initial_query, client, model, cepo_config, request_id) + elif approach in plugin_approaches: + # Check if the plugin accepts request_config + plugin_func = plugin_approaches[approach] + import inspect + sig = inspect.signature(plugin_func) + + # Check if the plugin function is async + is_async = inspect.iscoroutinefunction(plugin_func) + + if is_async: + # For async functions, we need to run them in an event loop + import asyncio + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + if 'request_config' in sig.parameters: + # Plugin supports request_config + result = loop.run_until_complete(plugin_func(system_prompt, initial_query, client, model, request_config=request_config)) + else: + # Legacy plugin without request_config support + result = loop.run_until_complete(plugin_func(system_prompt, initial_query, client, model)) + return result + finally: + loop.close() + else: + # For synchronous functions, call directly + if 'request_config' in sig.parameters: + # Plugin supports request_config + return plugin_func(system_prompt, initial_query, client, model, request_config=request_config) + else: + # Legacy plugin without request_config support + return plugin_func(system_prompt, initial_query, client, model) + else: + raise ValueError(f"Unknown approach: {approach}") + +def execute_combined_approaches(approaches, system_prompt, initial_query, client, model, request_config: dict = None): + final_response = initial_query + total_tokens = 0 + for approach in approaches: + response, tokens = execute_single_approach(approach, system_prompt, final_response, client, model, request_config) + final_response = response + total_tokens += tokens + return final_response, total_tokens + +async def execute_parallel_approaches(approaches, system_prompt, initial_query, client, model, request_config: dict = None): + async def run_approach(approach): + return await asyncio.to_thread(execute_single_approach, approach, system_prompt, initial_query, client, model, request_config) + + tasks = [run_approach(approach) for approach in approaches] + results = await asyncio.gather(*tasks) + responses, tokens = zip(*results) + return list(responses), sum(tokens) + +def execute_n_times(n: int, approaches, operation: str, system_prompt: str, initial_query: str, client: Any, model: str, + request_config: dict = None, request_id: str = None) -> Tuple[Union[str, List[str]], int]: + """ + Execute the pipeline n times and return n responses. + + Args: + n (int): Number of times to run the pipeline + approaches (list): List of approaches to execute + operation (str): Operation type ('SINGLE', 'AND', or 'OR') + system_prompt (str): System prompt + initial_query (str): Initial query + client: OpenAI client instance + model (str): Model identifier + + Returns: + Tuple[Union[str, List[str]], int]: List of responses and total token count + """ + responses = [] + total_tokens = 0 + + for _ in range(n): + if operation == 'SINGLE': + response, tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config, request_id) + elif operation == 'AND': + response, tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model, request_config) + elif operation == 'OR': + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + response, tokens = loop.run_until_complete(execute_parallel_approaches(approaches, system_prompt, initial_query, client, model, request_config)) + loop.close() + else: + raise ValueError(f"Unknown operation: {operation}") + + # If response is already a list (from OR operation), extend responses + # Otherwise append the single response + if isinstance(response, list): + responses.extend(response) + else: + responses.append(response) + total_tokens += tokens + + # If n=1 and we got a single response, return it as is + # Otherwise return the list of responses + if n == 1 and len(responses) == 1: + return responses[0], total_tokens + return responses, total_tokens + +def generate_streaming_response(final_response, model): + # Yield the final response + if isinstance(final_response, list): + for index, response in enumerate(final_response): + yield "data: " + json.dumps({ + "choices": [{"delta": {"content": response}, "index": index, "finish_reason": "stop"}], + "model": model, + }) + "\n\n" + else: + yield "data: " + json.dumps({ + "choices": [{"delta": {"content": final_response}, "index": 0, "finish_reason": "stop"}], + "model": model, + }) + "\n\n" + + # Yield the final message to indicate the stream has ended + yield "data: [DONE]\n\n" + +def extract_contents(response_obj): + contents = [] + # Handle both single response and list of responses + responses = response_obj if isinstance(response_obj, list) else [response_obj] + + for response in responses: + # Extract content from first choice if it exists + if (response.get('choices') and + len(response['choices']) > 0 and + response['choices'][0].get('message') and + response['choices'][0]['message'].get('content')): + contents.append(response['choices'][0]['message']['content']) + + return contents + +def parse_conversation(messages): + system_prompt = "" + conversation = [] + optillm_approach = None + + for message in messages: + role = message['role'] + content = message['content'] + + # Handle content that could be a list or string + if isinstance(content, list): + # Extract text content from the list + text_content = ' '.join( + item['text'] for item in content + if isinstance(item, dict) and item.get('type') == 'text' + ) + else: + text_content = content + + if role == 'system': + system_prompt, optillm_approach = extract_optillm_approach(text_content) + elif role == 'user': + if not optillm_approach: + text_content, optillm_approach = extract_optillm_approach(text_content) + conversation.append(f"User: {text_content}") + elif role == 'assistant': + conversation.append(f"Assistant: {text_content}") + + initial_query = "\n".join(conversation) + return system_prompt, initial_query, optillm_approach + +def tagged_conversation_to_messages(response_text): + """Convert a tagged conversation string or list of strings into a list of messages. + If the input doesn't contain User:/Assistant: tags, return it as is. + + Args: + response_text: Either a string containing "User:" and "Assistant:" tags, + or a list of such strings. + + Returns: + If input has tags: A list of message dictionaries. + If input has no tags: The original input. + """ + def has_conversation_tags(text): + return "User:" in text or "Assistant:" in text + + def process_single_response(text): + if not has_conversation_tags(text): + return text + + messages = [] + # Split on "User:" or "Assistant:" while keeping the delimiter + parts = re.split(r'(?=(User:|Assistant:))', text.strip()) + # Remove empty strings + parts = [p for p in parts if p.strip()] + + for part in parts: + part = part.strip() + if part.startswith('User:'): + messages.append({ + 'role': 'user', + 'content': part[5:].strip() + }) + elif part.startswith('Assistant:'): + messages.append({ + 'role': 'assistant', + 'content': part[10:].strip() + }) + return messages + + if isinstance(response_text, list): + processed = [process_single_response(text) for text in response_text] + # If none of the responses had tags, return original list + if all(isinstance(p, str) for p in processed): + return response_text + return processed + else: + return process_single_response(response_text) + +def extract_optillm_approach(content): + match = re.search(r'(.*?)', content) + if match: + approach = match.group(1) + content = re.sub(r'.*?', '', content).strip() + return content, approach + return content, None + +# Optional API key configuration to secure the proxy +@app.before_request +def check_api_key(): + if server_config['optillm_api_key']: + if request.path == "/health": + return + + auth_header = request.headers.get('Authorization') + if not auth_header or not auth_header.startswith('Bearer '): + return jsonify({"error": "Invalid Authorization header. Expected format: 'Authorization: Bearer YOUR_API_KEY'"}), 401 + + client_key = auth_header.split('Bearer ', 1)[1].strip() + if not secrets.compare_digest(client_key, server_config['optillm_api_key']): + return jsonify({"error": "Invalid API key"}), 401 + +@app.route('/v1/chat/completions', methods=['POST']) +def proxy(): + logger.info('Received request to /v1/chat/completions') + data = request.get_json() + auth_header = request.headers.get("Authorization") + bearer_token = "" + + if auth_header and auth_header.startswith("Bearer "): + bearer_token = auth_header.split("Bearer ")[1].strip() + logger.debug(f"Intercepted Bearer Token: {bearer_token}") + + logger.debug(f'Request data: {data}') + + stream = data.get('stream', False) + messages = data.get('messages', []) + model = data.get('model', server_config['model']) + n = data.get('n', server_config['n']) # Get n value from request or config + # Extract response_format if present + response_format = data.get("response_format", None) + + # Explicit keys that we are already handling + explicit_keys = {'stream', 'messages', 'model', 'n', 'response_format'} + + # Copy the rest into request_config + request_config = {k: v for k, v in data.items() if k not in explicit_keys} + + # Add the explicitly handled ones + request_config.update({ + "stream": stream, + "n": n, + "response_format": response_format, # Add response_format to config + }) + + optillm_approach = data.get('optillm_approach', server_config['approach']) + logger.debug(data) + server_config['mcts_depth'] = data.get('mcts_depth', server_config['mcts_depth']) + server_config['mcts_exploration'] = data.get('mcts_exploration', server_config['mcts_exploration']) + server_config['mcts_simulations'] = data.get('mcts_simulations', server_config['mcts_simulations']) + + system_prompt, initial_query, message_optillm_approach = parse_conversation(messages) + + if message_optillm_approach: + optillm_approach = message_optillm_approach + + if optillm_approach != "auto": + model = f"{optillm_approach}-{model}" + + base_url = server_config['base_url'] + default_client, api_key = get_config() + + operation, approaches, model = parse_combined_approach(model, known_approaches, plugin_approaches) + + # Start conversation logging if enabled + request_id = None + if conversation_logger and conversation_logger.enabled: + request_id = conversation_logger.start_conversation( + client_request={ + 'messages': messages, + 'model': data.get('model', server_config['model']), + 'stream': stream, + 'n': n, + **{k: v for k, v in data.items() if k not in {'messages', 'model', 'stream', 'n'}} + }, + approach=approaches[0] if len(approaches) == 1 else f"{operation}({','.join(approaches)})", + model=model + ) + + # Log approach and request start with ID for terminal monitoring + request_id_str = f' [Request: {request_id}]' if request_id else '' + logger.info(f'Using approach(es) {approaches}, operation {operation}, with model {model}{request_id_str}') + if request_id: + logger.info(f'Request {request_id}: Starting processing') + + if bearer_token != "" and bearer_token.startswith("sk-"): + api_key = bearer_token + if base_url != "": + client = OpenAI(api_key=api_key, base_url=base_url) + else: + client = OpenAI(api_key=api_key) + else: + client = default_client + + try: + # Route to batch processing if batch mode is enabled + if request_batcher is not None: + try: + # Create request data for batching + batch_request_data = { + 'system_prompt': system_prompt, + 'initial_query': initial_query, + 'client': client, + 'model': model, + 'request_config': request_config, + 'approaches': approaches, + 'operation': operation, + 'n': n, + 'stream': stream, + 'optillm_approach': optillm_approach + } + + logger.debug("Routing request to batch processor") + result = request_batcher.add_request(batch_request_data) + return jsonify(result), 200 + + except BatchingError as e: + logger.error(f"Batch processing failed: {e}") + return jsonify({"error": str(e)}), 500 + + # Check if any of the approaches is 'none' + contains_none = any(approach == 'none' for approach in approaches) + + if operation == 'SINGLE' and approaches[0] == 'none': + # Pass through the request including the n parameter + result, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config, request_id) + + logger.debug(f'Direct proxy response: {result}') + + # Log the final response and finalize conversation logging + if conversation_logger and request_id: + conversation_logger.log_final_response(request_id, result) + conversation_logger.finalize_conversation(request_id) + + if stream: + if request_id: + logger.info(f'Request {request_id}: Completed (streaming response)') + return Response(generate_streaming_response(extract_contents(result), model), content_type='text/event-stream') + else : + if request_id: + logger.info(f'Request {request_id}: Completed') + return jsonify(result), 200 + + elif operation == 'AND' or operation == 'OR': + if contains_none: + raise ValueError("'none' approach cannot be combined with other approaches") + + # Handle non-none approaches with n attempts + response, completion_tokens = execute_n_times(n, approaches, operation, system_prompt, initial_query, client, model, request_config, request_id) + + except Exception as e: + # Log error to conversation logger if enabled + if conversation_logger and request_id: + conversation_logger.log_error(request_id, str(e)) + conversation_logger.finalize_conversation(request_id) + + request_id_str = f' {request_id}' if request_id else '' + logger.error(f"Error processing request{request_id_str}: {str(e)}") + return jsonify({"error": str(e)}), 500 + + # Convert tagged conversation to messages format if needed + if isinstance(response, list): + processed_response = tagged_conversation_to_messages(response) + # If processed_response is a list of message lists, extract last message content + if processed_response != response: # Only process if format changed + response = [msg[-1]['content'] if isinstance(msg, list) and msg else msg + for msg in processed_response] + # Otherwise keep original response + else: + messages = tagged_conversation_to_messages(response) + if isinstance(messages, list) and messages: # Only process if format changed + response = messages[-1]['content'] + + if stream: + return Response(generate_streaming_response(response, model), content_type='text/event-stream') + else: + # Calculate reasoning tokens from the response + reasoning_tokens = 0 + if isinstance(response, str): + reasoning_tokens = count_reasoning_tokens(response) + elif isinstance(response, list) and response: + # For multiple responses, sum up reasoning tokens from all + reasoning_tokens = sum(count_reasoning_tokens(resp) for resp in response if isinstance(resp, str)) + + response_data = { + 'model': model, + 'choices': [], + 'usage': { + 'completion_tokens': completion_tokens, + 'completion_tokens_details': { + 'reasoning_tokens': reasoning_tokens + } + } + } + + if isinstance(response, list): + for index, resp in enumerate(response): + response_data['choices'].append({ + 'index': index, + 'message': { + 'role': 'assistant', + 'content': resp, + }, + 'finish_reason': 'stop' + }) + else: + response_data['choices'].append({ + 'index': 0, + 'message': { + 'role': 'assistant', + 'content': response, + }, + 'finish_reason': 'stop' + }) + + # Log the final response and finalize conversation logging + if conversation_logger and request_id: + conversation_logger.log_final_response(request_id, response_data) + conversation_logger.finalize_conversation(request_id) + + logger.debug(f'API response: {response_data}') + if request_id: + logger.info(f'Request {request_id}: Completed') + return jsonify(response_data), 200 + +@app.route('/v1/models', methods=['GET']) +def proxy_models(): + logger.info('Received request to /v1/models') + default_client, API_KEY = get_config() + try: + if server_config['base_url']: + client = OpenAI(api_key=API_KEY, base_url=server_config['base_url']) + else: + client = default_client + + # Fetch models using the OpenAI client and return the raw response + models_response = client.models.list().json() + + logger.debug('Models retrieved successfully') + return models_response, 200 + except Exception as e: + logger.error(f"Error fetching models: {str(e)}") + return jsonify({"error": f"Error fetching models: {str(e)}"}), 500 + +@app.route('/health', methods=['GET']) +def health(): + return jsonify({"status": "ok"}), 200 + +def parse_args(): + parser = argparse.ArgumentParser(description="Run LLM inference with various approaches.") + + try: + from optillm import __version__ as package_version + except ImportError: + package_version = "unknown" + + parser.add_argument('--version', action='version', + version=f'%(prog)s {package_version}', + help="Show program's version number and exit") + + # Define arguments and their corresponding environment variables + args_env = [ + ("--optillm-api-key", "OPTILLM_API_KEY", str, "", "Optional API key for client authentication to optillm"), + ("--approach", "OPTILLM_APPROACH", str, "auto", "Inference approach to use", known_approaches + list(plugin_approaches.keys())), + ("--mcts-simulations", "OPTILLM_SIMULATIONS", int, 2, "Number of MCTS simulations"), + ("--mcts-exploration", "OPTILLM_EXPLORATION", float, 0.2, "Exploration weight for MCTS"), + ("--mcts-depth", "OPTILLM_DEPTH", int, 1, "Simulation depth for MCTS"), + ("--model", "OPTILLM_MODEL", str, "gpt-4o-mini", "OpenAI model to use"), + ("--rstar-max-depth", "OPTILLM_RSTAR_MAX_DEPTH", int, 3, "Maximum depth for rStar algorithm"), + ("--rstar-num-rollouts", "OPTILLM_RSTAR_NUM_ROLLOUTS", int, 5, "Number of rollouts for rStar algorithm"), + ("--rstar-c", "OPTILLM_RSTAR_C", float, 1.4, "Exploration constant for rStar algorithm"), + ("--n", "OPTILLM_N", int, 1, "Number of final responses to be returned"), + ("--return-full-response", "OPTILLM_RETURN_FULL_RESPONSE", bool, False, "Return the full response including the CoT with tags"), + ("--port", "OPTILLM_PORT", int, 8000, "Specify the port to run the proxy"), + ("--log", "OPTILLM_LOG", str, "info", "Specify the logging level", list(logging_levels.keys())), + ("--launch-gui", "OPTILLM_LAUNCH_GUI", bool, False, "Launch a Gradio chat interface"), + ("--plugins-dir", "OPTILLM_PLUGINS_DIR", str, "", "Path to the plugins directory"), + ("--log-conversations", "OPTILLM_LOG_CONVERSATIONS", bool, False, "Enable conversation logging with full metadata"), + ("--conversation-log-dir", "OPTILLM_CONVERSATION_LOG_DIR", str, str(Path.home() / ".optillm" / "conversations"), "Directory to save conversation logs"), + ] + + for arg, env, type_, default, help_text, *extra in args_env: + env_value = os.environ.get(env) + if env_value is not None: + if type_ == bool: + default = env_value.lower() in ('true', '1', 'yes') + else: + default = type_(env_value) + if extra and extra[0]: # Check if there are choices for this argument + parser.add_argument(arg, type=type_, default=default, help=help_text, choices=extra[0]) + else: + if type_ == bool: + # For boolean flags, use store_true action + parser.add_argument(arg, action='store_true', default=default, help=help_text) + else: + parser.add_argument(arg, type=type_, default=default, help=help_text) + + # Special handling for best_of_n to support both formats + best_of_n_default = int(os.environ.get("OPTILLM_BEST_OF_N", 3)) + parser.add_argument("--best-of-n", "--best_of_n", dest="best_of_n", type=int, default=best_of_n_default, + help="Number of samples for best_of_n approach") + + # Special handling for base_url to support both formats + base_url_default = os.environ.get("OPTILLM_BASE_URL", "") + parser.add_argument("--base-url", "--base_url", dest="base_url", type=str, default=base_url_default, + help="Base url for OpenAI compatible endpoint") + + # Use the function to get the default path + default_config_path = get_config_path() + + # Batch mode arguments + batch_mode_default = os.environ.get("OPTILLM_BATCH_MODE", "false").lower() == "true" + batch_size_default = int(os.environ.get("OPTILLM_BATCH_SIZE", 4)) + batch_wait_ms_default = int(os.environ.get("OPTILLM_BATCH_WAIT_MS", 50)) + + parser.add_argument("--batch-mode", action="store_true", default=batch_mode_default, + help="Enable automatic request batching (fail-fast, no fallback)") + parser.add_argument("--batch-size", type=int, default=batch_size_default, + help="Maximum batch size for request batching") + parser.add_argument("--batch-wait-ms", dest="batch_wait_ms", type=int, default=batch_wait_ms_default, + help="Maximum wait time in milliseconds for batch formation") + + # Special handling of all the CePO Configurations + for field in fields(CepoConfig): + parser.add_argument(f"--cepo_{field.name}", + dest=f"cepo_{field.name}", + type=field.type, + default=None, + help=f"CePO configuration for {field.name}") + + parser.add_argument("--cepo_config_file", + dest="cepo_config_file", + type=str, + default=default_config_path, + help="Path to CePO configuration file") + + args = parser.parse_args() + + # Convert argument names to match server_config keys + args_dict = vars(args) + for key in list(args_dict.keys()): + new_key = key.replace("-", "_") + if new_key != key: + args_dict[new_key] = args_dict.pop(key) + + return args + +def main(): + global server_config + global cepo_config + global request_batcher + global conversation_logger + # Call this function at the start of main() + args = parse_args() + # Update server_config with all argument values + server_config.update(vars(args)) + + load_plugins() + + port = server_config['port'] + + # Initialize request batcher if batch mode is enabled + if server_config.get('batch_mode', False): + logger.info(f"Batch mode enabled: size={server_config['batch_size']}, " + f"wait={server_config['batch_wait_ms']}ms") + request_batcher = RequestBatcher( + max_batch_size=server_config['batch_size'], + max_wait_ms=server_config['batch_wait_ms'], + enable_logging=True + ) + + # Set up the batch processor function + def process_batch_requests(batch_requests): + """ + Process a batch of requests using true batching when possible + + Args: + batch_requests: List of request data dictionaries + + Returns: + List of response dictionaries + """ + import time + from optillm.batching import BatchingError + + if not batch_requests: + return [] + + logger.info(f"Processing batch of {len(batch_requests)} requests") + + # Check if we can use true batching (all requests compatible and using 'none' approach) + can_use_true_batching = True + first_req = batch_requests[0] + + # Check compatibility across all requests + for req_data in batch_requests: + if (req_data['stream'] or + req_data['approaches'] != first_req['approaches'] or + req_data['operation'] != first_req['operation'] or + req_data['model'] != first_req['model']): + can_use_true_batching = False + break + + # For now, implement sequential processing but with proper infrastructure + # TODO: Implement true PyTorch/MLX batching in next phase + responses = [] + + for i, req_data in enumerate(batch_requests): + try: + logger.debug(f"Processing batch request {i+1}/{len(batch_requests)}") + + # Extract request parameters + system_prompt = req_data['system_prompt'] + initial_query = req_data['initial_query'] + client = req_data['client'] + model = req_data['model'] + request_config = req_data['request_config'] + approaches = req_data['approaches'] + operation = req_data['operation'] + n = req_data['n'] + stream = req_data['stream'] + + # Validate request + if stream: + raise BatchingError("Streaming requests cannot be batched") + + # Check if any of the approaches is 'none' + contains_none = any(approach == 'none' for approach in approaches) + + if operation == 'SINGLE' and approaches[0] == 'none': + # Pass through the request including the n parameter + result, completion_tokens = execute_single_approach( + approaches[0], system_prompt, initial_query, client, model, request_config) + elif operation == 'AND' or operation == 'OR': + if contains_none: + raise ValueError("'none' approach cannot be combined with other approaches") + # Handle non-none approaches with n attempts + result, completion_tokens = execute_n_times( + n, approaches, operation, system_prompt, initial_query, client, model, request_config) + else: + # Handle non-none approaches with n attempts + result, completion_tokens = execute_n_times( + n, approaches, operation, system_prompt, initial_query, client, model, request_config) + + # Convert tagged conversation to messages format if needed + if isinstance(result, list): + processed_response = tagged_conversation_to_messages(result) + if processed_response != result: # Only process if format changed + result = [msg[-1]['content'] if isinstance(msg, list) and msg else msg + for msg in processed_response] + else: + messages = tagged_conversation_to_messages(result) + if isinstance(messages, list) and messages: # Only process if format changed + result = messages[-1]['content'] + + # Generate the response in OpenAI format + if isinstance(result, list): + choices = [] + for j, res in enumerate(result): + choices.append({ + "index": j, + "message": { + "role": "assistant", + "content": res + }, + "finish_reason": "stop" + }) + else: + choices = [{ + "index": 0, + "message": { + "role": "assistant", + "content": result + }, + "finish_reason": "stop" + }] + + response_dict = { + "id": f"chatcmpl-{int(time.time()*1000)}-{i}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": choices, + "usage": { + "prompt_tokens": 0, # Will be calculated properly later + "completion_tokens": completion_tokens if isinstance(completion_tokens, int) else 0, + "total_tokens": completion_tokens if isinstance(completion_tokens, int) else 0 + } + } + + responses.append(response_dict) + + except Exception as e: + logger.error(f"Error processing batch request {i+1}: {e}") + raise BatchingError(f"Failed to process request {i+1}: {str(e)}") + + logger.info(f"Completed batch processing of {len(responses)} requests") + return responses + + # Set the processor function on the batcher + request_batcher.set_processor(process_batch_requests) + + # Set logging level from user request + logging_level = server_config['log'] + if logging_level in logging_levels.keys(): + logger.setLevel(logging_levels[logging_level]) + + # Initialize conversation logger if enabled + global conversation_logger + conversation_logger = ConversationLogger( + log_dir=Path(server_config['conversation_log_dir']), + enabled=server_config['log_conversations'] + ) + # Set the global logger instance for access from approach modules + optillm.conversation_logger.set_global_logger(conversation_logger) + if server_config['log_conversations']: + logger.info(f"Conversation logging enabled. Logs will be saved to: {server_config['conversation_log_dir']}") + + # set and log the cepo configs + cepo_config = init_cepo_config(server_config) + if args.approach == 'cepo': + logger.info(f"CePO Config: {cepo_config}") + + logger.info(f"Starting server with approach: {server_config['approach']}") + server_config_clean = server_config.copy() + if server_config_clean['optillm_api_key']: + server_config_clean['optillm_api_key'] = '[REDACTED]' + logger.info(f"Server configuration: {server_config_clean}") + + # Launch GUI if requested + if server_config.get('launch_gui'): + try: + import gradio as gr + # Start server in a separate thread + import threading + server_thread = threading.Thread(target=app.run, kwargs={'host': '0.0.0.0', 'port': port}) + server_thread.daemon = True + server_thread.start() + + # Configure the base URL for the Gradio interface + base_url = f"http://localhost:{port}/v1" + logger.info(f"Launching Gradio interface connected to {base_url}") + + # Create custom chat function with extended timeout + def chat_with_optillm(message, history): + import httpx + from openai import OpenAI + + # Create client with extended timeout and no retries + custom_client = OpenAI( + api_key="optillm", + base_url=base_url, + timeout=httpx.Timeout(1800.0, connect=5.0), # 30 min timeout + max_retries=0 # No retries - prevents duplicate requests + ) + + # Convert history to messages format + messages = [] + for h in history: + if h[0]: # User message + messages.append({"role": "user", "content": h[0]}) + if h[1]: # Assistant message + messages.append({"role": "assistant", "content": h[1]}) + messages.append({"role": "user", "content": message}) + + # Make request + try: + response = custom_client.chat.completions.create( + model=server_config['model'], + messages=messages + ) + return response.choices[0].message.content + except Exception as e: + return f"Error: {str(e)}" + + # Create Gradio interface with queue for long operations + demo = gr.ChatInterface( + chat_with_optillm, + title="OptILLM Chat Interface", + description=f"Connected to OptILLM proxy at {base_url}" + ) + demo.queue() # Enable queue to handle long operations properly + demo.launch(server_name="0.0.0.0", share=False) + except ImportError: + logger.error("Gradio is required for GUI. Install it with: pip install gradio") + return + + app.run(host='0.0.0.0', port=port) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0a5cffe5..a296a8f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,4 @@ optillm = [ "plugins/*.py", "cepo/*.py", "cepo/configs/*.yaml", -] - -[tool.setuptools.data-files] -"" = ["optillm.py"] \ No newline at end of file +] \ No newline at end of file From b4813fb5a0a0aadecf02b83ee7a8c75268c90e9f Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Thu, 4 Sep 2025 17:04:36 +0800 Subject: [PATCH 2/3] f --- .github/workflows/publish-docker-full-amd64.yml | 7 ++++++- .github/workflows/publish-docker-full-arm64.yml | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/.github/workflows/publish-docker-full-amd64.yml b/.github/workflows/publish-docker-full-amd64.yml index bcabf18f..216fe49f 100644 --- a/.github/workflows/publish-docker-full-amd64.yml +++ b/.github/workflows/publish-docker-full-amd64.yml @@ -25,7 +25,12 @@ jobs: - name: Extract version from tag id: version - run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT + run: | + VERSION=${GITHUB_REF#refs/tags/} + if [ -z "$VERSION" ] || [ "$VERSION" = "$GITHUB_REF" ]; then + VERSION="latest" + fi + echo "VERSION=$VERSION" >> $GITHUB_OUTPUT - name: Build and push full AMD64 image uses: docker/build-push-action@v5 diff --git a/.github/workflows/publish-docker-full-arm64.yml b/.github/workflows/publish-docker-full-arm64.yml index db098467..f3f43e2d 100644 --- a/.github/workflows/publish-docker-full-arm64.yml +++ b/.github/workflows/publish-docker-full-arm64.yml @@ -28,7 +28,12 @@ jobs: - name: Extract version from tag id: version - run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT + run: | + VERSION=${GITHUB_REF#refs/tags/} + if [ -z "$VERSION" ] || [ "$VERSION" = "$GITHUB_REF" ]; then + VERSION="latest" + fi + echo "VERSION=$VERSION" >> $GITHUB_OUTPUT - name: Build and push full ARM64 image uses: docker/build-push-action@v5 From c2bad7d850c693f0b9ac0ddfde37145be15adf6e Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Thu, 4 Sep 2025 17:05:12 +0800 Subject: [PATCH 3/3] bump versions --- optillm/__init__.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/optillm/__init__.py b/optillm/__init__.py index 01f1face..129977a6 100644 --- a/optillm/__init__.py +++ b/optillm/__init__.py @@ -1,5 +1,5 @@ # Version information -__version__ = "0.2.1" +__version__ = "0.2.2" # Import from server module from .server import ( diff --git a/pyproject.toml b/pyproject.toml index a296a8f4..df090553 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "optillm" -version = "0.2.1" +version = "0.2.2" description = "An optimizing inference proxy for LLMs." readme = "README.md" license = "Apache-2.0"