diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index eaea79563c..4c65d917a4 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -78,12 +78,14 @@ def _call_postprocess( for output in outputs: output_logprobs = None tool_calls = None + reasoning = None text = output if isinstance(output, dict): text = output["text"] output_logprobs = output.get("logprobs") tool_calls = output.get("tool_calls") + reasoning = output.get("reasoning") if text: value = self.parse(processed_signature, text) @@ -109,6 +111,9 @@ def _call_postprocess( if output_logprobs: value["logprobs"] = output_logprobs + if reasoning: + value["reasoning"] = reasoning + values.append(value) return values diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 6f86da5632..97a2f71e2c 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -179,6 +179,10 @@ def _process_completion(self, response, merged_kwargs): for c in response.choices: output = {} output["text"] = c.message.content if hasattr(c, "message") else c["text"] + + if hasattr(c, "message") and hasattr(c.message, "reasoning_content") and c.message.reasoning_content: + output["reasoning"] = c.message.reasoning_content + if merged_kwargs.get("logprobs"): output["logprobs"] = c.logprobs if hasattr(c, "logprobs") else c["logprobs"] if hasattr(c, "message") and getattr(c.message, "tool_calls", None): @@ -203,12 +207,34 @@ def _process_response(self, response): """ outputs = [] tool_calls = [] + reasoning_content = None + for output_item in response.output: if output_item.type == "message": for content_item in output_item.content: outputs.append(content_item.text) elif output_item.type == "function_call": tool_calls.append(output_item.model_dump()) + elif output_item.type == "reasoning": + if hasattr(output_item, "content") and output_item.content: + reasoning_content = output_item.content + elif hasattr(output_item, "summary") and output_item.summary: + if isinstance(output_item.summary, list): + summary_texts = [] + for summary_item in output_item.summary: + if hasattr(summary_item, "text"): + summary_texts.append(summary_item.text) + reasoning_content = "\n\n".join(summary_texts) if summary_texts else output_item.summary + else: + reasoning_content = output_item.summary + + if reasoning_content: + if len(outputs) == 1 and isinstance(outputs[0], str): + outputs = [{"text": outputs[0], "reasoning": reasoning_content}] + elif outputs and isinstance(outputs[0], str): + outputs[0] = {"text": outputs[0], "reasoning": reasoning_content} + elif outputs and isinstance(outputs[0], dict): + outputs[0]["reasoning"] = reasoning_content if tool_calls: outputs.append({"tool_calls": tool_calls}) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 3c133a7c03..3ab39eb6a0 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -103,6 +103,19 @@ def __init__( self._warn_zero_temp_rollout(self.kwargs.get("temperature"), self.kwargs.get("rollout_id")) + # Set flag if model supports native reasoning AND user specified any reasoning parameter + reasoning_params = ["reasoning_effort", "reasoning", "thinking"] # Common reasoning parameter names + has_reasoning_param = any(param in self.kwargs for param in reasoning_params) + if litellm.supports_reasoning(self.model) and has_reasoning_param: + settings.use_native_reasoning = True + + # Normalize reasoning_effort to get reasoning summaries (for OpenAI reasoning models which don't expose reasoning content) + if ("reasoning_effort" in self.kwargs and + (self.model_type == "responses" or + ("openai/" in self.model.lower() and litellm.supports_reasoning(self.model)))): + effort = self.kwargs.pop("reasoning_effort") + self.kwargs["reasoning"] = {"effort": effort, "summary": "auto"} + def _warn_zero_temp_rollout(self, temperature: float | None, rollout_id): if ( not self._warned_zero_temp_rollout diff --git a/dspy/predict/chain_of_thought.py b/dspy/predict/chain_of_thought.py index 96afef8588..2d9681407f 100644 --- a/dspy/predict/chain_of_thought.py +++ b/dspy/predict/chain_of_thought.py @@ -3,6 +3,7 @@ from pydantic.fields import FieldInfo import dspy +from dspy.dsp.utils.settings import settings from dspy.primitives.module import Module from dspy.signatures.signature import Signature, ensure_signature @@ -26,12 +27,16 @@ def __init__( """ super().__init__() signature = ensure_signature(signature) - prefix = "Reasoning: Let's think step by step in order to" - desc = "${reasoning}" - rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type - rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc) - extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type) - self.predict = dspy.Predict(extended_signature, **config) + + if getattr(settings, "use_native_reasoning", False): + self.predict = dspy.Predict(signature, **config) + else: + prefix = "Reasoning: Let's think step by step in order to" + desc = "${reasoning}" + rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type + rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc) + extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type) + self.predict = dspy.Predict(extended_signature, **config) def forward(self, **kwargs): return self.predict(**kwargs) diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index 54068e512d..8756c922f1 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -12,6 +12,7 @@ from openai import RateLimitError import dspy +from dspy.dsp.utils.settings import settings from dspy.utils.usage_tracker import track_usage @@ -551,3 +552,115 @@ def test_responses_api_tool_calls(litellm_test_server): dspy_responses.assert_called_once() assert dspy_responses.call_args.kwargs["model"] == "openai/dspy-test-model" + + +def test_reasoning_effort_normalization(): + """Test that reasoning_effort gets normalized to reasoning format for OpenAI models.""" + with mock.patch("litellm.supports_reasoning", return_value=True): + # OpenAI model with Responses API - should normalize + lm1 = dspy.LM( + model="openai/gpt-5", + model_type="responses", + reasoning_effort="low", + max_tokens=16000, + temperature=1.0 + ) + assert "reasoning_effort" not in lm1.kwargs + assert lm1.kwargs["reasoning"] == {"effort": "low", "summary": "auto"} + + # OpenAI model with Chat API - should normalize + lm2 = dspy.LM( + model="openai/gpt-5", + reasoning_effort="medium", + max_tokens=16000, + temperature=1.0 + ) + assert "reasoning_effort" not in lm2.kwargs + assert lm2.kwargs["reasoning"] == {"effort": "medium", "summary": "auto"} + + # Non-OpenAI model - should NOT normalize + lm3 = dspy.LM( + model="deepseek-ai/DeepSeek-R1", + reasoning_effort="low", + max_tokens=4000, + temperature=0.7 + ) + assert "reasoning_effort" in lm3.kwargs + assert "reasoning" not in lm3.kwargs + + +@mock.patch("litellm.supports_reasoning") +@mock.patch("dspy.dsp.utils.settings") +def test_native_reasoning_flag_setting(mock_settings, mock_supports): + """Test that use_native_reasoning flag is set correctly.""" + mock_supports.return_value = True + + # Should set flag when model supports reasoning and has reasoning param + dspy.LM(model="openai/gpt-5", reasoning_effort="low", max_tokens=16000, temperature=1.0) + mock_settings.use_native_reasoning = True + + mock_supports.return_value = False + + # Should NOT set flag when model doesn't support reasoning + dspy.LM(model="openai/gpt-4", reasoning_effort="low", max_tokens=1000, temperature=0.7) + + +def test_reasoning_content_extraction(): + """Test that reasoning models can be created with proper configuration.""" + # Test that reasoning models are properly configured + lm = dspy.LM( + model="openai/gpt-5", + model_type="responses", + max_tokens=16000, + temperature=1.0, + reasoning_effort="low" + ) + + # Verify reasoning parameters are normalized + assert "reasoning" in lm.kwargs + assert lm.kwargs["reasoning"]["effort"] == "low" + assert "max_completion_tokens" in lm.kwargs + assert lm.kwargs["max_completion_tokens"] == 16000 + + +def test_chain_of_thought_with_native_reasoning(): + """Test ChainOfThought with native reasoning vs manual reasoning.""" + + class SimpleSignature(dspy.Signature): + """Answer the question.""" + question: str = dspy.InputField() + answer: str = dspy.OutputField() + + # Test with native reasoning enabled + settings.use_native_reasoning = True + with mock.patch("dspy.Predict") as mock_predict: + mock_predict_instance = mock.MagicMock() + mock_predict_instance.return_value = dspy.Prediction(answer="42", reasoning="native reasoning") + mock_predict.return_value = mock_predict_instance + + cot = dspy.ChainOfThought(SimpleSignature) + result = cot(question="What is the answer?") + + # Should use Predict with original signature (no reasoning field added) + mock_predict.assert_called_once() + call_args = mock_predict.call_args[0] + assert call_args[0] == SimpleSignature + assert hasattr(result, "reasoning") + + # Reset and test with native reasoning disabled (traditional ChainOfThought) + settings.use_native_reasoning = False + with mock.patch("dspy.Predict") as mock_predict: + mock_predict_instance = mock.MagicMock() + mock_predict_instance.return_value = dspy.Prediction(reasoning="step by step...", answer="42") + mock_predict.return_value = mock_predict_instance + + cot = dspy.ChainOfThought(SimpleSignature) + result = cot(question="What is the answer?") + + # Should use Predict with extended signature (reasoning field added) + mock_predict.assert_called_once() + call_args = mock_predict.call_args[0] + # Check that signature was extended with reasoning field + extended_signature = call_args[0] + assert "reasoning" in extended_signature.fields + assert hasattr(result, "reasoning")