diff --git a/optillm/__init__.py b/optillm/__init__.py index cbb0f196..ca7b60ec 100644 --- a/optillm/__init__.py +++ b/optillm/__init__.py @@ -1,5 +1,5 @@ # Version information -__version__ = "0.3.3" +__version__ = "0.3.4" # Import from server module from .server import ( diff --git a/optillm/bon.py b/optillm/bon.py index 3e5885df..c23b6432 100644 --- a/optillm/bon.py +++ b/optillm/bon.py @@ -22,16 +22,24 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st "temperature": 1 } response = client.chat.completions.create(**provider_request) - + # Log provider call if request_id: response_dict = response.model_dump() if hasattr(response, 'model_dump') else response conversation_logger.log_provider_call(request_id, provider_request, response_dict) - - completions = [choice.message.content for choice in response.choices] + + # Check for valid response with None-checking + if response is None or not response.choices: + raise Exception("Response is None or has no choices") + + completions = [choice.message.content for choice in response.choices if choice.message.content is not None] logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}") bon_completion_tokens += response.usage.completion_tokens - + + # Check if any valid completions were generated + if not completions: + raise Exception("No valid completions generated (all were None)") + except Exception as e: logger.warning(f"n parameter not supported by provider: {str(e)}") logger.info(f"Falling back to generating {n} completions one by one") @@ -46,12 +54,20 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st "temperature": 1 } response = client.chat.completions.create(**provider_request) - + # Log provider call if request_id: response_dict = response.model_dump() if hasattr(response, 'model_dump') else response conversation_logger.log_provider_call(request_id, provider_request, response_dict) - + + # Check for valid response with None-checking + if (response is None or + not response.choices or + response.choices[0].message.content is None or + response.choices[0].finish_reason == "length"): + logger.warning(f"Completion {i+1}/{n} truncated or empty, skipping") + continue + completions.append(response.choices[0].message.content) bon_completion_tokens += response.usage.completion_tokens logger.debug(f"Generated completion {i+1}/{n}") @@ -65,11 +81,16 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st return "Error: Could not generate any completions", 0 logger.info(f"Generated {len(completions)} completions using fallback method. Total tokens used: {bon_completion_tokens}") - + + # Double-check we have completions before rating + if not completions: + logger.error("No completions available for rating") + return "Error: Could not generate any completions", bon_completion_tokens + # Rate the completions rating_messages = messages.copy() rating_messages.append({"role": "system", "content": "Rate the following responses on a scale from 0 to 10, where 0 is poor and 10 is excellent. Consider factors such as relevance, coherence, and helpfulness. Respond with only a number."}) - + ratings = [] for completion in completions: rating_messages.append({"role": "assistant", "content": completion}) @@ -83,18 +104,27 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st "temperature": 0.1 } rating_response = client.chat.completions.create(**provider_request) - + # Log provider call if request_id: response_dict = rating_response.model_dump() if hasattr(rating_response, 'model_dump') else rating_response conversation_logger.log_provider_call(request_id, provider_request, response_dict) - + bon_completion_tokens += rating_response.usage.completion_tokens - try: - rating = float(rating_response.choices[0].message.content.strip()) - ratings.append(rating) - except ValueError: + + # Check for valid response with None-checking + if (rating_response is None or + not rating_response.choices or + rating_response.choices[0].message.content is None or + rating_response.choices[0].finish_reason == "length"): + logger.warning("Rating response truncated or empty, using default rating of 0") ratings.append(0) + else: + try: + rating = float(rating_response.choices[0].message.content.strip()) + ratings.append(rating) + except ValueError: + ratings.append(0) rating_messages = rating_messages[:-2] diff --git a/optillm/mcts.py b/optillm/mcts.py index d1727ea2..85177497 100644 --- a/optillm/mcts.py +++ b/optillm/mcts.py @@ -122,13 +122,18 @@ def generate_actions(self, state: DialogueState) -> List[str]: "temperature": 1 } response = self.client.chat.completions.create(**provider_request) - + # Log provider call if self.request_id: response_dict = response.model_dump() if hasattr(response, 'model_dump') else response conversation_logger.log_provider_call(self.request_id, provider_request, response_dict) - - completions = [choice.message.content.strip() for choice in response.choices] + + # Check for valid response with None-checking + if response is None or not response.choices: + logger.error("Failed to get valid completions from the model") + return [] + + completions = [choice.message.content.strip() for choice in response.choices if choice.message.content is not None] self.completion_tokens += response.usage.completion_tokens logger.info(f"Received {len(completions)} completions from the model") return completions @@ -151,13 +156,22 @@ def apply_action(self, state: DialogueState, action: str) -> DialogueState: "temperature": 1 } response = self.client.chat.completions.create(**provider_request) - + # Log provider call if self.request_id: response_dict = response.model_dump() if hasattr(response, 'model_dump') else response conversation_logger.log_provider_call(self.request_id, provider_request, response_dict) - - next_query = response.choices[0].message.content + + # Check for valid response with None-checking + if (response is None or + not response.choices or + response.choices[0].message.content is None or + response.choices[0].finish_reason == "length"): + logger.warning("Next query response truncated or empty, using default") + next_query = "Please continue." + else: + next_query = response.choices[0].message.content + self.completion_tokens += response.usage.completion_tokens logger.info(f"Generated next user query: {next_query}") return DialogueState(state.system_prompt, new_history, next_query) @@ -181,13 +195,22 @@ def evaluate_state(self, state: DialogueState) -> float: "temperature": 0.1 } response = self.client.chat.completions.create(**provider_request) - + # Log provider call if self.request_id: response_dict = response.model_dump() if hasattr(response, 'model_dump') else response conversation_logger.log_provider_call(self.request_id, provider_request, response_dict) - + self.completion_tokens += response.usage.completion_tokens + + # Check for valid response with None-checking + if (response is None or + not response.choices or + response.choices[0].message.content is None or + response.choices[0].finish_reason == "length"): + logger.warning("Evaluation response truncated or empty. Using default value 0.5") + return 0.5 + try: score = float(response.choices[0].message.content.strip()) score = max(0, min(score, 1)) # Ensure the score is between 0 and 1 diff --git a/optillm/moa.py b/optillm/moa.py index 9f6fd034..86371c1f 100644 --- a/optillm/moa.py +++ b/optillm/moa.py @@ -25,18 +25,26 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str } response = client.chat.completions.create(**provider_request) - + # Convert response to dict for logging response_dict = response.model_dump() if hasattr(response, 'model_dump') else response - + # Log provider call if conversation logging is enabled if request_id: conversation_logger.log_provider_call(request_id, provider_request, response_dict) - - completions = [choice.message.content for choice in response.choices] + + # Check for valid response with None-checking + if response is None or not response.choices: + raise Exception("Response is None or has no choices") + + completions = [choice.message.content for choice in response.choices if choice.message.content is not None] moa_completion_tokens += response.usage.completion_tokens logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}") - + + # Check if any valid completions were generated + if not completions: + raise Exception("No valid completions generated (all were None)") + except Exception as e: logger.warning(f"n parameter not supported by provider: {str(e)}") logger.info("Falling back to generating 3 completions one by one") @@ -56,14 +64,22 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str } response = client.chat.completions.create(**provider_request) - + # Convert response to dict for logging response_dict = response.model_dump() if hasattr(response, 'model_dump') else response - + # Log provider call if conversation logging is enabled if request_id: conversation_logger.log_provider_call(request_id, provider_request, response_dict) - + + # Check for valid response with None-checking + if (response is None or + not response.choices or + response.choices[0].message.content is None or + response.choices[0].finish_reason == "length"): + logger.warning(f"Completion {i+1}/3 truncated or empty, skipping") + continue + completions.append(response.choices[0].message.content) moa_completion_tokens += response.usage.completion_tokens logger.debug(f"Generated completion {i+1}/3") @@ -77,7 +93,12 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str return "Error: Could not generate any completions", 0 logger.info(f"Generated {len(completions)} completions using fallback method. Total tokens used: {moa_completion_tokens}") - + + # Double-check we have at least one completion + if not completions: + logger.error("No completions available for processing") + return "Error: Could not generate any completions", moa_completion_tokens + # Handle case where fewer than 3 completions were generated if len(completions) < 3: original_count = len(completions) @@ -118,15 +139,24 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str } critique_response = client.chat.completions.create(**provider_request) - + # Convert response to dict for logging response_dict = critique_response.model_dump() if hasattr(critique_response, 'model_dump') else critique_response - + # Log provider call if conversation logging is enabled if request_id: conversation_logger.log_provider_call(request_id, provider_request, response_dict) - - critiques = critique_response.choices[0].message.content + + # Check for valid response with None-checking + if (critique_response is None or + not critique_response.choices or + critique_response.choices[0].message.content is None or + critique_response.choices[0].finish_reason == "length"): + logger.warning("Critique response truncated or empty, using generic critique") + critiques = "All candidates show reasonable approaches to the problem." + else: + critiques = critique_response.choices[0].message.content + moa_completion_tokens += critique_response.usage.completion_tokens logger.info(f"Generated critiques. Tokens used: {critique_response.usage.completion_tokens}") @@ -165,16 +195,27 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str } final_response = client.chat.completions.create(**provider_request) - + # Convert response to dict for logging response_dict = final_response.model_dump() if hasattr(final_response, 'model_dump') else final_response - + # Log provider call if conversation logging is enabled if request_id: conversation_logger.log_provider_call(request_id, provider_request, response_dict) - + moa_completion_tokens += final_response.usage.completion_tokens logger.info(f"Generated final response. Tokens used: {final_response.usage.completion_tokens}") - + + # Check for valid response with None-checking + if (final_response is None or + not final_response.choices or + final_response.choices[0].message.content is None or + final_response.choices[0].finish_reason == "length"): + logger.error("Final response truncated or empty. Consider increasing max_tokens.") + # Return best completion if final response failed + result = completions[0] if completions else "Error: Response was truncated due to token limit. Please increase max_tokens or max_completion_tokens." + else: + result = final_response.choices[0].message.content + logger.info(f"Total completion tokens used: {moa_completion_tokens}") - return final_response.choices[0].message.content, moa_completion_tokens \ No newline at end of file + return result, moa_completion_tokens \ No newline at end of file diff --git a/optillm/plansearch.py b/optillm/plansearch.py index 517bccc9..f91c9a8e 100644 --- a/optillm/plansearch.py +++ b/optillm/plansearch.py @@ -35,12 +35,21 @@ def generate_observations(self, problem: str, num_observations: int = 3) -> List } response = self.client.chat.completions.create(**provider_request) - + # Log provider call if conversation logging is enabled if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id: response_dict = response.model_dump() if hasattr(response, 'model_dump') else response optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict) self.plansearch_completion_tokens += response.usage.completion_tokens + + # Check for valid response with None-checking + if (response is None or + not response.choices or + response.choices[0].message.content is None or + response.choices[0].finish_reason == "length"): + logger.warning("Observations response truncated or empty, returning empty list") + return [] + observations = response.choices[0].message.content.strip().split('\n') return [obs.strip() for obs in observations if obs.strip()] @@ -70,12 +79,21 @@ def generate_derived_observations(self, problem: str, observations: List[str], n } response = self.client.chat.completions.create(**provider_request) - + # Log provider call if conversation logging is enabled if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id: response_dict = response.model_dump() if hasattr(response, 'model_dump') else response optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict) self.plansearch_completion_tokens += response.usage.completion_tokens + + # Check for valid response with None-checking + if (response is None or + not response.choices or + response.choices[0].message.content is None or + response.choices[0].finish_reason == "length"): + logger.warning("Derived observations response truncated or empty, returning empty list") + return [] + new_observations = response.choices[0].message.content.strip().split('\n') return [obs.strip() for obs in new_observations if obs.strip()] @@ -101,14 +119,23 @@ def generate_solution(self, problem: str, observations: List[str]) -> str: {"role": "user", "content": prompt} ] } - + response = self.client.chat.completions.create(**provider_request) - + # Log provider call if conversation logging is enabled if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id: response_dict = response.model_dump() if hasattr(response, 'model_dump') else response optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict) self.plansearch_completion_tokens += response.usage.completion_tokens + + # Check for valid response with None-checking + if (response is None or + not response.choices or + response.choices[0].message.content is None or + response.choices[0].finish_reason == "length"): + logger.error("Solution generation response truncated or empty. Consider increasing max_tokens.") + return "Error: Response was truncated due to token limit. Please increase max_tokens or max_completion_tokens." + return response.choices[0].message.content.strip() def implement_solution(self, problem: str, solution: str) -> str: @@ -134,14 +161,23 @@ def implement_solution(self, problem: str, solution: str) -> str: {"role": "user", "content": prompt} ] } - + response = self.client.chat.completions.create(**provider_request) - + # Log provider call if conversation logging is enabled if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id: response_dict = response.model_dump() if hasattr(response, 'model_dump') else response optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict) self.plansearch_completion_tokens += response.usage.completion_tokens + + # Check for valid response with None-checking + if (response is None or + not response.choices or + response.choices[0].message.content is None or + response.choices[0].finish_reason == "length"): + logger.error("Implementation response truncated or empty. Consider increasing max_tokens.") + return "Error: Response was truncated due to token limit. Please increase max_tokens or max_completion_tokens." + return response.choices[0].message.content.strip() def solve(self, problem: str, num_initial_observations: int = 3, num_derived_observations: int = 2) -> Tuple[str, str]: diff --git a/optillm/plugins/deepthink/self_discover.py b/optillm/plugins/deepthink/self_discover.py index 71426241..5564cf99 100644 --- a/optillm/plugins/deepthink/self_discover.py +++ b/optillm/plugins/deepthink/self_discover.py @@ -16,16 +16,20 @@ class SelfDiscover: """ Implementation of the SELF-DISCOVER framework. - + The framework operates in two stages: 1. Stage 1: Discover task-specific reasoning structure (SELECT, ADAPT, IMPLEMENT) 2. Stage 2: Use discovered structure to solve problem instances """ - - def __init__(self, client, model: str, max_tokens: int = 16382): + + def __init__(self, client, model: str, max_tokens: int = 16382, request_config: Dict[str, Any] = None): self.client = client self.model = model - self.max_tokens = max_tokens + # Read max_completion_tokens (preferred) or max_tokens (deprecated) from request_config + if request_config: + self.max_tokens = request_config.get('max_completion_tokens') or request_config.get('max_tokens', max_tokens) + else: + self.max_tokens = max_tokens self.reasoning_modules = get_all_modules() self.completion_tokens = 0 @@ -95,10 +99,18 @@ def _select_modules(self, task_description: str, task_examples: List[str] = None max_tokens=1024, temperature=0.3 ) - + self.completion_tokens += response.usage.completion_tokens - + try: + # Check for truncated or empty response + if (response is None or + not response.choices or + response.choices[0].message.content is None or + response.choices[0].finish_reason == "length"): + logger.warning("Response truncated or empty in module selection, using fallback modules") + return self.reasoning_modules[:5] + # Extract JSON from response response_text = response.choices[0].message.content.strip() # Look for JSON array in the response @@ -153,9 +165,18 @@ def _adapt_modules(self, selected_modules: List[Dict[str, Any]], task_descriptio max_tokens=2048, temperature=0.3 ) - + self.completion_tokens += response.usage.completion_tokens - + + # Check for truncated or empty response + if (response is None or + not response.choices or + response.choices[0].message.content is None or + response.choices[0].finish_reason == "length"): + logger.warning("Response truncated or empty in module adaptation, using generic descriptions") + # Return generic adapted versions of the selected modules + return [module.get('description', 'Apply reasoning to solve the problem') for module in selected_modules] + response_text = response.choices[0].message.content.strip() # Extract adapted modules from numbered list @@ -221,9 +242,24 @@ def _implement_structure(self, adapted_modules: List[str], task_description: str max_tokens=2048, temperature=0.3 ) - + self.completion_tokens += response.usage.completion_tokens - + + # Check for truncated or empty response + if (response is None or + not response.choices or + response.choices[0].message.content is None or + response.choices[0].finish_reason == "length"): + logger.warning("Response truncated or empty in structure implementation, using fallback structure") + # Return the fallback structure directly + return { + "problem_understanding": "Analyze and understand the problem requirements", + "solution_approach": "Determine the best approach based on problem characteristics", + "step_by_step_reasoning": "Work through the problem systematically", + "verification": "Verify the solution is correct and complete", + "final_answer": "State the final answer clearly" + } + response_text = response.choices[0].message.content.strip() # Extract and parse JSON from response with improved error handling @@ -386,7 +422,15 @@ def solve_with_structure(self, problem: str, reasoning_structure: Dict[str, Any] max_tokens=self.max_tokens, temperature=0.7 ) - + self.completion_tokens += response.usage.completion_tokens - + + # Check for truncated or empty response + if (response is None or + not response.choices or + response.choices[0].message.content is None or + response.choices[0].finish_reason == "length"): + logger.error("Response truncated or empty when solving with structure. Consider increasing max_tokens.") + return "Error: Response was truncated due to token limit. Please increase max_tokens or max_completion_tokens." + return response.choices[0].message.content.strip() diff --git a/optillm/plugins/deepthink/uncertainty_cot.py b/optillm/plugins/deepthink/uncertainty_cot.py index 65bd7a25..d9e056c7 100644 --- a/optillm/plugins/deepthink/uncertainty_cot.py +++ b/optillm/plugins/deepthink/uncertainty_cot.py @@ -17,17 +17,21 @@ class UncertaintyRoutedCoT: """ Implements uncertainty-routed chain-of-thought reasoning. - + The approach: 1. Generate k chain-of-thought samples 2. Evaluate confidence through consistency analysis 3. Route to majority vote (high confidence) or greedy sample (low confidence) """ - - def __init__(self, client, model: str, max_tokens: int = 16382): + + def __init__(self, client, model: str, max_tokens: int = 16382, request_config: Dict[str, Any] = None): self.client = client self.model = model - self.max_tokens = max_tokens + # Read max_completion_tokens (preferred) or max_tokens (deprecated) from request_config + if request_config: + self.max_tokens = request_config.get('max_completion_tokens') or request_config.get('max_tokens', max_tokens) + else: + self.max_tokens = max_tokens self.completion_tokens = 0 def generate_with_uncertainty_routing( @@ -127,9 +131,18 @@ def _generate_multiple_samples( temperature=temperature, top_p=top_p ) - + self.completion_tokens += response.usage.completion_tokens - samples.append(response.choices[0].message.content.strip()) + + # Check for truncated or empty response + if (response is None or + not response.choices or + response.choices[0].message.content is None or + response.choices[0].finish_reason == "length"): + logger.warning(f"Sample {i+1}/{num_samples} truncated or empty, using empty string") + samples.append("") + else: + samples.append(response.choices[0].message.content.strip()) return samples @@ -143,9 +156,17 @@ def _generate_greedy_sample(self, prompt: str) -> str: max_tokens=self.max_tokens, temperature=0.0 # Greedy decoding ) - + self.completion_tokens += response.usage.completion_tokens - + + # Check for truncated or empty response + if (response is None or + not response.choices or + response.choices[0].message.content is None or + response.choices[0].finish_reason == "length"): + logger.error("Greedy sample truncated or empty. Consider increasing max_tokens.") + return "Error: Response was truncated due to token limit. Please increase max_tokens or max_completion_tokens." + return response.choices[0].message.content.strip() def _extract_thinking(self, response: str) -> str: diff --git a/optillm/plugins/deepthink_plugin.py b/optillm/plugins/deepthink_plugin.py index bdf76021..7bcddf9d 100644 --- a/optillm/plugins/deepthink_plugin.py +++ b/optillm/plugins/deepthink_plugin.py @@ -46,13 +46,15 @@ def run( self_discover = SelfDiscover( client=client, model=model, - max_tokens=config["max_tokens"] + max_tokens=config["max_tokens"], + request_config=request_config ) - + uncertainty_cot = UncertaintyRoutedCoT( client=client, model=model, - max_tokens=config["max_tokens"] + max_tokens=config["max_tokens"], + request_config=request_config ) total_tokens = 0 @@ -82,7 +84,7 @@ def run( # Stage 2: Uncertainty-routed generation logger.info("Generating response with uncertainty routing") - + generation_result = uncertainty_cot.generate_with_uncertainty_routing( prompt=enhanced_prompt, num_samples=config["deepthink_samples"], @@ -90,14 +92,20 @@ def run( temperature=config["temperature"], top_p=config["top_p"] ) - + total_tokens += generation_result["completion_tokens"] - + # Log routing decision logger.info(f"Routing decision: {generation_result['routing_decision']} " f"(confidence: {generation_result['confidence_score']:.3f})") - + final_response = generation_result["final_response"] + + # Check if response is an error message or empty + if not final_response or final_response.startswith("Error:"): + logger.error("Deep Think generation failed or was truncated") + if not final_response: + final_response = "Error: Failed to generate a response. The model may have exceeded token limits." # Clean up the response if needed final_response = _clean_response(final_response) @@ -108,7 +116,7 @@ def run( def _parse_config(request_config: Dict[str, Any]) -> Dict[str, Any]: """Parse and validate configuration parameters.""" - + default_config = { "deepthink_samples": 3, "confidence_threshold": 0.7, @@ -118,11 +126,17 @@ def _parse_config(request_config: Dict[str, Any]) -> Dict[str, Any]: "enable_self_discover": True, "reasoning_modules_limit": 7 } - + # Override with request config values for key, value in request_config.items(): if key in default_config: default_config[key] = value + + # Handle max_completion_tokens (preferred) or max_tokens (deprecated) + if 'max_completion_tokens' in request_config: + default_config['max_tokens'] = request_config['max_completion_tokens'] + elif 'max_tokens' in request_config: + default_config['max_tokens'] = request_config['max_tokens'] # Validate ranges default_config["deepthink_samples"] = max(1, min(10, default_config["deepthink_samples"])) diff --git a/optillm/plugins/spl_plugin.py b/optillm/plugins/spl_plugin.py index 506071d1..ace64674 100644 --- a/optillm/plugins/spl_plugin.py +++ b/optillm/plugins/spl_plugin.py @@ -13,10 +13,8 @@ LLM incrementally better at solving problems by learning from its experiences. """ -import os -import sys -import importlib.util from typing import Tuple +from optillm.plugins.spl import run_spl # Plugin identifier SLUG = "spl" @@ -24,7 +22,7 @@ def run(system_prompt: str, initial_query: str, client, model: str, request_config: dict = None) -> Tuple[str, int]: """ Plugin entry point for System Prompt Learning. - + Args: system_prompt: The system prompt initial_query: The user's query @@ -32,27 +30,8 @@ def run(system_prompt: str, initial_query: str, client, model: str, request_conf model: The model identifier request_config: Optional request configuration Can include {'spl_learning': True} to enable learning mode - + Returns: Tuple[str, int]: The LLM response and token count """ - # Get the directory where this plugin is located - plugin_dir = os.path.dirname(os.path.abspath(__file__)) - spl_dir = os.path.join(plugin_dir, 'spl') - main_file = os.path.join(spl_dir, 'main.py') - - # Load the main module dynamically - spec = importlib.util.spec_from_file_location("spl_main", main_file) - spl_main = importlib.util.module_from_spec(spec) - - # Add the spl directory to the Python path temporarily - if spl_dir not in sys.path: - sys.path.insert(0, spl_dir) - - try: - spec.loader.exec_module(spl_main) - return spl_main.run_spl(system_prompt, initial_query, client, model, request_config) - finally: - # Remove from path after use - if spl_dir in sys.path: - sys.path.remove(spl_dir) + return run_spl(system_prompt, initial_query, client, model, request_config) diff --git a/optillm/server.py b/optillm/server.py index 8fbad81a..17b15796 100644 --- a/optillm/server.py +++ b/optillm/server.py @@ -699,8 +699,13 @@ def proxy(): # Extract response_format if present response_format = data.get("response_format", None) + # Handle max_completion_tokens (preferred) and max_tokens (deprecated but supported) + # Priority: max_completion_tokens > max_tokens + max_completion_tokens = data.get('max_completion_tokens') + max_tokens = data.get('max_tokens') + # Explicit keys that we are already handling - explicit_keys = {'stream', 'messages', 'model', 'n', 'response_format'} + explicit_keys = {'stream', 'messages', 'model', 'n', 'response_format', 'max_completion_tokens', 'max_tokens'} # Copy the rest into request_config request_config = {k: v for k, v in data.items() if k not in explicit_keys} @@ -712,6 +717,13 @@ def proxy(): "response_format": response_format, # Add response_format to config }) + # Add token limits to request_config with proper priority + if max_completion_tokens is not None: + request_config['max_completion_tokens'] = max_completion_tokens + request_config['max_tokens'] = max_completion_tokens # For backward compatibility with approaches that read max_tokens + elif max_tokens is not None: + request_config['max_tokens'] = max_tokens + optillm_approach = data.get('optillm_approach', server_config['approach']) logger.debug(data) server_config['mcts_depth'] = data.get('mcts_depth', server_config['mcts_depth']) diff --git a/pyproject.toml b/pyproject.toml index c37c5ed4..e8e3fd33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "optillm" -version = "0.3.3" +version = "0.3.4" description = "An optimizing inference proxy for LLMs." readme = "README.md" license = "Apache-2.0"