From 6af5b759f8c8cda3761c0e6b39fe01d3882ce00a Mon Sep 17 00:00:00 2001 From: arnavsinghvi11 Date: Wed, 3 Sep 2025 23:03:43 -0700 Subject: [PATCH 1/2] support for native reasoning in CoT for reasoning models --- dspy/adapters/base.py | 5 ++ dspy/clients/base_lm.py | 29 ++++++++ dspy/clients/lm.py | 13 ++++ dspy/predict/chain_of_thought.py | 17 +++-- tests/clients/test_lm.py | 113 +++++++++++++++++++++++++++++++ 5 files changed, 171 insertions(+), 6 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index eaea79563c..4c65d917a4 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -78,12 +78,14 @@ def _call_postprocess( for output in outputs: output_logprobs = None tool_calls = None + reasoning = None text = output if isinstance(output, dict): text = output["text"] output_logprobs = output.get("logprobs") tool_calls = output.get("tool_calls") + reasoning = output.get("reasoning") if text: value = self.parse(processed_signature, text) @@ -109,6 +111,9 @@ def _call_postprocess( if output_logprobs: value["logprobs"] = output_logprobs + if reasoning: + value["reasoning"] = reasoning + values.append(value) return values diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 6f86da5632..fb0a338a2a 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -179,6 +179,10 @@ def _process_completion(self, response, merged_kwargs): for c in response.choices: output = {} output["text"] = c.message.content if hasattr(c, "message") else c["text"] + + if hasattr(c, "message") and hasattr(c.message, "reasoning_content") and c.message.reasoning_content: + output["reasoning"] = c.message.reasoning_content + if merged_kwargs.get("logprobs"): output["logprobs"] = c.logprobs if hasattr(c, "logprobs") else c["logprobs"] if hasattr(c, "message") and getattr(c.message, "tool_calls", None): @@ -203,12 +207,37 @@ def _process_response(self, response): """ outputs = [] tool_calls = [] + reasoning_content = None + for output_item in response.output: if output_item.type == "message": for content_item in output_item.content: outputs.append(content_item.text) elif output_item.type == "function_call": tool_calls.append(output_item.model_dump()) + elif output_item.type == "reasoning": + if hasattr(output_item, 'content') and output_item.content: + reasoning_content = output_item.content + elif hasattr(output_item, 'summary') and output_item.summary: + if isinstance(output_item.summary, list): + summary_texts = [] + for summary_item in output_item.summary: + if hasattr(summary_item, 'text'): + summary_texts.append(summary_item.text) + reasoning_content = "\n\n".join(summary_texts) if summary_texts else output_item.summary + else: + reasoning_content = output_item.summary + + if len(outputs) == 1 and isinstance(outputs[0], str): + result = {"text": outputs[0]} + if reasoning_content: + result["reasoning"] = reasoning_content + outputs = [result] + elif reasoning_content: + if outputs and isinstance(outputs[0], str): + outputs[0] = {"text": outputs[0], "reasoning": reasoning_content} + elif outputs and isinstance(outputs[0], dict): + outputs[0]["reasoning"] = reasoning_content if tool_calls: outputs.append({"tool_calls": tool_calls}) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 3c133a7c03..420652bf83 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -103,6 +103,19 @@ def __init__( self._warn_zero_temp_rollout(self.kwargs.get("temperature"), self.kwargs.get("rollout_id")) + # Set flag if model supports native reasoning AND user specified any reasoning parameter + reasoning_params = ['reasoning_effort', 'reasoning', 'thinking'] # Common reasoning parameter names + has_reasoning_param = any(param in self.kwargs for param in reasoning_params) + if litellm.supports_reasoning(self.model) and has_reasoning_param: + settings.use_native_reasoning = True + + # Normalize reasoning_effort to get reasoning summaries (for OpenAI reasoning models which don't expose reasoning content) + if ('reasoning_effort' in self.kwargs and + (self.model_type == "responses" or + ('openai/' in self.model.lower() and litellm.supports_reasoning(self.model)))): + effort = self.kwargs.pop('reasoning_effort') + self.kwargs['reasoning'] = {'effort': effort, 'summary': 'auto'} + def _warn_zero_temp_rollout(self, temperature: float | None, rollout_id): if ( not self._warned_zero_temp_rollout diff --git a/dspy/predict/chain_of_thought.py b/dspy/predict/chain_of_thought.py index 96afef8588..cddf034c9a 100644 --- a/dspy/predict/chain_of_thought.py +++ b/dspy/predict/chain_of_thought.py @@ -5,6 +5,7 @@ import dspy from dspy.primitives.module import Module from dspy.signatures.signature import Signature, ensure_signature +from dspy.dsp.utils.settings import settings class ChainOfThought(Module): @@ -26,12 +27,16 @@ def __init__( """ super().__init__() signature = ensure_signature(signature) - prefix = "Reasoning: Let's think step by step in order to" - desc = "${reasoning}" - rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type - rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc) - extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type) - self.predict = dspy.Predict(extended_signature, **config) + + if getattr(settings, 'use_native_reasoning', False): + self.predict = dspy.Predict(signature, **config) + else: + prefix = "Reasoning: Let's think step by step in order to" + desc = "${reasoning}" + rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type + rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc) + extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type) + self.predict = dspy.Predict(extended_signature, **config) def forward(self, **kwargs): return self.predict(**kwargs) diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index 54068e512d..92f9decdf5 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -13,6 +13,7 @@ import dspy from dspy.utils.usage_tracker import track_usage +from dspy.dsp.utils.settings import settings def make_response(output_blocks): @@ -551,3 +552,115 @@ def test_responses_api_tool_calls(litellm_test_server): dspy_responses.assert_called_once() assert dspy_responses.call_args.kwargs["model"] == "openai/dspy-test-model" + + +def test_reasoning_effort_normalization(): + """Test that reasoning_effort gets normalized to reasoning format for OpenAI models.""" + with mock.patch("litellm.supports_reasoning", return_value=True): + # OpenAI model with Responses API - should normalize + lm1 = dspy.LM( + model="openai/gpt-5", + model_type="responses", + reasoning_effort="low", + max_tokens=16000, + temperature=1.0 + ) + assert "reasoning_effort" not in lm1.kwargs + assert lm1.kwargs["reasoning"] == {"effort": "low", "summary": "auto"} + + # OpenAI model with Chat API - should normalize + lm2 = dspy.LM( + model="openai/gpt-5", + reasoning_effort="medium", + max_tokens=16000, + temperature=1.0 + ) + assert "reasoning_effort" not in lm2.kwargs + assert lm2.kwargs["reasoning"] == {"effort": "medium", "summary": "auto"} + + # Non-OpenAI model - should NOT normalize + lm3 = dspy.LM( + model="deepseek-ai/DeepSeek-R1", + reasoning_effort="low", + max_tokens=4000, + temperature=0.7 + ) + assert "reasoning_effort" in lm3.kwargs + assert "reasoning" not in lm3.kwargs + + +@mock.patch("litellm.supports_reasoning") +@mock.patch("dspy.dsp.utils.settings") +def test_native_reasoning_flag_setting(mock_settings, mock_supports): + """Test that use_native_reasoning flag is set correctly.""" + mock_supports.return_value = True + + # Should set flag when model supports reasoning and has reasoning param + dspy.LM(model="openai/gpt-5", reasoning_effort="low", max_tokens=16000, temperature=1.0) + mock_settings.use_native_reasoning = True + + mock_supports.return_value = False + + # Should NOT set flag when model doesn't support reasoning + dspy.LM(model="openai/gpt-4", reasoning_effort="low", max_tokens=1000, temperature=0.7) + + +def test_reasoning_content_extraction(): + """Test that reasoning models can be created with proper configuration.""" + # Test that reasoning models are properly configured + lm = dspy.LM( + model="openai/gpt-5", + model_type="responses", + max_tokens=16000, + temperature=1.0, + reasoning_effort="low" + ) + + # Verify reasoning parameters are normalized + assert "reasoning" in lm.kwargs + assert lm.kwargs["reasoning"]["effort"] == "low" + assert "max_completion_tokens" in lm.kwargs + assert lm.kwargs["max_completion_tokens"] == 16000 + + +def test_chain_of_thought_with_native_reasoning(): + """Test ChainOfThought with native reasoning vs manual reasoning.""" + + class SimpleSignature(dspy.Signature): + """Answer the question.""" + question: str = dspy.InputField() + answer: str = dspy.OutputField() + + # Test with native reasoning enabled + settings.use_native_reasoning = True + with mock.patch("dspy.Predict") as mock_predict: + mock_predict_instance = mock.MagicMock() + mock_predict_instance.return_value = dspy.Prediction(answer="42", reasoning="native reasoning") + mock_predict.return_value = mock_predict_instance + + cot = dspy.ChainOfThought(SimpleSignature) + result = cot(question="What is the answer?") + + # Should use Predict with original signature (no reasoning field added) + mock_predict.assert_called_once() + call_args = mock_predict.call_args[0] + assert call_args[0] == SimpleSignature + assert hasattr(result, 'reasoning') + + # Reset and test with native reasoning disabled (traditional ChainOfThought) + settings.use_native_reasoning = False + with mock.patch("dspy.Predict") as mock_predict: + mock_predict_instance = mock.MagicMock() + mock_predict_instance.return_value = dspy.Prediction(reasoning="step by step...", answer="42") + mock_predict.return_value = mock_predict_instance + + cot = dspy.ChainOfThought(SimpleSignature) + result = cot(question="What is the answer?") + + # Should use Predict with extended signature (reasoning field added) + mock_predict.assert_called_once() + call_args = mock_predict.call_args[0] + # Check that signature was extended with reasoning field + extended_signature = call_args[0] + assert 'reasoning' in extended_signature.fields + assert hasattr(result, 'reasoning') From c699a1fc874434aa4c88a671b13c3fd94063538d Mon Sep 17 00:00:00 2001 From: arnavsinghvi11 Date: Wed, 3 Sep 2025 23:17:36 -0700 Subject: [PATCH 2/2] ruff and test --- dspy/clients/base_lm.py | 23 ++++++++--------- dspy/clients/lm.py | 12 ++++----- dspy/predict/chain_of_thought.py | 6 ++--- tests/clients/test_lm.py | 44 ++++++++++++++++---------------- 4 files changed, 41 insertions(+), 44 deletions(-) diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index fb0a338a2a..97a2f71e2c 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -179,10 +179,10 @@ def _process_completion(self, response, merged_kwargs): for c in response.choices: output = {} output["text"] = c.message.content if hasattr(c, "message") else c["text"] - + if hasattr(c, "message") and hasattr(c.message, "reasoning_content") and c.message.reasoning_content: output["reasoning"] = c.message.reasoning_content - + if merged_kwargs.get("logprobs"): output["logprobs"] = c.logprobs if hasattr(c, "logprobs") else c["logprobs"] if hasattr(c, "message") and getattr(c.message, "tool_calls", None): @@ -208,7 +208,7 @@ def _process_response(self, response): outputs = [] tool_calls = [] reasoning_content = None - + for output_item in response.output: if output_item.type == "message": for content_item in output_item.content: @@ -216,25 +216,22 @@ def _process_response(self, response): elif output_item.type == "function_call": tool_calls.append(output_item.model_dump()) elif output_item.type == "reasoning": - if hasattr(output_item, 'content') and output_item.content: + if hasattr(output_item, "content") and output_item.content: reasoning_content = output_item.content - elif hasattr(output_item, 'summary') and output_item.summary: + elif hasattr(output_item, "summary") and output_item.summary: if isinstance(output_item.summary, list): summary_texts = [] for summary_item in output_item.summary: - if hasattr(summary_item, 'text'): + if hasattr(summary_item, "text"): summary_texts.append(summary_item.text) reasoning_content = "\n\n".join(summary_texts) if summary_texts else output_item.summary else: reasoning_content = output_item.summary - if len(outputs) == 1 and isinstance(outputs[0], str): - result = {"text": outputs[0]} - if reasoning_content: - result["reasoning"] = reasoning_content - outputs = [result] - elif reasoning_content: - if outputs and isinstance(outputs[0], str): + if reasoning_content: + if len(outputs) == 1 and isinstance(outputs[0], str): + outputs = [{"text": outputs[0], "reasoning": reasoning_content}] + elif outputs and isinstance(outputs[0], str): outputs[0] = {"text": outputs[0], "reasoning": reasoning_content} elif outputs and isinstance(outputs[0], dict): outputs[0]["reasoning"] = reasoning_content diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 420652bf83..3ab39eb6a0 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -104,17 +104,17 @@ def __init__( self._warn_zero_temp_rollout(self.kwargs.get("temperature"), self.kwargs.get("rollout_id")) # Set flag if model supports native reasoning AND user specified any reasoning parameter - reasoning_params = ['reasoning_effort', 'reasoning', 'thinking'] # Common reasoning parameter names + reasoning_params = ["reasoning_effort", "reasoning", "thinking"] # Common reasoning parameter names has_reasoning_param = any(param in self.kwargs for param in reasoning_params) if litellm.supports_reasoning(self.model) and has_reasoning_param: settings.use_native_reasoning = True # Normalize reasoning_effort to get reasoning summaries (for OpenAI reasoning models which don't expose reasoning content) - if ('reasoning_effort' in self.kwargs and - (self.model_type == "responses" or - ('openai/' in self.model.lower() and litellm.supports_reasoning(self.model)))): - effort = self.kwargs.pop('reasoning_effort') - self.kwargs['reasoning'] = {'effort': effort, 'summary': 'auto'} + if ("reasoning_effort" in self.kwargs and + (self.model_type == "responses" or + ("openai/" in self.model.lower() and litellm.supports_reasoning(self.model)))): + effort = self.kwargs.pop("reasoning_effort") + self.kwargs["reasoning"] = {"effort": effort, "summary": "auto"} def _warn_zero_temp_rollout(self, temperature: float | None, rollout_id): if ( diff --git a/dspy/predict/chain_of_thought.py b/dspy/predict/chain_of_thought.py index cddf034c9a..2d9681407f 100644 --- a/dspy/predict/chain_of_thought.py +++ b/dspy/predict/chain_of_thought.py @@ -3,9 +3,9 @@ from pydantic.fields import FieldInfo import dspy +from dspy.dsp.utils.settings import settings from dspy.primitives.module import Module from dspy.signatures.signature import Signature, ensure_signature -from dspy.dsp.utils.settings import settings class ChainOfThought(Module): @@ -27,8 +27,8 @@ def __init__( """ super().__init__() signature = ensure_signature(signature) - - if getattr(settings, 'use_native_reasoning', False): + + if getattr(settings, "use_native_reasoning", False): self.predict = dspy.Predict(signature, **config) else: prefix = "Reasoning: Let's think step by step in order to" diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index 92f9decdf5..8756c922f1 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -12,8 +12,8 @@ from openai import RateLimitError import dspy -from dspy.utils.usage_tracker import track_usage from dspy.dsp.utils.settings import settings +from dspy.utils.usage_tracker import track_usage def make_response(output_blocks): @@ -559,7 +559,7 @@ def test_reasoning_effort_normalization(): with mock.patch("litellm.supports_reasoning", return_value=True): # OpenAI model with Responses API - should normalize lm1 = dspy.LM( - model="openai/gpt-5", + model="openai/gpt-5", model_type="responses", reasoning_effort="low", max_tokens=16000, @@ -568,10 +568,10 @@ def test_reasoning_effort_normalization(): assert "reasoning_effort" not in lm1.kwargs assert lm1.kwargs["reasoning"] == {"effort": "low", "summary": "auto"} - # OpenAI model with Chat API - should normalize + # OpenAI model with Chat API - should normalize lm2 = dspy.LM( model="openai/gpt-5", - reasoning_effort="medium", + reasoning_effort="medium", max_tokens=16000, temperature=1.0 ) @@ -594,13 +594,13 @@ def test_reasoning_effort_normalization(): def test_native_reasoning_flag_setting(mock_settings, mock_supports): """Test that use_native_reasoning flag is set correctly.""" mock_supports.return_value = True - + # Should set flag when model supports reasoning and has reasoning param dspy.LM(model="openai/gpt-5", reasoning_effort="low", max_tokens=16000, temperature=1.0) mock_settings.use_native_reasoning = True - + mock_supports.return_value = False - + # Should NOT set flag when model doesn't support reasoning dspy.LM(model="openai/gpt-4", reasoning_effort="low", max_tokens=1000, temperature=0.7) @@ -609,13 +609,13 @@ def test_reasoning_content_extraction(): """Test that reasoning models can be created with proper configuration.""" # Test that reasoning models are properly configured lm = dspy.LM( - model="openai/gpt-5", - model_type="responses", - max_tokens=16000, + model="openai/gpt-5", + model_type="responses", + max_tokens=16000, temperature=1.0, reasoning_effort="low" ) - + # Verify reasoning parameters are normalized assert "reasoning" in lm.kwargs assert lm.kwargs["reasoning"]["effort"] == "low" @@ -625,42 +625,42 @@ def test_reasoning_content_extraction(): def test_chain_of_thought_with_native_reasoning(): """Test ChainOfThought with native reasoning vs manual reasoning.""" - + class SimpleSignature(dspy.Signature): """Answer the question.""" question: str = dspy.InputField() answer: str = dspy.OutputField() - + # Test with native reasoning enabled settings.use_native_reasoning = True with mock.patch("dspy.Predict") as mock_predict: mock_predict_instance = mock.MagicMock() mock_predict_instance.return_value = dspy.Prediction(answer="42", reasoning="native reasoning") mock_predict.return_value = mock_predict_instance - + cot = dspy.ChainOfThought(SimpleSignature) result = cot(question="What is the answer?") - + # Should use Predict with original signature (no reasoning field added) mock_predict.assert_called_once() call_args = mock_predict.call_args[0] assert call_args[0] == SimpleSignature - assert hasattr(result, 'reasoning') - - # Reset and test with native reasoning disabled (traditional ChainOfThought) + assert hasattr(result, "reasoning") + + # Reset and test with native reasoning disabled (traditional ChainOfThought) settings.use_native_reasoning = False with mock.patch("dspy.Predict") as mock_predict: mock_predict_instance = mock.MagicMock() mock_predict_instance.return_value = dspy.Prediction(reasoning="step by step...", answer="42") mock_predict.return_value = mock_predict_instance - + cot = dspy.ChainOfThought(SimpleSignature) result = cot(question="What is the answer?") - + # Should use Predict with extended signature (reasoning field added) mock_predict.assert_called_once() call_args = mock_predict.call_args[0] # Check that signature was extended with reasoning field extended_signature = call_args[0] - assert 'reasoning' in extended_signature.fields - assert hasattr(result, 'reasoning') + assert "reasoning" in extended_signature.fields + assert hasattr(result, "reasoning")