Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,14 @@ def _call_postprocess(
for output in outputs:
output_logprobs = None
tool_calls = None
reasoning = None
text = output

if isinstance(output, dict):
text = output["text"]
output_logprobs = output.get("logprobs")
tool_calls = output.get("tool_calls")
reasoning = output.get("reasoning")

if text:
value = self.parse(processed_signature, text)
Expand All @@ -109,6 +111,9 @@ def _call_postprocess(
if output_logprobs:
value["logprobs"] = output_logprobs

if reasoning:
value["reasoning"] = reasoning

values.append(value)

return values
Expand Down
26 changes: 26 additions & 0 deletions dspy/clients/base_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ def _process_completion(self, response, merged_kwargs):
for c in response.choices:
output = {}
output["text"] = c.message.content if hasattr(c, "message") else c["text"]

if hasattr(c, "message") and hasattr(c.message, "reasoning_content") and c.message.reasoning_content:
output["reasoning"] = c.message.reasoning_content

if merged_kwargs.get("logprobs"):
output["logprobs"] = c.logprobs if hasattr(c, "logprobs") else c["logprobs"]
if hasattr(c, "message") and getattr(c.message, "tool_calls", None):
Expand All @@ -203,12 +207,34 @@ def _process_response(self, response):
"""
outputs = []
tool_calls = []
reasoning_content = None

for output_item in response.output:
if output_item.type == "message":
for content_item in output_item.content:
outputs.append(content_item.text)
elif output_item.type == "function_call":
tool_calls.append(output_item.model_dump())
elif output_item.type == "reasoning":
if hasattr(output_item, "content") and output_item.content:
reasoning_content = output_item.content
elif hasattr(output_item, "summary") and output_item.summary:
if isinstance(output_item.summary, list):
summary_texts = []
for summary_item in output_item.summary:
if hasattr(summary_item, "text"):
summary_texts.append(summary_item.text)
reasoning_content = "\n\n".join(summary_texts) if summary_texts else output_item.summary
else:
reasoning_content = output_item.summary

if reasoning_content:
if len(outputs) == 1 and isinstance(outputs[0], str):
outputs = [{"text": outputs[0], "reasoning": reasoning_content}]
elif outputs and isinstance(outputs[0], str):
outputs[0] = {"text": outputs[0], "reasoning": reasoning_content}
elif outputs and isinstance(outputs[0], dict):
outputs[0]["reasoning"] = reasoning_content

if tool_calls:
outputs.append({"tool_calls": tool_calls})
Expand Down
13 changes: 13 additions & 0 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,19 @@ def __init__(

self._warn_zero_temp_rollout(self.kwargs.get("temperature"), self.kwargs.get("rollout_id"))

# Set flag if model supports native reasoning AND user specified any reasoning parameter
reasoning_params = ["reasoning_effort", "reasoning", "thinking"] # Common reasoning parameter names
has_reasoning_param = any(param in self.kwargs for param in reasoning_params)
if litellm.supports_reasoning(self.model) and has_reasoning_param:
settings.use_native_reasoning = True

# Normalize reasoning_effort to get reasoning summaries (for OpenAI reasoning models which don't expose reasoning content)
if ("reasoning_effort" in self.kwargs and
(self.model_type == "responses" or
("openai/" in self.model.lower() and litellm.supports_reasoning(self.model)))):
effort = self.kwargs.pop("reasoning_effort")
self.kwargs["reasoning"] = {"effort": effort, "summary": "auto"}

def _warn_zero_temp_rollout(self, temperature: float | None, rollout_id):
if (
not self._warned_zero_temp_rollout
Expand Down
17 changes: 11 additions & 6 deletions dspy/predict/chain_of_thought.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic.fields import FieldInfo

import dspy
from dspy.dsp.utils.settings import settings
from dspy.primitives.module import Module
from dspy.signatures.signature import Signature, ensure_signature

Expand All @@ -26,12 +27,16 @@ def __init__(
"""
super().__init__()
signature = ensure_signature(signature)
prefix = "Reasoning: Let's think step by step in order to"
desc = "${reasoning}"
rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type
rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc)
extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type)
self.predict = dspy.Predict(extended_signature, **config)

if getattr(settings, "use_native_reasoning", False):
self.predict = dspy.Predict(signature, **config)
else:
prefix = "Reasoning: Let's think step by step in order to"
desc = "${reasoning}"
rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type
rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc)
extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type)
self.predict = dspy.Predict(extended_signature, **config)

def forward(self, **kwargs):
return self.predict(**kwargs)
Expand Down
113 changes: 113 additions & 0 deletions tests/clients/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from openai import RateLimitError

import dspy
from dspy.dsp.utils.settings import settings
from dspy.utils.usage_tracker import track_usage


Expand Down Expand Up @@ -551,3 +552,115 @@ def test_responses_api_tool_calls(litellm_test_server):

dspy_responses.assert_called_once()
assert dspy_responses.call_args.kwargs["model"] == "openai/dspy-test-model"


def test_reasoning_effort_normalization():
"""Test that reasoning_effort gets normalized to reasoning format for OpenAI models."""
with mock.patch("litellm.supports_reasoning", return_value=True):
# OpenAI model with Responses API - should normalize
lm1 = dspy.LM(
model="openai/gpt-5",
model_type="responses",
reasoning_effort="low",
max_tokens=16000,
temperature=1.0
)
assert "reasoning_effort" not in lm1.kwargs
assert lm1.kwargs["reasoning"] == {"effort": "low", "summary": "auto"}

# OpenAI model with Chat API - should normalize
lm2 = dspy.LM(
model="openai/gpt-5",
reasoning_effort="medium",
max_tokens=16000,
temperature=1.0
)
assert "reasoning_effort" not in lm2.kwargs
assert lm2.kwargs["reasoning"] == {"effort": "medium", "summary": "auto"}

# Non-OpenAI model - should NOT normalize
lm3 = dspy.LM(
model="deepseek-ai/DeepSeek-R1",
reasoning_effort="low",
max_tokens=4000,
temperature=0.7
)
assert "reasoning_effort" in lm3.kwargs
assert "reasoning" not in lm3.kwargs


@mock.patch("litellm.supports_reasoning")
@mock.patch("dspy.dsp.utils.settings")
def test_native_reasoning_flag_setting(mock_settings, mock_supports):
"""Test that use_native_reasoning flag is set correctly."""
mock_supports.return_value = True

# Should set flag when model supports reasoning and has reasoning param
dspy.LM(model="openai/gpt-5", reasoning_effort="low", max_tokens=16000, temperature=1.0)
mock_settings.use_native_reasoning = True

mock_supports.return_value = False

# Should NOT set flag when model doesn't support reasoning
dspy.LM(model="openai/gpt-4", reasoning_effort="low", max_tokens=1000, temperature=0.7)


def test_reasoning_content_extraction():
"""Test that reasoning models can be created with proper configuration."""
# Test that reasoning models are properly configured
lm = dspy.LM(
model="openai/gpt-5",
model_type="responses",
max_tokens=16000,
temperature=1.0,
reasoning_effort="low"
)

# Verify reasoning parameters are normalized
assert "reasoning" in lm.kwargs
assert lm.kwargs["reasoning"]["effort"] == "low"
assert "max_completion_tokens" in lm.kwargs
assert lm.kwargs["max_completion_tokens"] == 16000


def test_chain_of_thought_with_native_reasoning():
"""Test ChainOfThought with native reasoning vs manual reasoning."""

class SimpleSignature(dspy.Signature):
"""Answer the question."""
question: str = dspy.InputField()
answer: str = dspy.OutputField()

# Test with native reasoning enabled
settings.use_native_reasoning = True
with mock.patch("dspy.Predict") as mock_predict:
mock_predict_instance = mock.MagicMock()
mock_predict_instance.return_value = dspy.Prediction(answer="42", reasoning="native reasoning")
mock_predict.return_value = mock_predict_instance

cot = dspy.ChainOfThought(SimpleSignature)
result = cot(question="What is the answer?")

# Should use Predict with original signature (no reasoning field added)
mock_predict.assert_called_once()
call_args = mock_predict.call_args[0]
assert call_args[0] == SimpleSignature
assert hasattr(result, "reasoning")

# Reset and test with native reasoning disabled (traditional ChainOfThought)
settings.use_native_reasoning = False
with mock.patch("dspy.Predict") as mock_predict:
mock_predict_instance = mock.MagicMock()
mock_predict_instance.return_value = dspy.Prediction(reasoning="step by step...", answer="42")
mock_predict.return_value = mock_predict_instance

cot = dspy.ChainOfThought(SimpleSignature)
result = cot(question="What is the answer?")

# Should use Predict with extended signature (reasoning field added)
mock_predict.assert_called_once()
call_args = mock_predict.call_args[0]
# Check that signature was extended with reasoning field
extended_signature = call_args[0]
assert "reasoning" in extended_signature.fields
assert hasattr(result, "reasoning")