From 36e3e1ee3aebe803e041608bba2b90a34f2d7737 Mon Sep 17 00:00:00 2001 From: SimFG <1142838399@qq.com> Date: Fri, 10 Oct 2025 03:57:41 +0000 Subject: [PATCH 1/5] feat: add bypass_n option to LangchainLLMWrapper for n-completion control --- src/ragas/llms/base.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py index 8e39cb582..e3f5a1f10 100644 --- a/src/ragas/llms/base.py +++ b/src/ragas/llms/base.py @@ -149,6 +149,7 @@ def __init__( is_finished_parser: t.Optional[t.Callable[[LLMResult], bool]] = None, cache: t.Optional[CacheInterface] = None, bypass_temperature: bool = False, + bypass_n: bool = False, ): super().__init__(cache=cache) self.langchain_llm = langchain_llm @@ -158,6 +159,8 @@ def __init__( self.is_finished_parser = is_finished_parser # Certain LLMs (e.g., OpenAI o1 series) do not support temperature self.bypass_temperature = bypass_temperature + # Certain Reason LLMs do not support n + self.bypass_n = bypass_n def is_finished(self, response: LLMResult) -> bool: """ @@ -225,7 +228,7 @@ def generate_text( old_temperature = self.langchain_llm.temperature # type: ignore self.langchain_llm.temperature = temperature # type: ignore - if is_multiple_completion_supported(self.langchain_llm): + if is_multiple_completion_supported(self.langchain_llm) && and not self.bypass_n: result = self.langchain_llm.generate_prompt( prompts=[prompt], n=n, @@ -278,7 +281,7 @@ async def agenerate_text( self.langchain_llm.temperature = temperature # type: ignore # handle n - if hasattr(self.langchain_llm, "n"): + if hasattr(self.langchain_llm, "n") and not self.bypass_n: self.langchain_llm.n = n # type: ignore result = await self.langchain_llm.agenerate_prompt( prompts=[prompt], From dfb6a725116b5a27fdc681be8ca23896f5191e24 Mon Sep 17 00:00:00 2001 From: SimFG <1142838399@qq.com> Date: Mon, 13 Oct 2025 10:12:42 +0800 Subject: [PATCH 2/5] Update src/ragas/llms/base.py Co-authored-by: Ani <5357586+anistark@users.noreply.github.com> --- src/ragas/llms/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py index e3f5a1f10..610b1efe7 100644 --- a/src/ragas/llms/base.py +++ b/src/ragas/llms/base.py @@ -228,7 +228,7 @@ def generate_text( old_temperature = self.langchain_llm.temperature # type: ignore self.langchain_llm.temperature = temperature # type: ignore - if is_multiple_completion_supported(self.langchain_llm) && and not self.bypass_n: + if is_multiple_completion_supported(self.langchain_llm) and not self.bypass_n: result = self.langchain_llm.generate_prompt( prompts=[prompt], n=n, From e796d3d47c21221dde10e327879b20b5873edc56 Mon Sep 17 00:00:00 2001 From: SimFG <1142838399@qq.com> Date: Mon, 13 Oct 2025 11:16:33 +0800 Subject: [PATCH 3/5] test(llm): add tests for bypass_n functionality in LangchainLLMWrapper Add comprehensive test cases to verify the behavior of bypass_n parameter in LangchainLLMWrapper. Tests cover both sync and async methods, default behavior, and interaction with multiple completion support. --- tests/unit/llms/test_llm.py | 184 +++++++++++++++++++++++++++++++++++- 1 file changed, 183 insertions(+), 1 deletion(-) diff --git a/tests/unit/llms/test_llm.py b/tests/unit/llms/test_llm.py index 059bfe875..ac9fd27ec 100644 --- a/tests/unit/llms/test_llm.py +++ b/tests/unit/llms/test_llm.py @@ -1,10 +1,13 @@ from __future__ import annotations import typing as t +from unittest.mock import MagicMock, patch +import pytest from langchain_core.outputs import Generation, LLMResult +from langchain_core.prompt_values import PromptValue -from ragas.llms.base import BaseRagasLLM +from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper if t.TYPE_CHECKING: from langchain_core.prompt_values import PromptValue @@ -38,3 +41,182 @@ async def agenerate_text( def is_finished(self, response: LLMResult) -> bool: return True + + +class MockLangchainLLM: + """Mock Langchain LLM for testing bypass_n functionality.""" + + def __init__(self): + self.n = None # This makes hasattr(self.langchain_llm, "n") return True + self.temperature = None + self.model_name = "mock-model" + + def generate_prompt(self, prompts, n=None, stop=None, callbacks=None): + # Track if n was passed to the method + self._n_passed = n + # Simulate the behavior where if n is passed, we return n generations per prompt + # If n is not passed, we return one generation per prompt + num_prompts = len(prompts) + if n is not None: + # If n is specified, return n generations for each prompt + generations = [[Generation(text="test response")] * n for _ in range(num_prompts)] + else: + # If n is not specified, return one generation per prompt + generations = [[Generation(text="test response")] for _ in range(num_prompts)] + return LLMResult(generations=generations) + + async def agenerate_prompt(self, prompts, n=None, stop=None, callbacks=None): + # Track if n was passed to the method + self._n_passed = n + # If n is not passed as parameter but self.n is set, use self.n + if n is None and hasattr(self, 'n') and self.n is not None: + n = self.n + # Simulate the behavior where if n is passed, we return n generations per prompt + # If n is not passed, we return one generation per prompt + num_prompts = len(prompts) + if n is not None: + # If n is specified, return n generations for each prompt + generations = [[Generation(text="test response")] * n for _ in range(num_prompts)] + else: + # If n is not specified, return one generation per prompt + generations = [[Generation(text="test response")] for _ in range(num_prompts)] + return LLMResult(generations=generations) + + +def create_mock_prompt(): + """Create a mock prompt for testing.""" + prompt = MagicMock(spec=PromptValue) + prompt.to_string.return_value = "test prompt" + return prompt + + +class TestLangchainLLMWrapperBypassN: + """Test bypass_n functionality in LangchainLLMWrapper.""" + + def test_bypass_n_true_sync_does_not_pass_n(self): + """Test that when bypass_n=True, n is not passed to underlying LLM in sync method.""" + mock_llm = MockLangchainLLM() + # Mock is_multiple_completion_supported to return True for this test + with patch('ragas.llms.base.is_multiple_completion_supported', return_value=True): + wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True) + prompt = create_mock_prompt() + + # Call generate_text with n=3 + result = wrapper.generate_text(prompt, n=3) + + # Verify that n was not passed to the underlying LLM + assert mock_llm._n_passed is None + # When bypass_n=True, the wrapper should duplicate prompts instead of passing n + # The result should still have 3 generations (created by duplicating prompts) + assert len(result.generations[0]) == 3 + + def test_bypass_n_false_sync_passes_n(self): + """Test that when bypass_n=False (default), n is passed to underlying LLM in sync method.""" + mock_llm = MockLangchainLLM() + # Mock is_multiple_completion_supported to return True for this test + with patch('ragas.llms.base.is_multiple_completion_supported', return_value=True): + wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=False) + prompt = create_mock_prompt() + + # Call generate_text with n=3 + result = wrapper.generate_text(prompt, n=3) + + # Verify that n was passed to the underlying LLM + assert mock_llm._n_passed == 3 + # Result should have 3 generations + assert len(result.generations[0]) == 3 + + @pytest.mark.asyncio + async def test_bypass_n_true_async_does_not_pass_n(self): + """Test that when bypass_n=True, n is not passed to underlying LLM in async method.""" + mock_llm = MockLangchainLLM() + wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True) + prompt = create_mock_prompt() + + # Call agenerate_text with n=3 + result = await wrapper.agenerate_text(prompt, n=3) + + # Verify that n was not passed to the underlying LLM + assert mock_llm._n_passed is None + # When bypass_n=True, the wrapper should duplicate prompts instead of passing n + # The result should still have 3 generations (created by duplicating prompts) + assert len(result.generations[0]) == 3 + + @pytest.mark.asyncio + async def test_bypass_n_false_async_passes_n(self): + """Test that when bypass_n=False (default), n is passed to underlying LLM in async method.""" + mock_llm = MockLangchainLLM() + wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=False) + prompt = create_mock_prompt() + + # Call agenerate_text with n=3 + result = await wrapper.agenerate_text(prompt, n=3) + + # Verify that n was passed to the underlying LLM (via n attribute) + assert mock_llm.n == 3 + # Result should have 3 generations + assert len(result.generations[0]) == 3 + + def test_default_bypass_n_behavior(self): + """Test that default behavior (bypass_n=False) remains unchanged.""" + mock_llm = MockLangchainLLM() + # Mock is_multiple_completion_supported to return True for this test + with patch('ragas.llms.base.is_multiple_completion_supported', return_value=True): + # Create wrapper without explicitly setting bypass_n (should default to False) + wrapper = LangchainLLMWrapper(langchain_llm=mock_llm) + prompt = create_mock_prompt() + + # Call generate_text with n=2 + result = wrapper.generate_text(prompt, n=2) + + # Verify that n was passed to the underlying LLM (default behavior) + assert mock_llm._n_passed == 2 + assert len(result.generations[0]) == 2 + + @pytest.mark.asyncio + async def test_default_bypass_n_behavior_async(self): + """Test that default behavior (bypass_n=False) remains unchanged in async method.""" + mock_llm = MockLangchainLLM() + # Create wrapper without explicitly setting bypass_n (should default to False) + wrapper = LangchainLLMWrapper(langchain_llm=mock_llm) + prompt = create_mock_prompt() + + # Call agenerate_text with n=2 + result = await wrapper.agenerate_text(prompt, n=2) + + # Verify that n was passed to the underlying LLM (default behavior) + assert mock_llm.n == 2 + assert len(result.generations[0]) == 2 + + def test_bypass_n_true_with_multiple_completion_supported(self): + """Test bypass_n=True with LLM that supports multiple completions.""" + # Create a mock LLM that would normally support multiple completions + mock_llm = MockLangchainLLM() + # Mock the is_multiple_completion_supported to return True for this test + with patch('ragas.llms.base.is_multiple_completion_supported', return_value=True): + wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True) + prompt = create_mock_prompt() + + # Call generate_text with n=3 + result = wrapper.generate_text(prompt, n=3) + + # Verify that n was not passed to the underlying LLM due to bypass_n=True + assert mock_llm._n_passed is None + # Result should still have 3 generations (created by duplicating prompts) + assert len(result.generations[0]) == 3 + + @pytest.mark.asyncio + async def test_bypass_n_true_with_multiple_completion_supported_async(self): + """Test bypass_n=True with LLM that supports multiple completions in async method.""" + mock_llm = MockLangchainLLM() + with patch('ragas.llms.base.is_multiple_completion_supported', return_value=True): + wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True) + prompt = create_mock_prompt() + + # Call agenerate_text with n=3 + result = await wrapper.agenerate_text(prompt, n=3) + + # Verify that n was not passed to the underlying LLM due to bypass_n=True + assert mock_llm._n_passed is None + # Result should still have 3 generations + assert len(result.generations[0]) == 3 From f0553e583212da71341c76b1ba681c7a6855dafd Mon Sep 17 00:00:00 2001 From: SimFG <1142838399@qq.com> Date: Tue, 14 Oct 2025 10:27:36 +0800 Subject: [PATCH 4/5] test(format):fix the test_llm.py lint error --- tests/unit/llms/test_llm.py | 94 ++++++++++++++++++++++--------------- 1 file changed, 56 insertions(+), 38 deletions(-) diff --git a/tests/unit/llms/test_llm.py b/tests/unit/llms/test_llm.py index ac9fd27ec..cc453c4b2 100644 --- a/tests/unit/llms/test_llm.py +++ b/tests/unit/llms/test_llm.py @@ -45,12 +45,12 @@ def is_finished(self, response: LLMResult) -> bool: class MockLangchainLLM: """Mock Langchain LLM for testing bypass_n functionality.""" - + def __init__(self): self.n = None # This makes hasattr(self.langchain_llm, "n") return True self.temperature = None self.model_name = "mock-model" - + def generate_prompt(self, prompts, n=None, stop=None, callbacks=None): # Track if n was passed to the method self._n_passed = n @@ -59,27 +59,35 @@ def generate_prompt(self, prompts, n=None, stop=None, callbacks=None): num_prompts = len(prompts) if n is not None: # If n is specified, return n generations for each prompt - generations = [[Generation(text="test response")] * n for _ in range(num_prompts)] + generations = [ + [Generation(text="test response")] * n for _ in range(num_prompts) + ] else: # If n is not specified, return one generation per prompt - generations = [[Generation(text="test response")] for _ in range(num_prompts)] + generations = [ + [Generation(text="test response")] for _ in range(num_prompts) + ] return LLMResult(generations=generations) - + async def agenerate_prompt(self, prompts, n=None, stop=None, callbacks=None): - # Track if n was passed to the method + # Track if n was passed to the method self._n_passed = n # If n is not passed as parameter but self.n is set, use self.n - if n is None and hasattr(self, 'n') and self.n is not None: + if n is None and hasattr(self, "n") and self.n is not None: n = self.n # Simulate the behavior where if n is passed, we return n generations per prompt # If n is not passed, we return one generation per prompt num_prompts = len(prompts) if n is not None: # If n is specified, return n generations for each prompt - generations = [[Generation(text="test response")] * n for _ in range(num_prompts)] + generations = [ + [Generation(text="test response")] * n for _ in range(num_prompts) + ] else: # If n is not specified, return one generation per prompt - generations = [[Generation(text="test response")] for _ in range(num_prompts)] + generations = [ + [Generation(text="test response")] for _ in range(num_prompts) + ] return LLMResult(generations=generations) @@ -92,87 +100,93 @@ def create_mock_prompt(): class TestLangchainLLMWrapperBypassN: """Test bypass_n functionality in LangchainLLMWrapper.""" - + def test_bypass_n_true_sync_does_not_pass_n(self): """Test that when bypass_n=True, n is not passed to underlying LLM in sync method.""" mock_llm = MockLangchainLLM() # Mock is_multiple_completion_supported to return True for this test - with patch('ragas.llms.base.is_multiple_completion_supported', return_value=True): + with patch( + "ragas.llms.base.is_multiple_completion_supported", return_value=True + ): wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True) prompt = create_mock_prompt() - + # Call generate_text with n=3 result = wrapper.generate_text(prompt, n=3) - + # Verify that n was not passed to the underlying LLM assert mock_llm._n_passed is None # When bypass_n=True, the wrapper should duplicate prompts instead of passing n # The result should still have 3 generations (created by duplicating prompts) assert len(result.generations[0]) == 3 - + def test_bypass_n_false_sync_passes_n(self): """Test that when bypass_n=False (default), n is passed to underlying LLM in sync method.""" mock_llm = MockLangchainLLM() # Mock is_multiple_completion_supported to return True for this test - with patch('ragas.llms.base.is_multiple_completion_supported', return_value=True): + with patch( + "ragas.llms.base.is_multiple_completion_supported", return_value=True + ): wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=False) prompt = create_mock_prompt() - + # Call generate_text with n=3 result = wrapper.generate_text(prompt, n=3) - + # Verify that n was passed to the underlying LLM assert mock_llm._n_passed == 3 # Result should have 3 generations assert len(result.generations[0]) == 3 - + @pytest.mark.asyncio async def test_bypass_n_true_async_does_not_pass_n(self): """Test that when bypass_n=True, n is not passed to underlying LLM in async method.""" mock_llm = MockLangchainLLM() wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True) prompt = create_mock_prompt() - + # Call agenerate_text with n=3 result = await wrapper.agenerate_text(prompt, n=3) - + # Verify that n was not passed to the underlying LLM assert mock_llm._n_passed is None # When bypass_n=True, the wrapper should duplicate prompts instead of passing n # The result should still have 3 generations (created by duplicating prompts) assert len(result.generations[0]) == 3 - + @pytest.mark.asyncio async def test_bypass_n_false_async_passes_n(self): """Test that when bypass_n=False (default), n is passed to underlying LLM in async method.""" mock_llm = MockLangchainLLM() wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=False) prompt = create_mock_prompt() - + # Call agenerate_text with n=3 result = await wrapper.agenerate_text(prompt, n=3) - + # Verify that n was passed to the underlying LLM (via n attribute) assert mock_llm.n == 3 # Result should have 3 generations assert len(result.generations[0]) == 3 - + def test_default_bypass_n_behavior(self): """Test that default behavior (bypass_n=False) remains unchanged.""" mock_llm = MockLangchainLLM() # Mock is_multiple_completion_supported to return True for this test - with patch('ragas.llms.base.is_multiple_completion_supported', return_value=True): + with patch( + "ragas.llms.base.is_multiple_completion_supported", return_value=True + ): # Create wrapper without explicitly setting bypass_n (should default to False) wrapper = LangchainLLMWrapper(langchain_llm=mock_llm) prompt = create_mock_prompt() - + # Call generate_text with n=2 result = wrapper.generate_text(prompt, n=2) - + # Verify that n was passed to the underlying LLM (default behavior) assert mock_llm._n_passed == 2 assert len(result.generations[0]) == 2 - + @pytest.mark.asyncio async def test_default_bypass_n_behavior_async(self): """Test that default behavior (bypass_n=False) remains unchanged in async method.""" @@ -180,42 +194,46 @@ async def test_default_bypass_n_behavior_async(self): # Create wrapper without explicitly setting bypass_n (should default to False) wrapper = LangchainLLMWrapper(langchain_llm=mock_llm) prompt = create_mock_prompt() - + # Call agenerate_text with n=2 result = await wrapper.agenerate_text(prompt, n=2) - + # Verify that n was passed to the underlying LLM (default behavior) assert mock_llm.n == 2 assert len(result.generations[0]) == 2 - + def test_bypass_n_true_with_multiple_completion_supported(self): """Test bypass_n=True with LLM that supports multiple completions.""" # Create a mock LLM that would normally support multiple completions mock_llm = MockLangchainLLM() # Mock the is_multiple_completion_supported to return True for this test - with patch('ragas.llms.base.is_multiple_completion_supported', return_value=True): + with patch( + "ragas.llms.base.is_multiple_completion_supported", return_value=True + ): wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True) prompt = create_mock_prompt() - + # Call generate_text with n=3 result = wrapper.generate_text(prompt, n=3) - + # Verify that n was not passed to the underlying LLM due to bypass_n=True assert mock_llm._n_passed is None # Result should still have 3 generations (created by duplicating prompts) assert len(result.generations[0]) == 3 - + @pytest.mark.asyncio async def test_bypass_n_true_with_multiple_completion_supported_async(self): """Test bypass_n=True with LLM that supports multiple completions in async method.""" mock_llm = MockLangchainLLM() - with patch('ragas.llms.base.is_multiple_completion_supported', return_value=True): + with patch( + "ragas.llms.base.is_multiple_completion_supported", return_value=True + ): wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True) prompt = create_mock_prompt() - + # Call agenerate_text with n=3 result = await wrapper.agenerate_text(prompt, n=3) - + # Verify that n was not passed to the underlying LLM due to bypass_n=True assert mock_llm._n_passed is None # Result should still have 3 generations From bfdbaa30d8276ed0310a5793ca431e5bc136d0ea Mon Sep 17 00:00:00 2001 From: SimFG <1142838399@qq.com> Date: Tue, 14 Oct 2025 12:15:27 +0000 Subject: [PATCH 5/5] check the review comment --- src/ragas/llms/base.py | 2 +- tests/unit/llms/test_llm.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py index 610b1efe7..74d9b8cd7 100644 --- a/src/ragas/llms/base.py +++ b/src/ragas/llms/base.py @@ -159,7 +159,7 @@ def __init__( self.is_finished_parser = is_finished_parser # Certain LLMs (e.g., OpenAI o1 series) do not support temperature self.bypass_temperature = bypass_temperature - # Certain Reason LLMs do not support n + # Certain reasoning LLMs (e.g., OpenAI o1 series) do not support n parameter for self.bypass_n = bypass_n def is_finished(self, response: LLMResult) -> bool: diff --git a/tests/unit/llms/test_llm.py b/tests/unit/llms/test_llm.py index cc453c4b2..0002b0714 100644 --- a/tests/unit/llms/test_llm.py +++ b/tests/unit/llms/test_llm.py @@ -9,9 +9,6 @@ from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper -if t.TYPE_CHECKING: - from langchain_core.prompt_values import PromptValue - class FakeTestLLM(BaseRagasLLM): def llm(self):