From cc26ba760deb40058f0c854888d8c5ef3ec5a9be Mon Sep 17 00:00:00 2001 From: Aaron Brown Date: Wed, 30 Jul 2025 16:24:20 -0500 Subject: [PATCH 01/11] Initial commit with llamacpp module; --- src/strands/models/__init__.py | 5 +- src/strands/models/llamacpp.py | 437 +++++++++++++++++++ tests/strands/models/test_llamacpp.py | 475 +++++++++++++++++++++ tests_integ/models/test_model_llamacpp.py | 495 ++++++++++++++++++++++ 4 files changed, 1410 insertions(+), 2 deletions(-) create mode 100644 src/strands/models/llamacpp.py create mode 100644 tests/strands/models/test_llamacpp.py create mode 100644 tests_integ/models/test_model_llamacpp.py diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index ead290a35..35036203f 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -3,8 +3,9 @@ This package includes an abstract base Model class along with concrete implementations for specific providers. """ -from . import bedrock, model +from . import bedrock, llamacpp, model from .bedrock import BedrockModel +from .llamacpp import LlamaCppModel from .model import Model -__all__ = ["bedrock", "model", "BedrockModel", "Model"] +__all__ = ["bedrock", "llamacpp", "model", "BedrockModel", "LlamaCppModel", "Model"] diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py new file mode 100644 index 000000000..59cb64656 --- /dev/null +++ b/src/strands/models/llamacpp.py @@ -0,0 +1,437 @@ +"""llama.cpp model provider. + +- Docs: https://github.com/ggml-org/llama.cpp +- Server docs: https://github.com/ggml-org/llama.cpp/tree/master/tools/server +""" + +import json +import logging +from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union + +import httpx +from pydantic import BaseModel +from typing_extensions import Unpack, override + +from ..types.content import Messages +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException +from ..types.streaming import StreamEvent +from ..types.tools import ToolSpec +from .openai import OpenAIModel + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class LlamaCppError(Exception): + """Base exception for llama.cpp specific errors.""" + pass + + +class LlamaCppContextOverflowError(LlamaCppError, ContextWindowOverflowException): + """Raised when context window is exceeded in llama.cpp.""" + pass + + +class LlamaCppModel(OpenAIModel): + """llama.cpp model provider implementation. + + Connects to a llama.cpp server running in OpenAI-compatible mode with + support for advanced llama.cpp-specific features like grammar constraints, + Mirostat sampling, and native JSON schema validation. + + The llama.cpp server must be started with the OpenAI-compatible API enabled: + llama-server -m model.gguf --host 0.0.0.0 --port 8080 + + Example: + Basic usage: + >>> model = LlamaCppModel(base_url="http://localhost:8080/v1") + >>> model.update_config(params={"temperature": 0.7, "top_k": 40}) + + Grammar constraints: + >>> model.use_grammar_constraint(''' + ... root ::= answer + ... answer ::= "yes" | "no" + ... ''') + + Advanced sampling: + >>> model.update_config(params={ + ... "mirostat": 2, + ... "mirostat_lr": 0.1, + ... "tfs_z": 0.95, + ... "repeat_penalty": 1.1 + ... }) + """ + + class LlamaCppConfig(TypedDict, total=False): + """Configuration options for llama.cpp models. + + Attributes: + model_id: Model identifier for the loaded model in llama.cpp server. + Default is "default" as llama.cpp typically loads a single model. + params: Model parameters supporting both OpenAI and llama.cpp-specific options. + + OpenAI-compatible parameters: + - max_tokens: Maximum number of tokens to generate + - temperature: Sampling temperature (0.0 to 2.0) + - top_p: Nucleus sampling parameter (0.0 to 1.0) + - frequency_penalty: Frequency penalty (-2.0 to 2.0) + - presence_penalty: Presence penalty (-2.0 to 2.0) + - stop: List of stop sequences + - seed: Random seed for reproducibility + - n: Number of completions to generate + - logprobs: Include log probabilities in output + - top_logprobs: Number of top log probabilities to include + + llama.cpp-specific parameters: + - repeat_penalty: Penalize repeat tokens (1.0 = no penalty) + - top_k: Top-k sampling (0 = disabled) + - min_p: Min-p sampling threshold (0.0 to 1.0) + - typical_p: Typical-p sampling (0.0 to 1.0) + - tfs_z: Tail-free sampling parameter (0.0 to 1.0) + - top_a: Top-a sampling parameter + - mirostat: Mirostat sampling mode (0, 1, or 2) + - mirostat_lr: Mirostat learning rate + - mirostat_ent: Mirostat target entropy + - grammar: GBNF grammar string for constrained generation + - json_schema: JSON schema for structured output + - penalty_last_n: Number of tokens to consider for penalties + - n_probs: Number of probabilities to return per token + - min_keep: Minimum tokens to keep in sampling + - ignore_eos: Ignore end-of-sequence token + - logit_bias: Token ID to bias mapping + - cache_prompt: Cache the prompt for faster generation + - slot_id: Slot ID for parallel inference + - samplers: Custom sampler order + """ + + model_id: str + params: Optional[dict[str, Any]] + + def __init__( + self, + base_url: str = "http://localhost:8080/v1", + api_key: Optional[str] = None, + timeout: Optional[Union[float, tuple[float, float]]] = None, + max_retries: Optional[int] = None, + **model_config: Unpack[LlamaCppConfig], + ) -> None: + """Initialize llama.cpp provider instance. + + Args: + base_url: Base URL for the llama.cpp server. + Default is "http://localhost:8080/v1" for local server. + api_key: Optional API key if the llama.cpp server requires authentication. + timeout: Request timeout in seconds. Can be a float or tuple of (connect, read) timeouts. + max_retries: Maximum number of retries for failed requests. + **model_config: Configuration options for the llama.cpp model. + """ + # Set default model_id if not provided + if "model_id" not in model_config: + model_config["model_id"] = "default" + + # Build OpenAI client args + client_args = { + "base_url": base_url, + "api_key": api_key or "dummy", # OpenAI client requires some API key + } + + if timeout is not None: + client_args["timeout"] = timeout + + if max_retries is not None: + client_args["max_retries"] = max_retries + + logger.debug( + "base_url=<%s>, model_id=<%s> | initializing llama.cpp provider", + base_url, + model_config.get("model_id"), + ) + + # Initialize parent OpenAI model with our client args + super().__init__(client_args=client_args, **model_config) + + def use_grammar_constraint(self, grammar: str) -> None: + """Apply a GBNF grammar constraint to the generation. + + Args: + grammar: GBNF (Backus-Naur Form) grammar string defining allowed outputs. + See https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + + Example: + >>> # Constrain output to yes/no answers + >>> model.use_grammar_constraint(''' + ... root ::= answer + ... answer ::= "yes" | "no" + ... ''') + + >>> # JSON object grammar + >>> model.use_grammar_constraint(''' + ... root ::= object + ... object ::= "{" pair ("," pair)* "}" + ... pair ::= string ":" value + ... string ::= "\\"" [^"]* "\\"" + ... value ::= string | number | "true" | "false" | "null" + ... number ::= "-"? [0-9]+ ("." [0-9]+)? + ... ''') + """ + if not self.config.get("params"): + self.config["params"] = {} + self.config["params"]["grammar"] = grammar + logger.debug("Applied grammar constraint") + + def use_json_schema(self, schema: dict[str, Any]) -> None: + """Apply a JSON schema constraint for structured output. + + Args: + schema: JSON schema dictionary defining the expected output structure. + + Example: + >>> model.use_json_schema({ + ... "type": "object", + ... "properties": { + ... "name": {"type": "string"}, + ... "age": {"type": "integer", "minimum": 0} + ... }, + ... "required": ["name", "age"] + ... }) + """ + if not self.config.get("params"): + self.config["params"] = {} + self.config["params"]["json_schema"] = schema + logger.debug("Applied JSON schema constraint") + + @override + def format_request( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + ) -> dict[str, Any]: + """Format a request for the llama.cpp server. + + This method overrides the OpenAI format to properly handle llama.cpp-specific parameters. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + A request formatted for llama.cpp server's OpenAI-compatible API. + """ + # Build base request structure without calling super() to avoid + # parameter conflicts between OpenAI and llama.cpp specific params. + # This allows us to properly separate parameters into the appropriate + # request fields (direct vs extra_body). + request = { + "messages": self.format_request_messages(messages, system_prompt), + "model": self.config["model_id"], + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "type": "function", + "function": { + "name": tool_spec["name"], + "description": tool_spec["description"], + "parameters": tool_spec["inputSchema"]["json"], + }, + } + for tool_spec in tool_specs or [] + ], + } + + # Handle parameters if provided + if self.config.get("params"): + params = self.config["params"] + + # Define llama.cpp-specific parameters that need special handling + llamacpp_specific_params = { + "repeat_penalty", + "top_k", + "min_p", + "typical_p", + "tfs_z", + "top_a", + "mirostat", + "mirostat_lr", + "mirostat_ent", + "grammar", + "json_schema", + "penalty_last_n", + "n_probs", + "min_keep", + "ignore_eos", + "logit_bias", + "cache_prompt", + "slot_id", + "samplers", + } + + # Standard OpenAI parameters that go directly in request + openai_params = { + "temperature", + "max_tokens", + "top_p", + "frequency_penalty", + "presence_penalty", + "stop", + "seed", + "n", + "logprobs", + "top_logprobs", + "response_format", + } + + # Add OpenAI parameters directly to request + for param, value in params.items(): + if param in openai_params: + request[param] = value + + # Collect llama.cpp-specific parameters for extra_body + extra_body = {} + for param, value in params.items(): + if param in llamacpp_specific_params: + extra_body[param] = value + + # Add extra_body if we have llama.cpp-specific parameters + if extra_body: + request["extra_body"] = extra_body + + return request + + @override + async def stream( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Stream conversation with the llama.cpp model. + + This method extends the OpenAI stream to handle llama.cpp-specific errors. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Formatted message chunks from the model. + + Raises: + LlamaCppContextOverflowError: When the context window is exceeded. + ModelThrottledException: When the llama.cpp server is overloaded. + """ + try: + async for event in super().stream(messages, tool_specs, system_prompt, **kwargs): + yield event + except httpx.HTTPStatusError as e: + if e.response.status_code == 400: + # Parse error response + try: + error_data = e.response.json() + error_msg = str(error_data.get("error", {}).get("message", str(error_data))) + except (json.JSONDecodeError, KeyError, AttributeError): + error_msg = e.response.text + + # Check for context overflow + if any(term in error_msg.lower() for term in ["context", "kv cache", "slot"]): + raise LlamaCppContextOverflowError( + f"Context window exceeded: {error_msg}" + ) from e + elif e.response.status_code == 503: + raise ModelThrottledException( + "llama.cpp server is busy or overloaded" + ) from e + raise + except Exception as e: + # Handle other potential errors + error_msg = str(e).lower() + if "rate" in error_msg or "429" in str(e): + raise ModelThrottledException(str(e)) from e + raise + + @override + async def structured_output( + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + """Get structured output using llama.cpp's native JSON schema support. + + This implementation uses llama.cpp's json_schema parameter to constrain + the model output to valid JSON matching the provided schema. + + Args: + output_model: The Pydantic model defining the expected output structure. + prompt: The prompt messages to use for generation. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Model events with the last being the structured output. + + Raises: + json.JSONDecodeError: If the model output is not valid JSON. + pydantic.ValidationError: If the output doesn't match the model schema. + """ + # Get the JSON schema from the Pydantic model + schema = output_model.model_json_schema() + + # Store current params to restore later + original_params = self.config.get("params", {}).copy() + + try: + # Configure for JSON output with schema constraint + if not self.config.get("params"): + self.config["params"] = {} + + self.config["params"]["json_schema"] = schema + self.config["params"]["cache_prompt"] = True # Cache schema processing + + # Collect the response + response_text = "" + async for event in self.stream(prompt, system_prompt=system_prompt, **kwargs): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + # Pass through other events + yield event + + # Parse and validate the JSON response + data = json.loads(response_text.strip()) + output_instance = output_model(**data) + yield {"output": output_instance} + + finally: + # Restore original params + self.config["params"] = original_params + + def _generate_pydantic_grammar(self, model: Type[BaseModel]) -> str: + """Generate a GBNF grammar from a Pydantic model. + + Args: + model: The Pydantic model to generate grammar for. + + Returns: + GBNF grammar string. + + Note: + This provides a basic JSON grammar. A future enhancement would + generate model-specific grammars based on the Pydantic schema. + """ + # Basic JSON grammar that works for most cases + return ''' +root ::= object +object ::= "{" pair ("," pair)* "}" +pair ::= string ":" value +string ::= "\\"" [^"]* "\\"" +value ::= string | number | boolean | null | array | object +array ::= "[" (value ("," value)*)? "]" +number ::= "-"? [0-9]+ ("." [0-9]+)? +boolean ::= "true" | "false" +null ::= "null" +''' diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py new file mode 100644 index 000000000..829452b0a --- /dev/null +++ b/tests/strands/models/test_llamacpp.py @@ -0,0 +1,475 @@ +"""Unit tests for llama.cpp model provider.""" + +import json +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from openai import AsyncOpenAI +from pydantic import BaseModel + +from strands.types.content import ContentBlock, Message +from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException +from strands.models.llamacpp import LlamaCppModel, LlamaCppError, LlamaCppContextOverflowError + + +class TestLlamaCppModel: + """Test suite for LlamaCppModel.""" + + def test_init_default_config(self) -> None: + """Test initialization with default configuration.""" + model = LlamaCppModel() + + assert model.config["model_id"] == "default" + assert isinstance(model.client, AsyncOpenAI) + # Check that base_url was set correctly + assert model.client.base_url == "http://localhost:8080/v1/" + + def test_init_custom_config(self) -> None: + """Test initialization with custom configuration.""" + model = LlamaCppModel( + base_url="http://example.com:8081/v1", + model_id="llama-3-8b", + params={"temperature": 0.7, "max_tokens": 100}, + ) + + assert model.config["model_id"] == "llama-3-8b" + assert model.config["params"]["temperature"] == 0.7 + assert model.config["params"]["max_tokens"] == 100 + assert model.client.base_url == "http://example.com:8081/v1/" + + def test_format_request_basic(self) -> None: + """Test basic request formatting.""" + model = LlamaCppModel(model_id="test-model") + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + request = model.format_request(messages) + + assert request["model"] == "test-model" + assert request["messages"][0]["role"] == "user" + # OpenAI format returns content as an array + assert request["messages"][0]["content"][0]["type"] == "text" + assert request["messages"][0]["content"][0]["text"] == "Hello" + assert request["stream"] is True + assert "extra_body" not in request # No llama.cpp params, so no extra_body + + def test_format_request_with_system_prompt(self) -> None: + """Test request formatting with system prompt.""" + model = LlamaCppModel() + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + request = model.format_request(messages, system_prompt="You are a helpful assistant") + + assert request["messages"][0]["role"] == "system" + assert request["messages"][0]["content"] == "You are a helpful assistant" + assert request["messages"][1]["role"] == "user" + + def test_format_request_with_llamacpp_params(self) -> None: + """Test request formatting with llama.cpp specific parameters.""" + model = LlamaCppModel( + params={ + "temperature": 0.8, + "max_tokens": 50, + "repeat_penalty": 1.1, + "top_k": 40, + "min_p": 0.05, + "grammar": "root ::= 'yes' | 'no'", + } + ) + + messages = [ + {"role": "user", "content": [{"text": "Is the sky blue?"}]}, + ] + + request = model.format_request(messages) + + # Standard OpenAI params + assert request["temperature"] == 0.8 + assert request["max_tokens"] == 50 + + # llama.cpp specific params should be in extra_body + assert "extra_body" in request + assert request["extra_body"]["repeat_penalty"] == 1.1 + assert request["extra_body"]["top_k"] == 40 + assert request["extra_body"]["min_p"] == 0.05 + assert request["extra_body"]["grammar"] == "root ::= 'yes' | 'no'" + + def test_format_request_with_all_new_params(self) -> None: + """Test request formatting with all new llama.cpp parameters.""" + model = LlamaCppModel( + params={ + # OpenAI params + "temperature": 0.7, + "max_tokens": 100, + "top_p": 0.9, + "seed": 42, + # All llama.cpp specific params + "repeat_penalty": 1.1, + "top_k": 40, + "min_p": 0.05, + "typical_p": 0.95, + "tfs_z": 0.97, + "top_a": 0.1, + "mirostat": 2, + "mirostat_lr": 0.1, + "mirostat_ent": 5.0, + "grammar": "root ::= answer", + "json_schema": {"type": "object"}, + "penalty_last_n": 256, + "n_probs": 5, + "min_keep": 1, + "ignore_eos": False, + "logit_bias": {100: 5.0, 200: -5.0}, + "cache_prompt": True, + "slot_id": 1, + "samplers": ["top_k", "tfs_z", "typical_p"], + } + ) + + messages = [{"role": "user", "content": [{"text": "Test"}]}] + request = model.format_request(messages) + + # Check OpenAI params are in root + assert request["temperature"] == 0.7 + assert request["max_tokens"] == 100 + assert request["top_p"] == 0.9 + assert request["seed"] == 42 + + # Check all llama.cpp params are in extra_body + assert "extra_body" in request + extra = request["extra_body"] + assert extra["repeat_penalty"] == 1.1 + assert extra["top_k"] == 40 + assert extra["min_p"] == 0.05 + assert extra["typical_p"] == 0.95 + assert extra["tfs_z"] == 0.97 + assert extra["top_a"] == 0.1 + assert extra["mirostat"] == 2 + assert extra["mirostat_lr"] == 0.1 + assert extra["mirostat_ent"] == 5.0 + assert extra["grammar"] == "root ::= answer" + assert extra["json_schema"] == {"type": "object"} + assert extra["penalty_last_n"] == 256 + assert extra["n_probs"] == 5 + assert extra["min_keep"] == 1 + assert extra["ignore_eos"] == False + assert extra["logit_bias"] == {100: 5.0, 200: -5.0} + assert extra["cache_prompt"] == True + assert extra["slot_id"] == 1 + assert extra["samplers"] == ["top_k", "tfs_z", "typical_p"] + + def test_format_request_with_tools(self) -> None: + """Test request formatting with tool specifications.""" + model = LlamaCppModel() + + messages = [ + {"role": "user", "content": [{"text": "What's the weather?"}]}, + ] + + tool_specs = [ + { + "name": "get_weather", + "description": "Get current weather", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "location": {"type": "string"}, + }, + "required": ["location"], + } + }, + } + ] + + request = model.format_request(messages, tool_specs=tool_specs) + + assert "tools" in request + assert len(request["tools"]) == 1 + assert request["tools"][0]["function"]["name"] == "get_weather" + + def test_update_config(self) -> None: + """Test configuration update.""" + model = LlamaCppModel(model_id="initial-model") + + assert model.config["model_id"] == "initial-model" + + model.update_config(model_id="updated-model", params={"temperature": 0.5}) + + assert model.config["model_id"] == "updated-model" + assert model.config["params"]["temperature"] == 0.5 + + def test_get_config(self) -> None: + """Test configuration retrieval.""" + config = { + "model_id": "test-model", + "params": {"temperature": 0.9}, + } + model = LlamaCppModel(**config) + + retrieved_config = model.get_config() + + assert retrieved_config["model_id"] == "test-model" + assert retrieved_config["params"]["temperature"] == 0.9 + + @pytest.mark.asyncio + async def test_stream_basic(self) -> None: + """Test basic streaming functionality.""" + model = LlamaCppModel() + + # Create properly structured mock events + class MockDelta: + content = None + tool_calls = None + def __init__(self, content=None): + self.content = content + + class MockChoice: + def __init__(self, content=None, finish_reason=None): + self.delta = MockDelta(content) + self.finish_reason = finish_reason + + class MockChunk: + def __init__(self, choices, usage=None): + self.choices = choices + self.usage = usage + + mock_chunks = [ + MockChunk([MockChoice(content="Hello")]), + MockChunk( + [MockChoice(content=" world", finish_reason="stop")], + usage=MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15) + ), + ] + + # Create async iterator + async def mock_stream(): + for chunk in mock_chunks: + yield chunk + + # Mock the create method to return a coroutine that returns the async iterator + async def mock_create(*args, **kwargs): + return mock_stream() + + with patch.object(model.client.chat.completions, "create", side_effect=mock_create): + + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + + chunks = [] + async for chunk in model.stream(messages): + chunks.append(chunk) + + # Verify we got the expected chunks + assert any("messageStart" in chunk for chunk in chunks) + assert any("contentBlockDelta" in chunk and chunk["contentBlockDelta"]["delta"]["text"] == "Hello" for chunk in chunks) + assert any("contentBlockDelta" in chunk and chunk["contentBlockDelta"]["delta"]["text"] == " world" for chunk in chunks) + assert any("messageStop" in chunk for chunk in chunks) + assert any("metadata" in chunk for chunk in chunks) + + @pytest.mark.asyncio + async def test_structured_output(self) -> None: + """Test structured output functionality using the enhanced implementation.""" + + class TestOutput(BaseModel): + answer: str + confidence: float + + model = LlamaCppModel() + + # Mock successful JSON response using the new structured_output implementation + mock_response_text = '{"answer": "yes", "confidence": 0.95}' + + # Create mock stream that returns JSON + async def mock_stream(*args, **kwargs): + # Verify json_schema was set + assert "json_schema" in model.config.get("params", {}) + + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": mock_response_text}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + with patch.object(model, "stream", side_effect=mock_stream): + messages = [{"role": "user", "content": [{"text": "Is the earth round?"}]}] + + events = [] + async for event in model.structured_output(TestOutput, messages): + events.append(event) + + # Check we got the output + output_event = next((e for e in events if "output" in e), None) + assert output_event is not None + assert output_event["output"].answer == "yes" + assert output_event["output"].confidence == 0.95 + + def test_timeout_configuration(self) -> None: + """Test timeout configuration.""" + model = LlamaCppModel(timeout=30.0) + + # The timeout should be passed to the OpenAI client + assert model.client.timeout == 30.0 + + # Test with tuple timeout + model2 = LlamaCppModel(timeout=(10.0, 60.0)) + assert model2.client.timeout == (10.0, 60.0) + + def test_max_retries_configuration(self) -> None: + """Test max retries configuration.""" + model = LlamaCppModel(max_retries=5) + + # The max_retries should be passed to the OpenAI client + assert model.client.max_retries == 5 + + def test_use_grammar_constraint(self) -> None: + """Test grammar constraint method.""" + model = LlamaCppModel() + + # Apply grammar constraint + grammar = ''' + root ::= answer + answer ::= "yes" | "no" + ''' + model.use_grammar_constraint(grammar) + + assert model.config["params"]["grammar"] == grammar + + # Update grammar + new_grammar = 'root ::= [0-9]+' + model.use_grammar_constraint(new_grammar) + + assert model.config["params"]["grammar"] == new_grammar + + def test_use_json_schema(self) -> None: + """Test JSON schema constraint method.""" + model = LlamaCppModel() + + # Apply JSON schema + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + }, + "required": ["name", "age"] + } + model.use_json_schema(schema) + + assert model.config["params"]["json_schema"] == schema + + @pytest.mark.asyncio + async def test_stream_with_context_overflow_error(self) -> None: + """Test stream handling of context overflow errors.""" + model = LlamaCppModel() + + # Create HTTP error response + error_response = httpx.Response( + status_code=400, + json={"error": {"message": "Context window exceeded. Max context length is 4096 tokens"}}, + request=httpx.Request("POST", "http://test.com") + ) + error = httpx.HTTPStatusError("Bad Request", request=error_response.request, response=error_response) + + # Mock the parent stream to raise the error + with patch.object(model.client.chat.completions, "create", side_effect=error): + messages = [{"role": "user", "content": [{"text": "Very long message"}]}] + + with pytest.raises(LlamaCppContextOverflowError) as exc_info: + async for _ in model.stream(messages): + pass + + assert "Context window exceeded" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_stream_with_server_overload_error(self) -> None: + """Test stream handling of server overload errors.""" + model = LlamaCppModel() + + # Create HTTP error response for 503 + error_response = httpx.Response( + status_code=503, + text="Server is busy", + request=httpx.Request("POST", "http://test.com") + ) + error = httpx.HTTPStatusError("Service Unavailable", request=error_response.request, response=error_response) + + # Mock the parent stream to raise the error + with patch.object(model.client.chat.completions, "create", side_effect=error): + messages = [{"role": "user", "content": [{"text": "Test"}]}] + + with pytest.raises(ModelThrottledException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "server is busy or overloaded" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_structured_output_with_json_schema(self) -> None: + """Test structured output using JSON schema.""" + + class TestOutput(BaseModel): + answer: str + confidence: float + + model = LlamaCppModel() + + # Mock successful JSON response + mock_response_text = '{"answer": "yes", "confidence": 0.95}' + + # Create mock stream that returns JSON + async def mock_stream(*args, **kwargs): + # Check that json_schema was set correctly + assert model.config["params"]["json_schema"] == TestOutput.model_json_schema() + + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": mock_response_text}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + with patch.object(model, "stream", side_effect=mock_stream): + messages = [{"role": "user", "content": [{"text": "Is the earth round?"}]}] + + events = [] + async for event in model.structured_output(TestOutput, messages): + events.append(event) + + # Check we got the output + output_event = next((e for e in events if "output" in e), None) + assert output_event is not None + assert output_event["output"].answer == "yes" + assert output_event["output"].confidence == 0.95 + + @pytest.mark.asyncio + async def test_structured_output_invalid_json_error(self) -> None: + """Test structured output raises error for invalid JSON.""" + + class TestOutput(BaseModel): + value: int + + model = LlamaCppModel() + + # Mock stream that returns invalid JSON + async def mock_stream(*args, **kwargs): + # Check that json_schema was set correctly + assert model.config["params"]["json_schema"] == TestOutput.model_json_schema() + + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": "This is not valid JSON"}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + with patch.object(model, "stream", side_effect=mock_stream): + messages = [{"role": "user", "content": [{"text": "Give me a number"}]}] + + with pytest.raises(json.JSONDecodeError): + async for event in model.structured_output(TestOutput, messages): + pass \ No newline at end of file diff --git a/tests_integ/models/test_model_llamacpp.py b/tests_integ/models/test_model_llamacpp.py new file mode 100644 index 000000000..d7ecd9195 --- /dev/null +++ b/tests_integ/models/test_model_llamacpp.py @@ -0,0 +1,495 @@ +"""Integration tests for llama.cpp model provider. + +These tests require a running llama.cpp server instance. +To run these tests: +1. Start llama.cpp server: llama-server -m model.gguf --host 0.0.0.0 --port 8080 +2. Run: pytest tests_integ/models/test_model_llamacpp.py + +Set LLAMACPP_TEST_URL environment variable to use a different server URL. +""" + +import os +from typing import Any, AsyncGenerator + +import pytest +from pydantic import BaseModel + +from strands.models.llamacpp import LlamaCppModel +from strands.types.content import Message + + +# Get server URL from environment or use default +LLAMACPP_URL = os.environ.get("LLAMACPP_TEST_URL", "http://localhost:8080/v1") + +# Skip these tests if LLAMACPP_SKIP_TESTS is set +pytestmark = pytest.mark.skipif( + os.environ.get("LLAMACPP_SKIP_TESTS", "true").lower() == "true", + reason="llama.cpp integration tests disabled (set LLAMACPP_SKIP_TESTS=false to enable)", +) + + +class WeatherOutput(BaseModel): + """Test output model for structured responses.""" + + temperature: float + condition: str + location: str + + +@pytest.fixture +async def llamacpp_model() -> LlamaCppModel: + """Fixture to create a llama.cpp model instance.""" + return LlamaCppModel(base_url=LLAMACPP_URL) + + +class TestLlamaCppModelIntegration: + """Integration tests for LlamaCppModel with a real server.""" + + @pytest.mark.asyncio + async def test_basic_completion(self, llamacpp_model: LlamaCppModel) -> None: + """Test basic text completion.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Say 'Hello, World!' and nothing else."}]}, + ] + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + assert "Hello, World!" in response_text + + @pytest.mark.asyncio + async def test_system_prompt(self, llamacpp_model: LlamaCppModel) -> None: + """Test completion with system prompt.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Who are you?"}]}, + ] + + system_prompt = "You are a helpful AI assistant named Claude." + + response_text = "" + async for event in llamacpp_model.stream(messages, system_prompt=system_prompt): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Response should reflect the system prompt + assert len(response_text) > 0 + assert "assistant" in response_text.lower() or "claude" in response_text.lower() + + @pytest.mark.asyncio + async def test_streaming_chunks(self, llamacpp_model: LlamaCppModel) -> None: + """Test that streaming returns proper chunk sequence.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Count from 1 to 3."}]}, + ] + + chunk_types = [] + async for event in llamacpp_model.stream(messages): + chunk_types.append(next(iter(event.keys()))) + + # Verify proper chunk sequence + assert chunk_types[0] == "messageStart" + assert chunk_types[1] == "contentBlockStart" + assert "contentBlockDelta" in chunk_types + assert chunk_types[-3] == "contentBlockStop" + assert chunk_types[-2] == "messageStop" + assert chunk_types[-1] == "metadata" + + @pytest.mark.asyncio + async def test_temperature_parameter(self, llamacpp_model: LlamaCppModel) -> None: + """Test temperature parameter affects randomness.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Generate a random word."}]}, + ] + + # Low temperature should give more consistent results + llamacpp_model.update_config(params={"temperature": 0.1, "seed": 42}) + + response1 = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response1 += delta["text"] + + # Same seed and low temperature should give similar result + response2 = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response2 += delta["text"] + + # With low temperature and same seed, responses should be very similar + assert len(response1) > 0 + assert len(response2) > 0 + + @pytest.mark.asyncio + async def test_max_tokens_limit(self, llamacpp_model: LlamaCppModel) -> None: + """Test max_tokens parameter limits response length.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Tell me a very long story about dragons."}]}, + ] + + # Set very low token limit + llamacpp_model.update_config(params={"max_tokens": 10}) + + token_count = 0 + async for event in llamacpp_model.stream(messages): + if "metadata" in event: + usage = event["metadata"]["usage"] + token_count = usage["outputTokens"] + if "messageStop" in event: + stop_reason = event["messageStop"]["stopReason"] + + # Should stop due to max_tokens + assert token_count <= 15 # Allow small overage due to tokenization + assert stop_reason == "max_tokens" + + @pytest.mark.asyncio + async def test_structured_output(self, llamacpp_model: LlamaCppModel) -> None: + """Test structured output generation.""" + messages: list[Message] = [ + { + "role": "user", + "content": [ + { + "text": "What's the weather like in Paris? " + "Respond with temperature in Celsius, condition, and location." + } + ], + }, + ] + + # Enable JSON response format for structured output + llamacpp_model.update_config(params={"response_format": {"type": "json_object"}}) + + result = None + async for event in llamacpp_model.structured_output(WeatherOutput, messages): + if "output" in event: + result = event["output"] + + assert result is not None + assert isinstance(result, WeatherOutput) + assert isinstance(result.temperature, float) + assert isinstance(result.condition, str) + assert result.location.lower() == "paris" + + @pytest.mark.asyncio + async def test_llamacpp_specific_params(self, llamacpp_model: LlamaCppModel) -> None: + """Test llama.cpp specific parameters.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Say 'test' five times."}]}, + ] + + # Use llama.cpp specific parameters + llamacpp_model.update_config( + params={ + "repeat_penalty": 1.5, # Penalize repetition + "top_k": 10, # Limit vocabulary + "min_p": 0.1, # Min-p sampling + } + ) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Response should contain "test" but with repetition penalty it might vary + assert "test" in response_text.lower() + + @pytest.mark.asyncio + async def test_advanced_sampling_params(self, llamacpp_model: LlamaCppModel) -> None: + """Test advanced sampling parameters.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Generate a random sentence about space."}]}, + ] + + # Test advanced sampling parameters + llamacpp_model.update_config( + params={ + "temperature": 0.8, + "tfs_z": 0.95, # Tail-free sampling + "top_a": 0.1, # Top-a sampling + "typical_p": 0.9, # Typical-p sampling + "penalty_last_n": 64, # Penalty context window + "min_keep": 1, # Minimum tokens to keep + "samplers": ["top_k", "tfs_z", "typical_p", "top_p", "min_p", "temperature"] + } + ) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Should generate something about space + assert len(response_text) > 0 + assert any(word in response_text.lower() for word in ["space", "star", "planet", "galaxy", "universe"]) + + @pytest.mark.asyncio + async def test_mirostat_sampling(self, llamacpp_model: LlamaCppModel) -> None: + """Test Mirostat sampling modes.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Write a short poem."}]}, + ] + + # Test Mirostat v2 + llamacpp_model.update_config( + params={ + "mirostat": 2, + "mirostat_lr": 0.1, + "mirostat_ent": 5.0, + "seed": 42, # For reproducibility + } + ) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Should generate a poem + assert len(response_text) > 20 + assert "\n" in response_text # Poems typically have line breaks + + @pytest.mark.asyncio + async def test_grammar_constraint(self, llamacpp_model: LlamaCppModel) -> None: + """Test grammar constraint feature (llama.cpp specific).""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Is the sky blue? Answer yes or no."}]}, + ] + + # Use the new grammar constraint method + grammar = """ + root ::= answer + answer ::= "yes" | "no" + """ + llamacpp_model.use_grammar_constraint(grammar) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Response should be exactly "yes" or "no" + assert response_text.strip().lower() in ["yes", "no"] + + @pytest.mark.asyncio + async def test_json_schema_constraint(self, llamacpp_model: LlamaCppModel) -> None: + """Test JSON schema constraint feature.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Describe the weather in JSON format with temperature and description."}]}, + ] + + # Use JSON schema constraint + schema = { + "type": "object", + "properties": { + "temperature": {"type": "number"}, + "description": {"type": "string"} + }, + "required": ["temperature", "description"] + } + llamacpp_model.use_json_schema(schema) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Should be valid JSON matching the schema + import json + data = json.loads(response_text.strip()) + assert "temperature" in data + assert "description" in data + assert isinstance(data["temperature"], (int, float)) + assert isinstance(data["description"], str) + + @pytest.mark.asyncio + async def test_logit_bias(self, llamacpp_model: LlamaCppModel) -> None: + """Test logit bias feature.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Choose between 'cat' and 'dog'."}]}, + ] + + # This is a simplified test - in reality you'd need to know the actual token IDs + # for "cat" and "dog" in the model's vocabulary + llamacpp_model.update_config( + params={ + "logit_bias": { + # These are placeholder token IDs - real implementation would need actual token IDs + 1234: 10.0, # Strong positive bias (hypothetical "cat" token) + 5678: -10.0, # Strong negative bias (hypothetical "dog" token) + }, + "seed": 42, # For reproducibility + } + ) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Should generate text (exact behavior depends on actual token IDs) + assert len(response_text) > 0 + + @pytest.mark.asyncio + async def test_cache_prompt(self, llamacpp_model: LlamaCppModel) -> None: + """Test prompt caching feature.""" + messages: list[Message] = [ + {"role": "system", "content": [{"text": "You are a helpful assistant. Always be concise."}]}, + {"role": "user", "content": [{"text": "What is 2+2?"}]}, + ] + + # Enable prompt caching + llamacpp_model.update_config( + params={ + "cache_prompt": True, + "slot_id": 0, # Use specific slot for caching + } + ) + + # First request + response1 = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response1 += delta["text"] + + # Second request with same system prompt should use cache + messages2 = [ + {"role": "system", "content": [{"text": "You are a helpful assistant. Always be concise."}]}, + {"role": "user", "content": [{"text": "What is 3+3?"}]}, + ] + + response2 = "" + async for event in llamacpp_model.stream(messages2): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response2 += delta["text"] + + # Both should give valid responses + assert "4" in response1 + assert "6" in response2 + + @pytest.mark.asyncio + async def test_concurrent_requests(self, llamacpp_model: LlamaCppModel) -> None: + """Test handling multiple concurrent requests.""" + import asyncio + + async def make_request(prompt: str) -> str: + messages: list[Message] = [ + {"role": "user", "content": [{"text": prompt}]}, + ] + + response = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response += delta["text"] + return response + + # Make concurrent requests + prompts = [ + "Say 'one'", + "Say 'two'", + "Say 'three'", + ] + + responses = await asyncio.gather(*[make_request(p) for p in prompts]) + + # Each response should contain the expected number + assert "one" in responses[0].lower() + assert "two" in responses[1].lower() + assert "three" in responses[2].lower() + + @pytest.mark.asyncio + async def test_enhanced_structured_output(self, llamacpp_model: LlamaCppModel) -> None: + """Test enhanced structured output with native JSON schema support.""" + + class BookInfo(BaseModel): + title: str + author: str + year: int + genres: list[str] + + messages: list[Message] = [ + { + "role": "user", + "content": [ + { + "text": "Create information about a fictional science fiction book. " + "Include title, author, publication year, and 2-3 genres." + } + ], + }, + ] + + result = None + events = [] + async for event in llamacpp_model.structured_output(BookInfo, messages): + events.append(event) + if "output" in event: + result = event["output"] + + # Verify we got structured output + assert result is not None + assert isinstance(result, BookInfo) + assert isinstance(result.title, str) and len(result.title) > 0 + assert isinstance(result.author, str) and len(result.author) > 0 + assert isinstance(result.year, int) and 1900 <= result.year <= 2100 + assert isinstance(result.genres, list) and len(result.genres) >= 2 + assert all(isinstance(genre, str) for genre in result.genres) + + # Should have streamed events before the output + assert len(events) > 1 + + @pytest.mark.asyncio + async def test_context_overflow_handling(self, llamacpp_model: LlamaCppModel) -> None: + """Test proper handling of context window overflow.""" + # Create a very long message that might exceed context + long_text = "This is a test sentence. " * 1000 + messages: list[Message] = [ + {"role": "user", "content": [{"text": f"Summarize this text: {long_text}"}]}, + ] + + try: + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # If it succeeds, we got a response + assert len(response_text) > 0 + except Exception as e: + # If it fails, it should be our custom error + from strands.models.llamacpp import LlamaCppContextOverflowError + if isinstance(e, LlamaCppContextOverflowError): + assert "context" in str(e).lower() + else: + # Some other error - re-raise to see what it was + raise \ No newline at end of file From 71e80d25653fd6af34d27c9f2dccac3dba4f6499 Mon Sep 17 00:00:00 2001 From: Aaron Brown Date: Thu, 31 Jul 2025 10:48:04 -0500 Subject: [PATCH 02/11] Updated llamacpp integration to better align with ollama implementation; --- src/strands/models/llamacpp.py | 565 ++++++++++++++++++++++---- tests/strands/models/test_llamacpp.py | 517 +++++++++++++++-------- 2 files changed, 824 insertions(+), 258 deletions(-) diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index 59cb64656..5a0f5340e 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -1,22 +1,29 @@ """llama.cpp model provider. +Provides integration with llama.cpp servers running in OpenAI-compatible mode, +with support for advanced llama.cpp-specific features. + - Docs: https://github.com/ggml-org/llama.cpp - Server docs: https://github.com/ggml-org/llama.cpp/tree/master/tools/server +- OpenAI API compatibility: + https://github.com/ggml-org/llama.cpp/blob/master/tools/server/README.md#api-endpoints """ +import base64 import json import logging +import mimetypes from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union import httpx from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import Messages +from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolSpec -from .openai import OpenAIModel +from .model import Model logger = logging.getLogger(__name__) @@ -24,28 +31,38 @@ class LlamaCppError(Exception): - """Base exception for llama.cpp specific errors.""" - pass + """Base exception for llama.cpp specific errors. + + This exception serves as the base class for all llama.cpp-specific errors, + allowing for targeted error handling in client code. + """ class LlamaCppContextOverflowError(LlamaCppError, ContextWindowOverflowException): - """Raised when context window is exceeded in llama.cpp.""" - pass + """Raised when context window is exceeded in llama.cpp. + + This error occurs when the combined input and output tokens exceed + the model's context window size. Common causes include: + - Long input prompts + - Extended conversations + - Large system prompts or tool definitions + """ -class LlamaCppModel(OpenAIModel): +class LlamaCppModel(Model): """llama.cpp model provider implementation. Connects to a llama.cpp server running in OpenAI-compatible mode with support for advanced llama.cpp-specific features like grammar constraints, - Mirostat sampling, and native JSON schema validation. + Mirostat sampling, native JSON schema validation, and native multimodal + support for audio and image content. The llama.cpp server must be started with the OpenAI-compatible API enabled: llama-server -m model.gguf --host 0.0.0.0 --port 8080 Example: Basic usage: - >>> model = LlamaCppModel(base_url="http://localhost:8080/v1") + >>> model = LlamaCppModel(base_url="http://localhost:8080") >>> model.update_config(params={"temperature": 0.7, "top_k": 40}) Grammar constraints: @@ -61,6 +78,21 @@ class LlamaCppModel(OpenAIModel): ... "tfs_z": 0.95, ... "repeat_penalty": 1.1 ... }) + + Multimodal usage (requires multimodal model like Qwen2.5-Omni): + >>> # Audio analysis + >>> audio_content = [{ + ... "audio": {"source": {"bytes": audio_bytes}, "format": "wav"}, + ... "text": "What do you hear in this audio?" + ... }] + >>> response = agent(audio_content) + + >>> # Image analysis + >>> image_content = [{ + ... "image": {"source": {"bytes": image_bytes}, "format": "png"}, + ... "text": "Describe this image" + ... }] + >>> response = agent(image_content) """ class LlamaCppConfig(TypedDict, total=False): @@ -110,37 +142,31 @@ class LlamaCppConfig(TypedDict, total=False): def __init__( self, - base_url: str = "http://localhost:8080/v1", - api_key: Optional[str] = None, + base_url: str = "http://localhost:8080", timeout: Optional[Union[float, tuple[float, float]]] = None, - max_retries: Optional[int] = None, **model_config: Unpack[LlamaCppConfig], ) -> None: """Initialize llama.cpp provider instance. Args: base_url: Base URL for the llama.cpp server. - Default is "http://localhost:8080/v1" for local server. - api_key: Optional API key if the llama.cpp server requires authentication. - timeout: Request timeout in seconds. Can be a float or tuple of (connect, read) timeouts. - max_retries: Maximum number of retries for failed requests. + Default is "http://localhost:8080" for local server. + timeout: Request timeout in seconds. Can be float or tuple of + (connect, read) timeouts. **model_config: Configuration options for the llama.cpp model. """ # Set default model_id if not provided if "model_id" not in model_config: model_config["model_id"] = "default" - # Build OpenAI client args - client_args = { - "base_url": base_url, - "api_key": api_key or "dummy", # OpenAI client requires some API key - } - - if timeout is not None: - client_args["timeout"] = timeout + self.base_url = base_url.rstrip("/") + self.config = dict(model_config) - if max_retries is not None: - client_args["max_retries"] = max_retries + # Configure HTTP client + self.client = httpx.AsyncClient( + base_url=self.base_url, + timeout=timeout or 30.0, + ) logger.debug( "base_url=<%s>, model_id=<%s> | initializing llama.cpp provider", @@ -148,11 +174,28 @@ def __init__( model_config.get("model_id"), ) - # Initialize parent OpenAI model with our client args - super().__init__(client_args=client_args, **model_config) + @override + def update_config( + self, **model_config: Unpack[LlamaCppConfig] + ) -> None: # type: ignore[override] + """Update the llama.cpp model configuration with provided arguments. + + Args: + **model_config: Configuration overrides. + """ + self.config.update(model_config) + + @override + def get_config(self) -> LlamaCppConfig: + """Get the llama.cpp model configuration. + + Returns: + The llama.cpp model configuration. + """ + return self.config # type: ignore[return-value] def use_grammar_constraint(self, grammar: str) -> None: - """Apply a GBNF grammar constraint to the generation. + r"""Apply a GBNF grammar constraint to the generation. Args: grammar: GBNF (Backus-Naur Form) grammar string defining allowed outputs. @@ -170,7 +213,7 @@ def use_grammar_constraint(self, grammar: str) -> None: ... root ::= object ... object ::= "{" pair ("," pair)* "}" ... pair ::= string ":" value - ... string ::= "\\"" [^"]* "\\"" + ... string ::= "\"" [^"]* "\"" ... value ::= string | number | "true" | "false" | "null" ... number ::= "-"? [0-9]+ ("." [0-9]+)? ... ''') @@ -178,7 +221,10 @@ def use_grammar_constraint(self, grammar: str) -> None: if not self.config.get("params"): self.config["params"] = {} self.config["params"]["grammar"] = grammar - logger.debug("Applied grammar constraint") + logger.debug( + "grammar=<%s> | applied grammar constraint", + grammar[:50] + "..." if len(grammar) > 50 else grammar, + ) def use_json_schema(self, schema: dict[str, Any]) -> None: """Apply a JSON schema constraint for structured output. @@ -199,10 +245,162 @@ def use_json_schema(self, schema: dict[str, Any]) -> None: if not self.config.get("params"): self.config["params"] = {} self.config["params"]["json_schema"] = schema - logger.debug("Applied JSON schema constraint") + logger.debug("schema=<%s> | applied JSON schema constraint", schema) - @override - def format_request( + def _format_message_content(self, content: ContentBlock) -> dict[str, Any]: + """Format a content block for llama.cpp. + + Args: + content: Message content. + + Returns: + llama.cpp compatible content block. + + Raises: + TypeError: If the content block type cannot be converted to a compatible format. + """ + if "document" in content: + mime_type = mimetypes.types_map.get( + f".{content['document']['format']}", "application/octet-stream" + ) + file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode( + "utf-8" + ) + return { + "file": { + "file_data": f"data:{mime_type};base64,{file_data}", + "filename": content["document"]["name"], + }, + "type": "file", + } + + if "image" in content: + mime_type = mimetypes.types_map.get( + f".{content['image']['format']}", "application/octet-stream" + ) + image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode( + "utf-8" + ) + return { + "image_url": { + "detail": "auto", + "format": mime_type, + "url": f"data:{mime_type};base64,{image_data}", + }, + "type": "image_url", + } + + if "audio" in content: + audio_data = base64.b64encode(content["audio"]["source"]["bytes"]).decode( + "utf-8" + ) + audio_format = content["audio"].get("format", "wav") + return { + "type": "input_audio", + "input_audio": {"data": audio_data, "format": audio_format}, + } + + if "text" in content: + return {"text": content["text"], "type": "text"} + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + def _format_tool_call(self, tool_use: dict[str, Any]) -> dict[str, Any]: + """Format a tool call for llama.cpp. + + Args: + tool_use: Tool use requested by the model. + + Returns: + llama.cpp compatible tool call. + """ + return { + "function": { + "arguments": json.dumps(tool_use["input"]), + "name": tool_use["name"], + }, + "id": tool_use["toolUseId"], + "type": "function", + } + + def _format_tool_message(self, tool_result: dict[str, Any]) -> dict[str, Any]: + """Format a tool message for llama.cpp. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + llama.cpp compatible tool message. + """ + contents = [ + {"text": json.dumps(content["json"])} if "json" in content else content + for content in tool_result["content"] + ] + + return { + "role": "tool", + "tool_call_id": tool_result["toolUseId"], + "content": [self._format_message_content(content) for content in contents], + } + + def _format_messages( + self, messages: Messages, system_prompt: Optional[str] = None + ) -> list[dict[str, Any]]: + """Format messages for llama.cpp. + + Args: + messages: List of message objects to be processed. + system_prompt: System prompt to provide context to the model. + + Returns: + Formatted messages array compatible with llama.cpp. + """ + formatted_messages = [] + + # Add system prompt if provided + if system_prompt: + formatted_messages.append({"role": "system", "content": system_prompt}) + + for message in messages: + contents = message["content"] + + formatted_contents = [ + self._format_message_content(content) + for content in contents + if not any( + block_type in content for block_type in ["toolResult", "toolUse"] + ) + ] + formatted_tool_calls = [ + self._format_tool_call(content["toolUse"]) + for content in contents + if "toolUse" in content + ] + formatted_tool_messages = [ + self._format_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + formatted_message = { + "role": message["role"], + "content": formatted_contents, + **( + {} + if not formatted_tool_calls + else {"tool_calls": formatted_tool_calls} + ), + } + formatted_messages.append(formatted_message) + formatted_messages.extend(formatted_tool_messages) + + return [ + message + for message in formatted_messages + if message["content"] or "tool_calls" in message + ] + + def _format_request( self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, @@ -210,8 +408,6 @@ def format_request( ) -> dict[str, Any]: """Format a request for the llama.cpp server. - This method overrides the OpenAI format to properly handle llama.cpp-specific parameters. - Args: messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. @@ -220,12 +416,9 @@ def format_request( Returns: A request formatted for llama.cpp server's OpenAI-compatible API. """ - # Build base request structure without calling super() to avoid - # parameter conflicts between OpenAI and llama.cpp specific params. - # This allows us to properly separate parameters into the appropriate - # request fields (direct vs extra_body). + # Separate OpenAI-compatible and llama.cpp-specific parameters request = { - "messages": self.format_request_messages(messages, system_prompt), + "messages": self._format_messages(messages, system_prompt), "model": self.config["model_id"], "stream": True, "stream_options": {"include_usage": True}, @@ -246,7 +439,7 @@ def format_request( if self.config.get("params"): params = self.config["params"] - # Define llama.cpp-specific parameters that need special handling + # llama.cpp-specific parameters that must be passed via extra_body llamacpp_specific_params = { "repeat_penalty", "top_k", @@ -269,7 +462,7 @@ def format_request( "samplers", } - # Standard OpenAI parameters that go directly in request + # Standard OpenAI parameters that go directly in the request openai_params = { "temperature", "max_tokens", @@ -301,6 +494,84 @@ def format_request( return request + def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format a llama.cpp response event into a standardized message chunk. + + Args: + event: A response event from the llama.cpp server. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "tool": + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": event["data"].function.name, + "toolUseId": event["data"].id, + } + } + } + } + return {"contentBlockStart": {"start": {}}} + + case "content_delta": + if event["data_type"] == "tool": + return { + "contentBlockDelta": { + "delta": { + "toolUse": { + "input": event["data"].function.arguments or "" + } + } + } + } + if event["data_type"] == "reasoning_content": + return { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": event["data"]}} + } + } + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event["data"]: + case "tool_calls": + return {"messageStop": {"stopReason": "tool_use"}} + case "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + }, + "metrics": { + "latencyMs": 0, # TODO: Add actual latency calculation + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']}> | unknown type") + @override async def stream( self, @@ -311,8 +582,6 @@ async def stream( ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the llama.cpp model. - This method extends the OpenAI stream to handle llama.cpp-specific errors. - Args: messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. @@ -327,19 +596,169 @@ async def stream( ModelThrottledException: When the llama.cpp server is overloaded. """ try: - async for event in super().stream(messages, tool_specs, system_prompt, **kwargs): - yield event + logger.debug("formatting request for llama.cpp server") + request = self._format_request(messages, tool_specs, system_prompt) + logger.debug("request=<%s>", request) + + logger.debug("sending request to llama.cpp server") + response = await self.client.post("/v1/chat/completions", json=request) + response.raise_for_status() + + logger.debug("processing streaming response") + yield self._format_chunk({"chunk_type": "message_start"}) + yield self._format_chunk( + {"chunk_type": "content_start", "data_type": "text"} + ) + + tool_calls = {} + usage_data = None + + async for line in response.aiter_lines(): + if not line.strip() or not line.startswith("data: "): + continue + + data_content = line[6:] # Remove "data: " prefix + if data_content.strip() == "[DONE]": + break + + try: + event = json.loads(data_content) + except json.JSONDecodeError: + continue + + # Handle usage information + if "usage" in event: + usage_data = event["usage"] + continue + + if not event.get("choices"): + continue + + choice = event["choices"][0] + delta = choice.get("delta", {}) + + # Handle content deltas + if "content" in delta and delta["content"]: + yield self._format_chunk( + { + "chunk_type": "content_delta", + "data_type": "text", + "data": delta["content"], + } + ) + + # Handle tool calls + if "tool_calls" in delta: + for tool_call in delta["tool_calls"]: + index = tool_call["index"] + if index not in tool_calls: + tool_calls[index] = [] + tool_calls[index].append(tool_call) + + # Check for finish reason + if choice.get("finish_reason"): + break + + yield self._format_chunk({"chunk_type": "content_stop"}) + + # Process tool calls + for tool_deltas in tool_calls.values(): + first_delta = tool_deltas[0] + yield self._format_chunk( + { + "chunk_type": "content_start", + "data_type": "tool", + "data": type( + "ToolCall", + (), + { + "function": type( + "Function", + (), + { + "name": first_delta.get("function", {}).get( + "name", "" + ), + }, + )(), + "id": first_delta.get("id", ""), + }, + )(), + } + ) + + for tool_delta in tool_deltas: + yield self._format_chunk( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": type( + "ToolCall", + (), + { + "function": type( + "Function", + (), + { + "arguments": tool_delta.get( + "function", {} + ).get("arguments", ""), + }, + )(), + }, + )(), + } + ) + + yield self._format_chunk({"chunk_type": "content_stop"}) + + # Send stop reason + stop_reason = ( + "tool_use" + if tool_calls + else getattr(choice, "finish_reason", "end_turn") + ) + yield self._format_chunk( + {"chunk_type": "message_stop", "data": stop_reason} + ) + + # Send usage metadata if available + if usage_data: + yield self._format_chunk( + { + "chunk_type": "metadata", + "data": type( + "Usage", + (), + { + "prompt_tokens": usage_data.get("prompt_tokens", 0), + "completion_tokens": usage_data.get( + "completion_tokens", 0 + ), + "total_tokens": usage_data.get("total_tokens", 0), + }, + )(), + } + ) + + logger.debug("finished streaming response") + except httpx.HTTPStatusError as e: if e.response.status_code == 400: - # Parse error response + # Parse error response from llama.cpp server try: error_data = e.response.json() - error_msg = str(error_data.get("error", {}).get("message", str(error_data))) + error_msg = str( + error_data.get("error", {}).get("message", str(error_data)) + ) except (json.JSONDecodeError, KeyError, AttributeError): error_msg = e.response.text - # Check for context overflow - if any(term in error_msg.lower() for term in ["context", "kv cache", "slot"]): + # Check for context overflow by looking for specific error indicators + if any( + term in error_msg.lower() + for term in ["context", "kv cache", "slot"] + ): raise LlamaCppContextOverflowError( f"Context window exceeded: {error_msg}" ) from e @@ -349,7 +768,7 @@ async def stream( ) from e raise except Exception as e: - # Handle other potential errors + # Handle other potential errors like rate limiting error_msg = str(e).lower() if "rate" in error_msg or "429" in str(e): raise ModelThrottledException(str(e)) from e @@ -357,7 +776,11 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any + self, + output_model: Type[T], + prompt: Messages, + system_prompt: Optional[str] = None, + **kwargs: Any, ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output using llama.cpp's native JSON schema support. @@ -389,16 +812,18 @@ async def structured_output( self.config["params"] = {} self.config["params"]["json_schema"] = schema - self.config["params"]["cache_prompt"] = True # Cache schema processing + self.config["params"]["cache_prompt"] = True # Collect the response response_text = "" - async for event in self.stream(prompt, system_prompt=system_prompt, **kwargs): + async for event in self.stream( + prompt, system_prompt=system_prompt, **kwargs + ): if "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] if "text" in delta: response_text += delta["text"] - # Pass through other events + # Forward events to caller yield event # Parse and validate the JSON response @@ -407,31 +832,5 @@ async def structured_output( yield {"output": output_instance} finally: - # Restore original params + # Restore original configuration self.config["params"] = original_params - - def _generate_pydantic_grammar(self, model: Type[BaseModel]) -> str: - """Generate a GBNF grammar from a Pydantic model. - - Args: - model: The Pydantic model to generate grammar for. - - Returns: - GBNF grammar string. - - Note: - This provides a basic JSON grammar. A future enhancement would - generate model-specific grammars based on the Pydantic schema. - """ - # Basic JSON grammar that works for most cases - return ''' -root ::= object -object ::= "{" pair ("," pair)* "}" -pair ::= string ":" value -string ::= "\\"" [^"]* "\\"" -value ::= string | number | boolean | null | array | object -array ::= "[" (value ("," value)*)? "]" -number ::= "-"? [0-9]+ ("." [0-9]+)? -boolean ::= "true" | "false" -null ::= "null" -''' diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index 829452b0a..5f967e6db 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -1,76 +1,79 @@ """Unit tests for llama.cpp model provider.""" +import base64 import json -from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import httpx import pytest -from openai import AsyncOpenAI from pydantic import BaseModel -from strands.types.content import ContentBlock, Message -from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException -from strands.models.llamacpp import LlamaCppModel, LlamaCppError, LlamaCppContextOverflowError +from strands.models.llamacpp import ( + LlamaCppContextOverflowError, + LlamaCppModel, +) +from strands.types.exceptions import ( + ModelThrottledException, +) class TestLlamaCppModel: """Test suite for LlamaCppModel.""" - + def test_init_default_config(self) -> None: """Test initialization with default configuration.""" model = LlamaCppModel() - + assert model.config["model_id"] == "default" - assert isinstance(model.client, AsyncOpenAI) - # Check that base_url was set correctly - assert model.client.base_url == "http://localhost:8080/v1/" - + assert isinstance(model.client, httpx.AsyncClient) + assert model.base_url == "http://localhost:8080" + def test_init_custom_config(self) -> None: """Test initialization with custom configuration.""" model = LlamaCppModel( - base_url="http://example.com:8081/v1", + base_url="http://example.com:8081", model_id="llama-3-8b", params={"temperature": 0.7, "max_tokens": 100}, ) - + assert model.config["model_id"] == "llama-3-8b" assert model.config["params"]["temperature"] == 0.7 assert model.config["params"]["max_tokens"] == 100 - assert model.client.base_url == "http://example.com:8081/v1/" - + assert model.base_url == "http://example.com:8081" + def test_format_request_basic(self) -> None: """Test basic request formatting.""" model = LlamaCppModel(model_id="test-model") - + messages = [ {"role": "user", "content": [{"text": "Hello"}]}, ] - - request = model.format_request(messages) - + + request = model._format_request(messages) + assert request["model"] == "test-model" assert request["messages"][0]["role"] == "user" - # OpenAI format returns content as an array assert request["messages"][0]["content"][0]["type"] == "text" assert request["messages"][0]["content"][0]["text"] == "Hello" assert request["stream"] is True - assert "extra_body" not in request # No llama.cpp params, so no extra_body - + assert "extra_body" not in request + def test_format_request_with_system_prompt(self) -> None: """Test request formatting with system prompt.""" model = LlamaCppModel() - + messages = [ {"role": "user", "content": [{"text": "Hello"}]}, ] - - request = model.format_request(messages, system_prompt="You are a helpful assistant") - + + request = model._format_request( + messages, system_prompt="You are a helpful assistant" + ) + assert request["messages"][0]["role"] == "system" assert request["messages"][0]["content"] == "You are a helpful assistant" assert request["messages"][1]["role"] == "user" - + def test_format_request_with_llamacpp_params(self) -> None: """Test request formatting with llama.cpp specific parameters.""" model = LlamaCppModel( @@ -83,24 +86,24 @@ def test_format_request_with_llamacpp_params(self) -> None: "grammar": "root ::= 'yes' | 'no'", } ) - + messages = [ {"role": "user", "content": [{"text": "Is the sky blue?"}]}, ] - - request = model.format_request(messages) - + + request = model._format_request(messages) + # Standard OpenAI params assert request["temperature"] == 0.8 assert request["max_tokens"] == 50 - + # llama.cpp specific params should be in extra_body assert "extra_body" in request assert request["extra_body"]["repeat_penalty"] == 1.1 assert request["extra_body"]["top_k"] == 40 assert request["extra_body"]["min_p"] == 0.05 assert request["extra_body"]["grammar"] == "root ::= 'yes' | 'no'" - + def test_format_request_with_all_new_params(self) -> None: """Test request formatting with all new llama.cpp parameters.""" model = LlamaCppModel( @@ -132,16 +135,16 @@ def test_format_request_with_all_new_params(self) -> None: "samplers": ["top_k", "tfs_z", "typical_p"], } ) - + messages = [{"role": "user", "content": [{"text": "Test"}]}] - request = model.format_request(messages) - + request = model._format_request(messages) + # Check OpenAI params are in root assert request["temperature"] == 0.7 assert request["max_tokens"] == 100 assert request["top_p"] == 0.9 assert request["seed"] == 42 - + # Check all llama.cpp params are in extra_body assert "extra_body" in request extra = request["extra_body"] @@ -159,20 +162,20 @@ def test_format_request_with_all_new_params(self) -> None: assert extra["penalty_last_n"] == 256 assert extra["n_probs"] == 5 assert extra["min_keep"] == 1 - assert extra["ignore_eos"] == False + assert extra["ignore_eos"] is False assert extra["logit_bias"] == {100: 5.0, 200: -5.0} - assert extra["cache_prompt"] == True + assert extra["cache_prompt"] is True assert extra["slot_id"] == 1 assert extra["samplers"] == ["top_k", "tfs_z", "typical_p"] - + def test_format_request_with_tools(self) -> None: """Test request formatting with tool specifications.""" model = LlamaCppModel() - + messages = [ {"role": "user", "content": [{"text": "What's the weather?"}]}, ] - + tool_specs = [ { "name": "get_weather", @@ -188,24 +191,24 @@ def test_format_request_with_tools(self) -> None: }, } ] - - request = model.format_request(messages, tool_specs=tool_specs) - + + request = model._format_request(messages, tool_specs=tool_specs) + assert "tools" in request assert len(request["tools"]) == 1 assert request["tools"][0]["function"]["name"] == "get_weather" - + def test_update_config(self) -> None: """Test configuration update.""" model = LlamaCppModel(model_id="initial-model") - + assert model.config["model_id"] == "initial-model" - + model.update_config(model_id="updated-model", params={"temperature": 0.5}) - + assert model.config["model_id"] == "updated-model" assert model.config["params"]["temperature"] == 0.5 - + def test_get_config(self) -> None: """Test configuration retrieval.""" config = { @@ -213,263 +216,427 @@ def test_get_config(self) -> None: "params": {"temperature": 0.9}, } model = LlamaCppModel(**config) - + retrieved_config = model.get_config() - + assert retrieved_config["model_id"] == "test-model" assert retrieved_config["params"]["temperature"] == 0.9 - + @pytest.mark.asyncio async def test_stream_basic(self) -> None: """Test basic streaming functionality.""" model = LlamaCppModel() - - # Create properly structured mock events - class MockDelta: - content = None - tool_calls = None - def __init__(self, content=None): - self.content = content - - class MockChoice: - def __init__(self, content=None, finish_reason=None): - self.delta = MockDelta(content) - self.finish_reason = finish_reason - - class MockChunk: - def __init__(self, choices, usage=None): - self.choices = choices - self.usage = usage - - mock_chunks = [ - MockChunk([MockChoice(content="Hello")]), - MockChunk( - [MockChoice(content=" world", finish_reason="stop")], - usage=MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15) - ), + + # Mock HTTP response with Server-Sent Events format + mock_response_lines = [ + 'data: {"choices": [{"delta": {"content": "Hello"}}]}', + 'data: {"choices": [{"delta": {"content": " world"}, "finish_reason": "stop"}]}', + 'data: {"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}}', + "data: [DONE]", ] - - # Create async iterator - async def mock_stream(): - for chunk in mock_chunks: - yield chunk - - # Mock the create method to return a coroutine that returns the async iterator - async def mock_create(*args, **kwargs): - return mock_stream() - - with patch.object(model.client.chat.completions, "create", side_effect=mock_create): - + + async def mock_aiter_lines(): + for line in mock_response_lines: + yield line + + mock_response = AsyncMock() + mock_response.aiter_lines = mock_aiter_lines + mock_response.raise_for_status = AsyncMock() + + with patch.object(model.client, "post", return_value=mock_response): messages = [{"role": "user", "content": [{"text": "Hi"}]}] - + chunks = [] async for chunk in model.stream(messages): chunks.append(chunk) - + # Verify we got the expected chunks assert any("messageStart" in chunk for chunk in chunks) - assert any("contentBlockDelta" in chunk and chunk["contentBlockDelta"]["delta"]["text"] == "Hello" for chunk in chunks) - assert any("contentBlockDelta" in chunk and chunk["contentBlockDelta"]["delta"]["text"] == " world" for chunk in chunks) + assert any( + "contentBlockDelta" in chunk + and chunk["contentBlockDelta"]["delta"]["text"] == "Hello" + for chunk in chunks + ) + assert any( + "contentBlockDelta" in chunk + and chunk["contentBlockDelta"]["delta"]["text"] == " world" + for chunk in chunks + ) assert any("messageStop" in chunk for chunk in chunks) - assert any("metadata" in chunk for chunk in chunks) - + @pytest.mark.asyncio async def test_structured_output(self) -> None: - """Test structured output functionality using the enhanced implementation.""" - + """Test structured output functionality.""" + class TestOutput(BaseModel): + """Test output model for structured output testing.""" + answer: str confidence: float - + model = LlamaCppModel() - + # Mock successful JSON response using the new structured_output implementation mock_response_text = '{"answer": "yes", "confidence": 0.95}' - + # Create mock stream that returns JSON - async def mock_stream(*args, **kwargs): + async def mock_stream(*_args, **_kwargs): # Verify json_schema was set assert "json_schema" in model.config.get("params", {}) - + yield {"messageStart": {"role": "assistant"}} yield {"contentBlockStart": {"start": {}}} yield {"contentBlockDelta": {"delta": {"text": mock_response_text}}} yield {"contentBlockStop": {}} yield {"messageStop": {"stopReason": "end_turn"}} - + with patch.object(model, "stream", side_effect=mock_stream): messages = [{"role": "user", "content": [{"text": "Is the earth round?"}]}] - + events = [] async for event in model.structured_output(TestOutput, messages): events.append(event) - + # Check we got the output output_event = next((e for e in events if "output" in e), None) assert output_event is not None assert output_event["output"].answer == "yes" assert output_event["output"].confidence == 0.95 - + def test_timeout_configuration(self) -> None: """Test timeout configuration.""" + # Test that timeout configuration is accepted without error model = LlamaCppModel(timeout=30.0) - - # The timeout should be passed to the OpenAI client - assert model.client.timeout == 30.0 - + assert model.client.timeout is not None + # Test with tuple timeout model2 = LlamaCppModel(timeout=(10.0, 60.0)) - assert model2.client.timeout == (10.0, 60.0) - + assert model2.client.timeout is not None + def test_max_retries_configuration(self) -> None: - """Test max retries configuration.""" - model = LlamaCppModel(max_retries=5) - - # The max_retries should be passed to the OpenAI client - assert model.client.max_retries == 5 - + """Test max retries configuration is handled gracefully.""" + # Since httpx doesn't use max_retries in the same way, + # we just test that the model initializes without error + model = LlamaCppModel() + assert model.config["model_id"] == "default" + def test_use_grammar_constraint(self) -> None: """Test grammar constraint method.""" model = LlamaCppModel() - + # Apply grammar constraint - grammar = ''' + grammar = """ root ::= answer answer ::= "yes" | "no" - ''' + """ model.use_grammar_constraint(grammar) - + assert model.config["params"]["grammar"] == grammar - + # Update grammar - new_grammar = 'root ::= [0-9]+' + new_grammar = "root ::= [0-9]+" model.use_grammar_constraint(new_grammar) - + assert model.config["params"]["grammar"] == new_grammar - + def test_use_json_schema(self) -> None: """Test JSON schema constraint method.""" model = LlamaCppModel() - + # Apply JSON schema schema = { "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "integer"} - }, - "required": ["name", "age"] + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "required": ["name", "age"], } model.use_json_schema(schema) - + assert model.config["params"]["json_schema"] == schema - + @pytest.mark.asyncio async def test_stream_with_context_overflow_error(self) -> None: """Test stream handling of context overflow errors.""" model = LlamaCppModel() - + # Create HTTP error response error_response = httpx.Response( status_code=400, - json={"error": {"message": "Context window exceeded. Max context length is 4096 tokens"}}, - request=httpx.Request("POST", "http://test.com") + json={ + "error": { + "message": "Context window exceeded. Max context length is 4096 tokens" + } + }, + request=httpx.Request("POST", "http://test.com"), + ) + error = httpx.HTTPStatusError( + "Bad Request", request=error_response.request, response=error_response ) - error = httpx.HTTPStatusError("Bad Request", request=error_response.request, response=error_response) - - # Mock the parent stream to raise the error - with patch.object(model.client.chat.completions, "create", side_effect=error): + + # Mock the client to raise the error + with patch.object(model.client, "post", side_effect=error): messages = [{"role": "user", "content": [{"text": "Very long message"}]}] - + with pytest.raises(LlamaCppContextOverflowError) as exc_info: async for _ in model.stream(messages): pass - + assert "Context window exceeded" in str(exc_info.value) - + @pytest.mark.asyncio async def test_stream_with_server_overload_error(self) -> None: """Test stream handling of server overload errors.""" model = LlamaCppModel() - + # Create HTTP error response for 503 error_response = httpx.Response( status_code=503, text="Server is busy", - request=httpx.Request("POST", "http://test.com") + request=httpx.Request("POST", "http://test.com"), ) - error = httpx.HTTPStatusError("Service Unavailable", request=error_response.request, response=error_response) - - # Mock the parent stream to raise the error - with patch.object(model.client.chat.completions, "create", side_effect=error): + error = httpx.HTTPStatusError( + "Service Unavailable", + request=error_response.request, + response=error_response, + ) + + # Mock the client to raise the error + with patch.object(model.client, "post", side_effect=error): messages = [{"role": "user", "content": [{"text": "Test"}]}] - + with pytest.raises(ModelThrottledException) as exc_info: async for _ in model.stream(messages): pass - + assert "server is busy or overloaded" in str(exc_info.value) - + @pytest.mark.asyncio async def test_structured_output_with_json_schema(self) -> None: """Test structured output using JSON schema.""" - + class TestOutput(BaseModel): + """Test output model for JSON schema testing.""" + answer: str confidence: float - + model = LlamaCppModel() - + # Mock successful JSON response mock_response_text = '{"answer": "yes", "confidence": 0.95}' - + # Create mock stream that returns JSON - async def mock_stream(*args, **kwargs): + async def mock_stream(*_args, **_kwargs): # Check that json_schema was set correctly - assert model.config["params"]["json_schema"] == TestOutput.model_json_schema() - + assert ( + model.config["params"]["json_schema"] == TestOutput.model_json_schema() + ) + yield {"messageStart": {"role": "assistant"}} yield {"contentBlockStart": {"start": {}}} yield {"contentBlockDelta": {"delta": {"text": mock_response_text}}} yield {"contentBlockStop": {}} yield {"messageStop": {"stopReason": "end_turn"}} - + with patch.object(model, "stream", side_effect=mock_stream): messages = [{"role": "user", "content": [{"text": "Is the earth round?"}]}] - + events = [] async for event in model.structured_output(TestOutput, messages): events.append(event) - + # Check we got the output output_event = next((e for e in events if "output" in e), None) assert output_event is not None assert output_event["output"].answer == "yes" assert output_event["output"].confidence == 0.95 - + @pytest.mark.asyncio async def test_structured_output_invalid_json_error(self) -> None: """Test structured output raises error for invalid JSON.""" - + class TestOutput(BaseModel): + """Test output model for invalid JSON testing.""" + value: int - + model = LlamaCppModel() - + # Mock stream that returns invalid JSON - async def mock_stream(*args, **kwargs): + async def mock_stream(*_args, **_kwargs): # Check that json_schema was set correctly - assert model.config["params"]["json_schema"] == TestOutput.model_json_schema() - + assert ( + model.config["params"]["json_schema"] == TestOutput.model_json_schema() + ) + yield {"messageStart": {"role": "assistant"}} yield {"contentBlockStart": {"start": {}}} yield {"contentBlockDelta": {"delta": {"text": "This is not valid JSON"}}} yield {"contentBlockStop": {}} yield {"messageStop": {"stopReason": "end_turn"}} - + with patch.object(model, "stream", side_effect=mock_stream): messages = [{"role": "user", "content": [{"text": "Give me a number"}]}] - + with pytest.raises(json.JSONDecodeError): - async for event in model.structured_output(TestOutput, messages): - pass \ No newline at end of file + async for _ in model.structured_output(TestOutput, messages): + pass + + def test_format_audio_content(self) -> None: + """Test formatting of audio content for llama.cpp multimodal models.""" + model = LlamaCppModel() + + # Create test audio data + audio_bytes = b"fake audio data" + audio_content = {"audio": {"source": {"bytes": audio_bytes}, "format": "wav"}} + + # Format the content + result = model._format_message_content(audio_content) + + # Verify the structure + assert result["type"] == "input_audio" + assert "input_audio" in result + assert "data" in result["input_audio"] + assert "format" in result["input_audio"] + + # Verify the data is base64 encoded + decoded = base64.b64decode(result["input_audio"]["data"]) + assert decoded == audio_bytes + + # Verify format is preserved + assert result["input_audio"]["format"] == "wav" + + def test_format_audio_content_default_format(self) -> None: + """Test audio content formatting uses wav as default format.""" + model = LlamaCppModel() + + audio_content = { + "audio": {"source": {"bytes": b"test audio"}} + # No format specified + } + + result = model._format_message_content(audio_content) + + # Should default to wav + assert result["input_audio"]["format"] == "wav" + + def test_format_messages_with_audio(self) -> None: + """Test that _format_messages properly handles audio content.""" + model = LlamaCppModel() + + # Create messages with audio content + messages = [ + { + "role": "user", + "content": [ + {"text": "Listen to this audio:"}, + {"audio": {"source": {"bytes": b"audio data"}, "format": "mp3"}}, + ], + } + ] + + # Format the messages + result = model._format_messages(messages) + + # Check structure + assert len(result) == 1 + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 2 + + # Check text content + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][0]["text"] == "Listen to this audio:" + + # Check audio content + assert result[0]["content"][1]["type"] == "input_audio" + assert "input_audio" in result[0]["content"][1] + assert result[0]["content"][1]["input_audio"]["format"] == "mp3" + + def test_format_messages_with_system_prompt(self) -> None: + """Test _format_messages includes system prompt.""" + model = LlamaCppModel() + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt = "You are a helpful assistant" + + result = model._format_messages(messages, system_prompt) + + # Should have system message first + assert len(result) == 2 + assert result[0]["role"] == "system" + assert result[0]["content"] == system_prompt + assert result[1]["role"] == "user" + + def test_format_messages_with_image(self) -> None: + """Test that _format_messages properly handles image content.""" + model = LlamaCppModel() + + # Create messages with image content + messages = [ + { + "role": "user", + "content": [ + {"text": "Describe this image:"}, + {"image": {"source": {"bytes": b"image data"}, "format": "png"}}, + ], + } + ] + + # Format the messages + result = model._format_messages(messages) + + # Check structure + assert len(result) == 1 + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 2 + + # Check text content + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][0]["text"] == "Describe this image:" + + # Check image content uses standard format + assert result[0]["content"][1]["type"] == "image_url" + assert "image_url" in result[0]["content"][1] + assert "url" in result[0]["content"][1]["image_url"] + assert result[0]["content"][1]["image_url"]["url"].startswith( + "data:image/png;base64," + ) + + def test_format_messages_with_mixed_content(self) -> None: + """Test that _format_messages handles mixed audio and image content correctly.""" + model = LlamaCppModel() + + # Create messages with both audio and image content + messages = [ + { + "role": "user", + "content": [ + {"text": "Analyze this media:"}, + {"audio": {"source": {"bytes": b"audio data"}, "format": "wav"}}, + {"image": {"source": {"bytes": b"image data"}, "format": "jpg"}}, + ], + } + ] + + # Format the messages + result = model._format_messages(messages) + + # Check structure + assert len(result) == 1 + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 3 + + # Check text content + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][0]["text"] == "Analyze this media:" + + # Check audio content uses llama.cpp specific format + assert result[0]["content"][1]["type"] == "input_audio" + assert "input_audio" in result[0]["content"][1] + assert result[0]["content"][1]["input_audio"]["format"] == "wav" + + # Check image content uses standard OpenAI format + assert result[0]["content"][2]["type"] == "image_url" + assert "image_url" in result[0]["content"][2] + assert result[0]["content"][2]["image_url"]["url"].startswith( + "data:image/jpeg;base64," + ) From edf445dcf207da36994d7e116ddffdb9727974ae Mon Sep 17 00:00:00 2001 From: Aaron Brown Date: Thu, 31 Jul 2025 11:40:24 -0500 Subject: [PATCH 03/11] fixed ruff liniting, black formatting and refined grammar support; --- README.md | 9 ++ src/strands/models/llamacpp.py | 188 +++++++++------------- tests/strands/models/test_llamacpp.py | 50 +++--- tests_integ/models/test_model_llamacpp.py | 166 +++++++++---------- 4 files changed, 190 insertions(+), 223 deletions(-) diff --git a/README.md b/README.md index 62ed54d47..143635475 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,7 @@ from strands import Agent from strands.models import BedrockModel from strands.models.ollama import OllamaModel from strands.models.llamaapi import LlamaAPIModel +from strands.models.llamacpp import LlamaCppModel # Bedrock bedrock_model = BedrockModel( @@ -153,6 +154,14 @@ llama_model = LlamaAPIModel( ) agent = Agent(model=llama_model) response = agent("Tell me about Agentic AI") + +# llama.cpp +llamacpp_model = LlamaCppModel( + base_url="http://localhost:8080", + params={"temperature": 0.7, "max_tokens": 100} +) +agent = Agent(model=llamacpp_model) +response = agent("Tell me about Agentic AI") ``` Built-in providers: diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index 5a0f5340e..3863bc5f4 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -13,7 +13,7 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union +from typing import Any, AsyncGenerator, Dict, Optional, Type, TypedDict, TypeVar, Union, cast import httpx from pydantic import BaseModel @@ -163,9 +163,20 @@ def __init__( self.config = dict(model_config) # Configure HTTP client + if isinstance(timeout, tuple): + # Convert tuple to httpx.Timeout object + timeout_obj = httpx.Timeout( + connect=timeout[0] if len(timeout) > 0 else None, + read=timeout[1] if len(timeout) > 1 else None, + write=timeout[2] if len(timeout) > 2 else None, + pool=timeout[3] if len(timeout) > 3 else None, + ) + else: + timeout_obj = httpx.Timeout(timeout or 30.0) + self.client = httpx.AsyncClient( base_url=self.base_url, - timeout=timeout or 30.0, + timeout=timeout_obj, ) logger.debug( @@ -175,9 +186,7 @@ def __init__( ) @override - def update_config( - self, **model_config: Unpack[LlamaCppConfig] - ) -> None: # type: ignore[override] + def update_config(self, **model_config: Unpack[LlamaCppConfig]) -> None: # type: ignore[override] """Update the llama.cpp model configuration with provided arguments. Args: @@ -218,9 +227,11 @@ def use_grammar_constraint(self, grammar: str) -> None: ... number ::= "-"? [0-9]+ ("." [0-9]+)? ... ''') """ - if not self.config.get("params"): - self.config["params"] = {} - self.config["params"]["grammar"] = grammar + params = self.config.get("params", {}) + if not isinstance(params, dict): + params = {} + params["grammar"] = grammar + self.config["params"] = params logger.debug( "grammar=<%s> | applied grammar constraint", grammar[:50] + "..." if len(grammar) > 50 else grammar, @@ -242,12 +253,14 @@ def use_json_schema(self, schema: dict[str, Any]) -> None: ... "required": ["name", "age"] ... }) """ - if not self.config.get("params"): - self.config["params"] = {} - self.config["params"]["json_schema"] = schema + params = self.config.get("params", {}) + if not isinstance(params, dict): + params = {} + params["json_schema"] = schema + self.config["params"] = params logger.debug("schema=<%s> | applied JSON schema constraint", schema) - def _format_message_content(self, content: ContentBlock) -> dict[str, Any]: + def _format_message_content(self, content: Union[ContentBlock, Dict[str, Any]]) -> dict[str, Any]: """Format a content block for llama.cpp. Args: @@ -260,12 +273,8 @@ def _format_message_content(self, content: ContentBlock) -> dict[str, Any]: TypeError: If the content block type cannot be converted to a compatible format. """ if "document" in content: - mime_type = mimetypes.types_map.get( - f".{content['document']['format']}", "application/octet-stream" - ) - file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode( - "utf-8" - ) + mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") + file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") return { "file": { "file_data": f"data:{mime_type};base64,{file_data}", @@ -275,12 +284,8 @@ def _format_message_content(self, content: ContentBlock) -> dict[str, Any]: } if "image" in content: - mime_type = mimetypes.types_map.get( - f".{content['image']['format']}", "application/octet-stream" - ) - image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode( - "utf-8" - ) + mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") + image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") return { "image_url": { "detail": "auto", @@ -290,11 +295,11 @@ def _format_message_content(self, content: ContentBlock) -> dict[str, Any]: "type": "image_url", } + # Handle audio content (not in standard ContentBlock but supported by llama.cpp) if "audio" in content: - audio_data = base64.b64encode(content["audio"]["source"]["bytes"]).decode( - "utf-8" - ) - audio_format = content["audio"].get("format", "wav") + audio_content = cast(Dict[str, Any], content) + audio_data = base64.b64encode(audio_content["audio"]["source"]["bytes"]).decode("utf-8") + audio_format = audio_content["audio"].get("format", "wav") return { "type": "input_audio", "input_audio": {"data": audio_data, "format": audio_format}, @@ -343,9 +348,7 @@ def _format_tool_message(self, tool_result: dict[str, Any]) -> dict[str, Any]: "content": [self._format_message_content(content) for content in contents], } - def _format_messages( - self, messages: Messages, system_prompt: Optional[str] = None - ) -> list[dict[str, Any]]: + def _format_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: """Format messages for llama.cpp. Args: @@ -355,7 +358,7 @@ def _format_messages( Returns: Formatted messages array compatible with llama.cpp. """ - formatted_messages = [] + formatted_messages: list[dict[str, Any]] = [] # Add system prompt if provided if system_prompt: @@ -367,17 +370,23 @@ def _format_messages( formatted_contents = [ self._format_message_content(content) for content in contents - if not any( - block_type in content for block_type in ["toolResult", "toolUse"] - ) + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) ] formatted_tool_calls = [ - self._format_tool_call(content["toolUse"]) + self._format_tool_call( + { + "name": content["toolUse"]["name"], + "input": content["toolUse"]["input"], + "toolUseId": content["toolUse"]["toolUseId"], + } + ) for content in contents if "toolUse" in content ] formatted_tool_messages = [ - self._format_tool_message(content["toolResult"]) + self._format_tool_message( + {"toolUseId": content["toolResult"]["toolUseId"], "content": content["toolResult"]["content"]} + ) for content in contents if "toolResult" in content ] @@ -385,20 +394,12 @@ def _format_messages( formatted_message = { "role": message["role"], "content": formatted_contents, - **( - {} - if not formatted_tool_calls - else {"tool_calls": formatted_tool_calls} - ), + **({} if not formatted_tool_calls else {"tool_calls": formatted_tool_calls}), } formatted_messages.append(formatted_message) formatted_messages.extend(formatted_tool_messages) - return [ - message - for message in formatted_messages - if message["content"] or "tool_calls" in message - ] + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def _format_request( self, @@ -436,10 +437,16 @@ def _format_request( } # Handle parameters if provided - if self.config.get("params"): - params = self.config["params"] + params = self.config.get("params") + if params and isinstance(params, dict): + # Grammar and json_schema go directly in request body for llama.cpp server + if "grammar" in params: + request["grammar"] = params["grammar"] + if "json_schema" in params: + request["json_schema"] = params["json_schema"] # llama.cpp-specific parameters that must be passed via extra_body + # NOTE: grammar and json_schema removed from this set llamacpp_specific_params = { "repeat_penalty", "top_k", @@ -450,8 +457,6 @@ def _format_request( "mirostat", "mirostat_lr", "mirostat_ent", - "grammar", - "json_schema", "penalty_last_n", "n_probs", "min_keep", @@ -483,7 +488,7 @@ def _format_request( request[param] = value # Collect llama.cpp-specific parameters for extra_body - extra_body = {} + extra_body: Dict[str, Any] = {} for param, value in params.items(): if param in llamacpp_specific_params: extra_body[param] = value @@ -527,20 +532,10 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: case "content_delta": if event["data_type"] == "tool": return { - "contentBlockDelta": { - "delta": { - "toolUse": { - "input": event["data"].function.arguments or "" - } - } - } + "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}} } if event["data_type"] == "reasoning_content": - return { - "contentBlockDelta": { - "delta": {"reasoningContent": {"text": event["data"]}} - } - } + return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}} return {"contentBlockDelta": {"delta": {"text": event["data"]}}} case "content_stop": @@ -606,11 +601,9 @@ async def stream( logger.debug("processing streaming response") yield self._format_chunk({"chunk_type": "message_start"}) - yield self._format_chunk( - {"chunk_type": "content_start", "data_type": "text"} - ) + yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"}) - tool_calls = {} + tool_calls: Dict[int, list] = {} usage_data = None async for line in response.aiter_lines(): @@ -676,9 +669,7 @@ async def stream( "Function", (), { - "name": first_delta.get("function", {}).get( - "name", "" - ), + "name": first_delta.get("function", {}).get("name", ""), }, )(), "id": first_delta.get("id", ""), @@ -700,9 +691,7 @@ async def stream( "Function", (), { - "arguments": tool_delta.get( - "function", {} - ).get("arguments", ""), + "arguments": tool_delta.get("function", {}).get("arguments", ""), }, )(), }, @@ -713,14 +702,8 @@ async def stream( yield self._format_chunk({"chunk_type": "content_stop"}) # Send stop reason - stop_reason = ( - "tool_use" - if tool_calls - else getattr(choice, "finish_reason", "end_turn") - ) - yield self._format_chunk( - {"chunk_type": "message_stop", "data": stop_reason} - ) + stop_reason = "tool_use" if tool_calls else getattr(choice, "finish_reason", "end_turn") + yield self._format_chunk({"chunk_type": "message_stop", "data": stop_reason}) # Send usage metadata if available if usage_data: @@ -732,9 +715,7 @@ async def stream( (), { "prompt_tokens": usage_data.get("prompt_tokens", 0), - "completion_tokens": usage_data.get( - "completion_tokens", 0 - ), + "completion_tokens": usage_data.get("completion_tokens", 0), "total_tokens": usage_data.get("total_tokens", 0), }, )(), @@ -748,24 +729,15 @@ async def stream( # Parse error response from llama.cpp server try: error_data = e.response.json() - error_msg = str( - error_data.get("error", {}).get("message", str(error_data)) - ) + error_msg = str(error_data.get("error", {}).get("message", str(error_data))) except (json.JSONDecodeError, KeyError, AttributeError): error_msg = e.response.text # Check for context overflow by looking for specific error indicators - if any( - term in error_msg.lower() - for term in ["context", "kv cache", "slot"] - ): - raise LlamaCppContextOverflowError( - f"Context window exceeded: {error_msg}" - ) from e + if any(term in error_msg.lower() for term in ["context", "kv cache", "slot"]): + raise LlamaCppContextOverflowError(f"Context window exceeded: {error_msg}") from e elif e.response.status_code == 503: - raise ModelThrottledException( - "llama.cpp server is busy or overloaded" - ) from e + raise ModelThrottledException("llama.cpp server is busy or overloaded") from e raise except Exception as e: # Handle other potential errors like rate limiting @@ -804,27 +776,27 @@ async def structured_output( schema = output_model.model_json_schema() # Store current params to restore later - original_params = self.config.get("params", {}).copy() + params = self.config.get("params", {}) + original_params = dict(params) if isinstance(params, dict) else {} try: # Configure for JSON output with schema constraint - if not self.config.get("params"): - self.config["params"] = {} - - self.config["params"]["json_schema"] = schema - self.config["params"]["cache_prompt"] = True + params = self.config.get("params", {}) + if not isinstance(params, dict): + params = {} + params["json_schema"] = schema + params["cache_prompt"] = True + self.config["params"] = params # Collect the response response_text = "" - async for event in self.stream( - prompt, system_prompt=system_prompt, **kwargs - ): + async for event in self.stream(prompt, system_prompt=system_prompt, **kwargs): if "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] if "text" in delta: response_text += delta["text"] # Forward events to caller - yield event + yield cast(Dict[str, Union[T, Any]], event) # Parse and validate the JSON response data = json.loads(response_text.strip()) diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index 5f967e6db..f200385ee 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -66,9 +66,7 @@ def test_format_request_with_system_prompt(self) -> None: {"role": "user", "content": [{"text": "Hello"}]}, ] - request = model._format_request( - messages, system_prompt="You are a helpful assistant" - ) + request = model._format_request(messages, system_prompt="You are a helpful assistant") assert request["messages"][0]["role"] == "system" assert request["messages"][0]["content"] == "You are a helpful assistant" @@ -97,12 +95,14 @@ def test_format_request_with_llamacpp_params(self) -> None: assert request["temperature"] == 0.8 assert request["max_tokens"] == 50 - # llama.cpp specific params should be in extra_body + # Grammar and json_schema go directly in request for llama.cpp + assert request["grammar"] == "root ::= 'yes' | 'no'" + + # Other llama.cpp specific params should be in extra_body assert "extra_body" in request assert request["extra_body"]["repeat_penalty"] == 1.1 assert request["extra_body"]["top_k"] == 40 assert request["extra_body"]["min_p"] == 0.05 - assert request["extra_body"]["grammar"] == "root ::= 'yes' | 'no'" def test_format_request_with_all_new_params(self) -> None: """Test request formatting with all new llama.cpp parameters.""" @@ -145,7 +145,11 @@ def test_format_request_with_all_new_params(self) -> None: assert request["top_p"] == 0.9 assert request["seed"] == 42 - # Check all llama.cpp params are in extra_body + # Grammar and json_schema go directly in request for llama.cpp + assert request["grammar"] == "root ::= answer" + assert request["json_schema"] == {"type": "object"} + + # Check all other llama.cpp params are in extra_body assert "extra_body" in request extra = request["extra_body"] assert extra["repeat_penalty"] == 1.1 @@ -157,8 +161,6 @@ def test_format_request_with_all_new_params(self) -> None: assert extra["mirostat"] == 2 assert extra["mirostat_lr"] == 0.1 assert extra["mirostat_ent"] == 5.0 - assert extra["grammar"] == "root ::= answer" - assert extra["json_schema"] == {"type": "object"} assert extra["penalty_last_n"] == 256 assert extra["n_probs"] == 5 assert extra["min_keep"] == 1 @@ -253,13 +255,11 @@ async def mock_aiter_lines(): # Verify we got the expected chunks assert any("messageStart" in chunk for chunk in chunks) assert any( - "contentBlockDelta" in chunk - and chunk["contentBlockDelta"]["delta"]["text"] == "Hello" + "contentBlockDelta" in chunk and chunk["contentBlockDelta"]["delta"]["text"] == "Hello" for chunk in chunks ) assert any( - "contentBlockDelta" in chunk - and chunk["contentBlockDelta"]["delta"]["text"] == " world" + "contentBlockDelta" in chunk and chunk["contentBlockDelta"]["delta"]["text"] == " world" for chunk in chunks ) assert any("messageStop" in chunk for chunk in chunks) @@ -361,16 +361,10 @@ async def test_stream_with_context_overflow_error(self) -> None: # Create HTTP error response error_response = httpx.Response( status_code=400, - json={ - "error": { - "message": "Context window exceeded. Max context length is 4096 tokens" - } - }, + json={"error": {"message": "Context window exceeded. Max context length is 4096 tokens"}}, request=httpx.Request("POST", "http://test.com"), ) - error = httpx.HTTPStatusError( - "Bad Request", request=error_response.request, response=error_response - ) + error = httpx.HTTPStatusError("Bad Request", request=error_response.request, response=error_response) # Mock the client to raise the error with patch.object(model.client, "post", side_effect=error): @@ -427,9 +421,7 @@ class TestOutput(BaseModel): # Create mock stream that returns JSON async def mock_stream(*_args, **_kwargs): # Check that json_schema was set correctly - assert ( - model.config["params"]["json_schema"] == TestOutput.model_json_schema() - ) + assert model.config["params"]["json_schema"] == TestOutput.model_json_schema() yield {"messageStart": {"role": "assistant"}} yield {"contentBlockStart": {"start": {}}} @@ -464,9 +456,7 @@ class TestOutput(BaseModel): # Mock stream that returns invalid JSON async def mock_stream(*_args, **_kwargs): # Check that json_schema was set correctly - assert ( - model.config["params"]["json_schema"] == TestOutput.model_json_schema() - ) + assert model.config["params"]["json_schema"] == TestOutput.model_json_schema() yield {"messageStart": {"role": "assistant"}} yield {"contentBlockStart": {"start": {}}} @@ -597,9 +587,7 @@ def test_format_messages_with_image(self) -> None: assert result[0]["content"][1]["type"] == "image_url" assert "image_url" in result[0]["content"][1] assert "url" in result[0]["content"][1]["image_url"] - assert result[0]["content"][1]["image_url"]["url"].startswith( - "data:image/png;base64," - ) + assert result[0]["content"][1]["image_url"]["url"].startswith("data:image/png;base64,") def test_format_messages_with_mixed_content(self) -> None: """Test that _format_messages handles mixed audio and image content correctly.""" @@ -637,6 +625,4 @@ def test_format_messages_with_mixed_content(self) -> None: # Check image content uses standard OpenAI format assert result[0]["content"][2]["type"] == "image_url" assert "image_url" in result[0]["content"][2] - assert result[0]["content"][2]["image_url"]["url"].startswith( - "data:image/jpeg;base64," - ) + assert result[0]["content"][2]["image_url"]["url"].startswith("data:image/jpeg;base64,") diff --git a/tests_integ/models/test_model_llamacpp.py b/tests_integ/models/test_model_llamacpp.py index d7ecd9195..e1b302504 100644 --- a/tests_integ/models/test_model_llamacpp.py +++ b/tests_integ/models/test_model_llamacpp.py @@ -9,7 +9,6 @@ """ import os -from typing import Any, AsyncGenerator import pytest from pydantic import BaseModel @@ -17,7 +16,6 @@ from strands.models.llamacpp import LlamaCppModel from strands.types.content import Message - # Get server URL from environment or use default LLAMACPP_URL = os.environ.get("LLAMACPP_TEST_URL", "http://localhost:8080/v1") @@ -30,7 +28,7 @@ class WeatherOutput(BaseModel): """Test output model for structured responses.""" - + temperature: float condition: str location: str @@ -44,54 +42,54 @@ async def llamacpp_model() -> LlamaCppModel: class TestLlamaCppModelIntegration: """Integration tests for LlamaCppModel with a real server.""" - + @pytest.mark.asyncio async def test_basic_completion(self, llamacpp_model: LlamaCppModel) -> None: """Test basic text completion.""" messages: list[Message] = [ {"role": "user", "content": [{"text": "Say 'Hello, World!' and nothing else."}]}, ] - + response_text = "" async for event in llamacpp_model.stream(messages): if "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] if "text" in delta: response_text += delta["text"] - + assert "Hello, World!" in response_text - + @pytest.mark.asyncio async def test_system_prompt(self, llamacpp_model: LlamaCppModel) -> None: """Test completion with system prompt.""" messages: list[Message] = [ {"role": "user", "content": [{"text": "Who are you?"}]}, ] - + system_prompt = "You are a helpful AI assistant named Claude." - + response_text = "" async for event in llamacpp_model.stream(messages, system_prompt=system_prompt): if "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] if "text" in delta: response_text += delta["text"] - + # Response should reflect the system prompt assert len(response_text) > 0 assert "assistant" in response_text.lower() or "claude" in response_text.lower() - + @pytest.mark.asyncio async def test_streaming_chunks(self, llamacpp_model: LlamaCppModel) -> None: """Test that streaming returns proper chunk sequence.""" messages: list[Message] = [ {"role": "user", "content": [{"text": "Count from 1 to 3."}]}, ] - + chunk_types = [] async for event in llamacpp_model.stream(messages): chunk_types.append(next(iter(event.keys()))) - + # Verify proper chunk sequence assert chunk_types[0] == "messageStart" assert chunk_types[1] == "contentBlockStart" @@ -99,24 +97,24 @@ async def test_streaming_chunks(self, llamacpp_model: LlamaCppModel) -> None: assert chunk_types[-3] == "contentBlockStop" assert chunk_types[-2] == "messageStop" assert chunk_types[-1] == "metadata" - + @pytest.mark.asyncio async def test_temperature_parameter(self, llamacpp_model: LlamaCppModel) -> None: """Test temperature parameter affects randomness.""" messages: list[Message] = [ {"role": "user", "content": [{"text": "Generate a random word."}]}, ] - + # Low temperature should give more consistent results llamacpp_model.update_config(params={"temperature": 0.1, "seed": 42}) - + response1 = "" async for event in llamacpp_model.stream(messages): if "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] if "text" in delta: response1 += delta["text"] - + # Same seed and low temperature should give similar result response2 = "" async for event in llamacpp_model.stream(messages): @@ -124,21 +122,21 @@ async def test_temperature_parameter(self, llamacpp_model: LlamaCppModel) -> Non delta = event["contentBlockDelta"]["delta"] if "text" in delta: response2 += delta["text"] - + # With low temperature and same seed, responses should be very similar assert len(response1) > 0 assert len(response2) > 0 - + @pytest.mark.asyncio async def test_max_tokens_limit(self, llamacpp_model: LlamaCppModel) -> None: """Test max_tokens parameter limits response length.""" messages: list[Message] = [ {"role": "user", "content": [{"text": "Tell me a very long story about dragons."}]}, ] - + # Set very low token limit llamacpp_model.update_config(params={"max_tokens": 10}) - + token_count = 0 async for event in llamacpp_model.stream(messages): if "metadata" in event: @@ -146,11 +144,11 @@ async def test_max_tokens_limit(self, llamacpp_model: LlamaCppModel) -> None: token_count = usage["outputTokens"] if "messageStop" in event: stop_reason = event["messageStop"]["stopReason"] - + # Should stop due to max_tokens assert token_count <= 15 # Allow small overage due to tokenization assert stop_reason == "max_tokens" - + @pytest.mark.asyncio async def test_structured_output(self, llamacpp_model: LlamaCppModel) -> None: """Test structured output generation.""" @@ -165,28 +163,28 @@ async def test_structured_output(self, llamacpp_model: LlamaCppModel) -> None: ], }, ] - + # Enable JSON response format for structured output llamacpp_model.update_config(params={"response_format": {"type": "json_object"}}) - + result = None async for event in llamacpp_model.structured_output(WeatherOutput, messages): if "output" in event: result = event["output"] - + assert result is not None assert isinstance(result, WeatherOutput) assert isinstance(result.temperature, float) assert isinstance(result.condition, str) assert result.location.lower() == "paris" - + @pytest.mark.asyncio async def test_llamacpp_specific_params(self, llamacpp_model: LlamaCppModel) -> None: """Test llama.cpp specific parameters.""" messages: list[Message] = [ {"role": "user", "content": [{"text": "Say 'test' five times."}]}, ] - + # Use llama.cpp specific parameters llamacpp_model.update_config( params={ @@ -195,55 +193,55 @@ async def test_llamacpp_specific_params(self, llamacpp_model: LlamaCppModel) -> "min_p": 0.1, # Min-p sampling } ) - + response_text = "" async for event in llamacpp_model.stream(messages): if "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] if "text" in delta: response_text += delta["text"] - + # Response should contain "test" but with repetition penalty it might vary assert "test" in response_text.lower() - + @pytest.mark.asyncio async def test_advanced_sampling_params(self, llamacpp_model: LlamaCppModel) -> None: """Test advanced sampling parameters.""" messages: list[Message] = [ {"role": "user", "content": [{"text": "Generate a random sentence about space."}]}, ] - + # Test advanced sampling parameters llamacpp_model.update_config( params={ "temperature": 0.8, "tfs_z": 0.95, # Tail-free sampling - "top_a": 0.1, # Top-a sampling + "top_a": 0.1, # Top-a sampling "typical_p": 0.9, # Typical-p sampling "penalty_last_n": 64, # Penalty context window "min_keep": 1, # Minimum tokens to keep - "samplers": ["top_k", "tfs_z", "typical_p", "top_p", "min_p", "temperature"] + "samplers": ["top_k", "tfs_z", "typical_p", "top_p", "min_p", "temperature"], } ) - + response_text = "" async for event in llamacpp_model.stream(messages): if "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] if "text" in delta: response_text += delta["text"] - + # Should generate something about space assert len(response_text) > 0 assert any(word in response_text.lower() for word in ["space", "star", "planet", "galaxy", "universe"]) - - @pytest.mark.asyncio + + @pytest.mark.asyncio async def test_mirostat_sampling(self, llamacpp_model: LlamaCppModel) -> None: """Test Mirostat sampling modes.""" messages: list[Message] = [ {"role": "user", "content": [{"text": "Write a short poem."}]}, ] - + # Test Mirostat v2 llamacpp_model.update_config( params={ @@ -253,105 +251,106 @@ async def test_mirostat_sampling(self, llamacpp_model: LlamaCppModel) -> None: "seed": 42, # For reproducibility } ) - + response_text = "" async for event in llamacpp_model.stream(messages): if "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] if "text" in delta: response_text += delta["text"] - + # Should generate a poem assert len(response_text) > 20 assert "\n" in response_text # Poems typically have line breaks - + @pytest.mark.asyncio async def test_grammar_constraint(self, llamacpp_model: LlamaCppModel) -> None: """Test grammar constraint feature (llama.cpp specific).""" messages: list[Message] = [ {"role": "user", "content": [{"text": "Is the sky blue? Answer yes or no."}]}, ] - + # Use the new grammar constraint method grammar = """ root ::= answer answer ::= "yes" | "no" """ llamacpp_model.use_grammar_constraint(grammar) - + response_text = "" async for event in llamacpp_model.stream(messages): if "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] if "text" in delta: response_text += delta["text"] - + # Response should be exactly "yes" or "no" assert response_text.strip().lower() in ["yes", "no"] - + @pytest.mark.asyncio async def test_json_schema_constraint(self, llamacpp_model: LlamaCppModel) -> None: """Test JSON schema constraint feature.""" messages: list[Message] = [ - {"role": "user", "content": [{"text": "Describe the weather in JSON format with temperature and description."}]}, + { + "role": "user", + "content": [{"text": "Describe the weather in JSON format with temperature and description."}], + }, ] - + # Use JSON schema constraint schema = { "type": "object", - "properties": { - "temperature": {"type": "number"}, - "description": {"type": "string"} - }, - "required": ["temperature", "description"] + "properties": {"temperature": {"type": "number"}, "description": {"type": "string"}}, + "required": ["temperature", "description"], } llamacpp_model.use_json_schema(schema) - + response_text = "" async for event in llamacpp_model.stream(messages): if "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] if "text" in delta: response_text += delta["text"] - + # Should be valid JSON matching the schema import json + data = json.loads(response_text.strip()) assert "temperature" in data assert "description" in data assert isinstance(data["temperature"], (int, float)) assert isinstance(data["description"], str) - + @pytest.mark.asyncio async def test_logit_bias(self, llamacpp_model: LlamaCppModel) -> None: """Test logit bias feature.""" messages: list[Message] = [ {"role": "user", "content": [{"text": "Choose between 'cat' and 'dog'."}]}, ] - + # This is a simplified test - in reality you'd need to know the actual token IDs # for "cat" and "dog" in the model's vocabulary llamacpp_model.update_config( params={ "logit_bias": { # These are placeholder token IDs - real implementation would need actual token IDs - 1234: 10.0, # Strong positive bias (hypothetical "cat" token) + 1234: 10.0, # Strong positive bias (hypothetical "cat" token) 5678: -10.0, # Strong negative bias (hypothetical "dog" token) }, "seed": 42, # For reproducibility } ) - + response_text = "" async for event in llamacpp_model.stream(messages): if "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] if "text" in delta: response_text += delta["text"] - + # Should generate text (exact behavior depends on actual token IDs) assert len(response_text) > 0 - + @pytest.mark.asyncio async def test_cache_prompt(self, llamacpp_model: LlamaCppModel) -> None: """Test prompt caching feature.""" @@ -359,7 +358,7 @@ async def test_cache_prompt(self, llamacpp_model: LlamaCppModel) -> None: {"role": "system", "content": [{"text": "You are a helpful assistant. Always be concise."}]}, {"role": "user", "content": [{"text": "What is 2+2?"}]}, ] - + # Enable prompt caching llamacpp_model.update_config( params={ @@ -367,7 +366,7 @@ async def test_cache_prompt(self, llamacpp_model: LlamaCppModel) -> None: "slot_id": 0, # Use specific slot for caching } ) - + # First request response1 = "" async for event in llamacpp_model.stream(messages): @@ -375,34 +374,34 @@ async def test_cache_prompt(self, llamacpp_model: LlamaCppModel) -> None: delta = event["contentBlockDelta"]["delta"] if "text" in delta: response1 += delta["text"] - + # Second request with same system prompt should use cache messages2 = [ {"role": "system", "content": [{"text": "You are a helpful assistant. Always be concise."}]}, {"role": "user", "content": [{"text": "What is 3+3?"}]}, ] - + response2 = "" async for event in llamacpp_model.stream(messages2): if "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] if "text" in delta: response2 += delta["text"] - + # Both should give valid responses assert "4" in response1 assert "6" in response2 - + @pytest.mark.asyncio async def test_concurrent_requests(self, llamacpp_model: LlamaCppModel) -> None: """Test handling multiple concurrent requests.""" import asyncio - + async def make_request(prompt: str) -> str: messages: list[Message] = [ {"role": "user", "content": [{"text": prompt}]}, ] - + response = "" async for event in llamacpp_model.stream(messages): if "contentBlockDelta" in event: @@ -410,31 +409,31 @@ async def make_request(prompt: str) -> str: if "text" in delta: response += delta["text"] return response - + # Make concurrent requests prompts = [ "Say 'one'", - "Say 'two'", + "Say 'two'", "Say 'three'", ] - + responses = await asyncio.gather(*[make_request(p) for p in prompts]) - + # Each response should contain the expected number assert "one" in responses[0].lower() assert "two" in responses[1].lower() assert "three" in responses[2].lower() - + @pytest.mark.asyncio async def test_enhanced_structured_output(self, llamacpp_model: LlamaCppModel) -> None: """Test enhanced structured output with native JSON schema support.""" - + class BookInfo(BaseModel): title: str author: str year: int genres: list[str] - + messages: list[Message] = [ { "role": "user", @@ -446,14 +445,14 @@ class BookInfo(BaseModel): ], }, ] - + result = None events = [] async for event in llamacpp_model.structured_output(BookInfo, messages): events.append(event) if "output" in event: result = event["output"] - + # Verify we got structured output assert result is not None assert isinstance(result, BookInfo) @@ -462,10 +461,10 @@ class BookInfo(BaseModel): assert isinstance(result.year, int) and 1900 <= result.year <= 2100 assert isinstance(result.genres, list) and len(result.genres) >= 2 assert all(isinstance(genre, str) for genre in result.genres) - + # Should have streamed events before the output assert len(events) > 1 - + @pytest.mark.asyncio async def test_context_overflow_handling(self, llamacpp_model: LlamaCppModel) -> None: """Test proper handling of context window overflow.""" @@ -474,7 +473,7 @@ async def test_context_overflow_handling(self, llamacpp_model: LlamaCppModel) -> messages: list[Message] = [ {"role": "user", "content": [{"text": f"Summarize this text: {long_text}"}]}, ] - + try: response_text = "" async for event in llamacpp_model.stream(messages): @@ -482,14 +481,15 @@ async def test_context_overflow_handling(self, llamacpp_model: LlamaCppModel) -> delta = event["contentBlockDelta"]["delta"] if "text" in delta: response_text += delta["text"] - + # If it succeeds, we got a response assert len(response_text) > 0 except Exception as e: # If it fails, it should be our custom error from strands.models.llamacpp import LlamaCppContextOverflowError + if isinstance(e, LlamaCppContextOverflowError): assert "context" in str(e).lower() else: # Some other error - re-raise to see what it was - raise \ No newline at end of file + raise From abbc4605b1f8abe2aad3bf923b3fa2b2810af0ad Mon Sep 17 00:00:00 2001 From: Aaron Brown Date: Thu, 31 Jul 2025 12:28:01 -0500 Subject: [PATCH 04/11] fixed tool calling bug; --- src/strands/models/llamacpp.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index 3863bc5f4..149d7381d 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -13,7 +13,17 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Dict, Optional, Type, TypedDict, TypeVar, Union, cast +from typing import ( + Any, + AsyncGenerator, + Dict, + Optional, + Type, + TypedDict, + TypeVar, + Union, + cast, +) import httpx from pydantic import BaseModel @@ -385,7 +395,10 @@ def _format_messages(self, messages: Messages, system_prompt: Optional[str] = No ] formatted_tool_messages = [ self._format_tool_message( - {"toolUseId": content["toolResult"]["toolUseId"], "content": content["toolResult"]["content"]} + { + "toolUseId": content["toolResult"]["toolUseId"], + "content": content["toolResult"]["content"], + } ) for content in contents if "toolResult" in content @@ -605,6 +618,7 @@ async def stream( tool_calls: Dict[int, list] = {} usage_data = None + finish_reason = None async for line in response.aiter_lines(): if not line.strip() or not line.startswith("data: "): @@ -650,6 +664,7 @@ async def stream( # Check for finish reason if choice.get("finish_reason"): + finish_reason = choice.get("finish_reason") break yield self._format_chunk({"chunk_type": "content_stop"}) @@ -702,7 +717,12 @@ async def stream( yield self._format_chunk({"chunk_type": "content_stop"}) # Send stop reason - stop_reason = "tool_use" if tool_calls else getattr(choice, "finish_reason", "end_turn") + logger.debug("finish_reason=%s, tool_calls=%s", finish_reason, bool(tool_calls)) + if finish_reason == "tool_calls" or tool_calls: + stop_reason = "tool_calls" # Changed from "tool_use" to match format_chunk expectations + else: + stop_reason = finish_reason or "end_turn" + logger.debug("stop_reason=%s", stop_reason) yield self._format_chunk({"chunk_type": "message_stop", "data": stop_reason}) # Send usage metadata if available From 0fe024909ae5369c9f927e75192db032e066f27c Mon Sep 17 00:00:00 2001 From: Aaron Brown Date: Mon, 4 Aug 2025 10:36:05 -0500 Subject: [PATCH 05/11] fix to bug 565; --- src/strands/tools/decorator.py | 144 ++++++++++++++++++++++++++++++--- 1 file changed, 134 insertions(+), 10 deletions(-) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 5ec324b68..ffac02ead 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -66,6 +66,98 @@ def my_tool(param1: str, param2: int = 42) -> dict: logger = logging.getLogger(__name__) +def _resolve_json_schema_references(schema: dict[str, Any]) -> dict[str, Any]: + """Resolve all $ref references in a JSON schema by inlining definitions. + + Some model providers (e.g., Bedrock via LiteLLM) don't support JSON Schema + $ref references. This function flattens the schema by replacing all $ref + occurrences with their actual definitions from the $defs section. + + This is particularly important for Pydantic-generated schemas that use $defs + for enum types, as these would otherwise cause validation errors with certain + model providers. + + Args: + schema: A JSON schema dict that may contain $ref references and a $defs section + + Returns: + A new schema dict with all $ref references replaced by their definitions. + The $defs section is removed from the result. + + Example: + Input schema with $ref: + { + "$defs": {"Color": {"type": "string", "enum": ["red", "blue"]}}, + "properties": {"color": {"$ref": "#/$defs/Color"}} + } + + Output schema with resolved reference: + { + "properties": {"color": {"type": "string", "enum": ["red", "blue"]}} + } + """ + # Get definitions if they exist + defs = schema.get("$defs", {}) + if not defs: + return schema + + def resolve_node(node: Any) -> Any: + """Recursively process a schema node, replacing any $ref with actual definitions. + + Args: + node: Any value from the schema (dict, list, or primitive) + + Returns: + The node with all $ref references resolved + """ + if not isinstance(node, dict): + return node + + # If this node is a $ref, replace it with the referenced definition + if "$ref" in node: + # Extract the definition name from the reference (e.g., "#/$defs/Color" -> "Color") + ref_name = node["$ref"].split("/")[-1] + if ref_name in defs: + # Copy the referenced definition to avoid modifying the original + resolved = defs[ref_name].copy() + # Preserve any additional properties from the $ref node (e.g., "default", "description") + for key, value in node.items(): + if key != "$ref": + resolved[key] = value + # Recursively resolve in case the definition itself contains references + return resolve_node(resolved) + # If reference not found, return as-is (shouldn't happen with valid schemas) + return node + + # For dict nodes, recursively process all values + result: dict[str, Any] = {} + for key, value in node.items(): + if isinstance(value, list): + # For arrays, resolve each item + result[key] = [resolve_node(item) for item in value] + elif isinstance(value, dict): + # For objects, check if this is a properties dict that needs special handling + if key == "properties" and isinstance(value, dict): + # Ensure all property definitions are fully resolved + result[key] = { + prop_name: resolve_node(prop_schema) + for prop_name, prop_schema in value.items() + } + else: + result[key] = resolve_node(value) + else: + # Primitive values are copied as-is + result[key] = value + return result + + # Process the entire schema, excluding the $defs section from the result + result = { + key: resolve_node(value) for key, value in schema.items() if key != "$defs" + } + + return result + + # Type for wrapped function T = TypeVar("T", bound=Callable[..., Any]) @@ -101,7 +193,8 @@ def __init__(self, func: Callable[..., Any]) -> None: # Get parameter descriptions from parsed docstring self.param_descriptions = { - param.arg_name: param.description or f"Parameter {param.arg_name}" for param in self.doc.params + param.arg_name: param.description or f"Parameter {param.arg_name}" + for param in self.doc.params } # Create a Pydantic model for validation @@ -131,7 +224,10 @@ def _create_input_model(self) -> Type[BaseModel]: description = self.param_descriptions.get(name, f"Parameter {name}") # Create Field with description and default - field_definitions[name] = (param_type, Field(default=default, description=description)) + field_definitions[name] = ( + param_type, + Field(default=default, description=description), + ) # Create model name based on function name model_name = f"{self.func.__name__.capitalize()}Tool" @@ -173,8 +269,17 @@ def extract_metadata(self) -> ToolSpec: # Clean up Pydantic-specific schema elements self._clean_pydantic_schema(input_schema) + # Flatten schema by resolving $ref references to their definitions + # This is required for compatibility with model providers that don't support + # JSON Schema $ref (e.g., Bedrock/Anthropic via LiteLLM) + input_schema = _resolve_json_schema_references(input_schema) + # Create tool specification - tool_spec: ToolSpec = {"name": func_name, "description": description, "inputSchema": {"json": input_schema}} + tool_spec: ToolSpec = { + "name": func_name, + "description": description, + "inputSchema": {"json": input_schema}, + } return tool_spec @@ -206,7 +311,9 @@ def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None: if "anyOf" in prop_schema: any_of = prop_schema["anyOf"] # Handle Optional[Type] case (represented as anyOf[Type, null]) - if len(any_of) == 2 and any(item.get("type") == "null" for item in any_of): + if len(any_of) == 2 and any( + item.get("type") == "null" for item in any_of + ): # Find the non-null type for item in any_of: if item.get("type") != "null": @@ -250,7 +357,9 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]: except Exception as e: # Re-raise with more detailed error message error_msg = str(e) - raise ValueError(f"Validation failed for input parameters: {error_msg}") from e + raise ValueError( + f"Validation failed for input parameters: {error_msg}" + ) from e P = ParamSpec("P") # Captures all parameters @@ -296,7 +405,9 @@ def __init__( functools.update_wrapper(wrapper=self, wrapped=self._tool_func) - def __get__(self, instance: Any, obj_type: Optional[Type] = None) -> "DecoratedFunctionTool[P, R]": + def __get__( + self, instance: Any, obj_type: Optional[Type] = None + ) -> "DecoratedFunctionTool[P, R]": """Descriptor protocol implementation for proper method binding. This method enables the decorated function to work correctly when used as a class method. @@ -325,7 +436,9 @@ def my_tool(): if instance is not None and not inspect.ismethod(self._tool_func): # Create a bound method tool_func = self._tool_func.__get__(instance, instance.__class__) - return DecoratedFunctionTool(self._tool_name, self._tool_spec, tool_func, self._metadata) + return DecoratedFunctionTool( + self._tool_name, self._tool_spec, tool_func, self._metadata + ) return self @@ -372,7 +485,9 @@ def tool_type(self) -> str: return "function" @override - async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: + async def stream( + self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any + ) -> ToolGenerator: """Stream the tool with a tool use specification. This method handles tool use streams from a Strands Agent. It validates the input, @@ -403,7 +518,10 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw validated_input = self._metadata.validate_input(tool_input) # Pass along the agent if provided and expected by the function - if "agent" in invocation_state and "agent" in self._metadata.signature.parameters: + if ( + "agent" in invocation_state + and "agent" in self._metadata.signature.parameters + ): validated_input["agent"] = invocation_state.get("agent") # "Too few arguments" expected, hence the type ignore @@ -468,6 +586,8 @@ def get_display_properties(self) -> dict[str, str]: # Handle @decorator @overload def tool(__func: Callable[P, R]) -> DecoratedFunctionTool[P, R]: ... + + # Handle @decorator() @overload def tool( @@ -475,6 +595,8 @@ def tool( inputSchema: Optional[JSONSchema] = None, name: Optional[str] = None, ) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ... + + # Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the # call site, but the actual implementation handles that and it's not representable via the type-system def tool( # type: ignore @@ -482,7 +604,9 @@ def tool( # type: ignore description: Optional[str] = None, inputSchema: Optional[JSONSchema] = None, name: Optional[str] = None, -) -> Union[DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]]: +) -> Union[ + DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]] +]: """Decorator that transforms a Python function into a Strands tool. This decorator seamlessly enables a function to be called both as a regular Python function and as a Strands tool. From aed40e2dcd8ef2ed85df6e8f9a650df02393a64c Mon Sep 17 00:00:00 2001 From: Aaron Brown Date: Mon, 4 Aug 2025 10:42:15 -0500 Subject: [PATCH 06/11] Revert "fix to bug 565;" This reverts commit 0fe024909ae5369c9f927e75192db032e066f27c. --- src/strands/tools/decorator.py | 144 +++------------------------------ 1 file changed, 10 insertions(+), 134 deletions(-) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index ffac02ead..5ec324b68 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -66,98 +66,6 @@ def my_tool(param1: str, param2: int = 42) -> dict: logger = logging.getLogger(__name__) -def _resolve_json_schema_references(schema: dict[str, Any]) -> dict[str, Any]: - """Resolve all $ref references in a JSON schema by inlining definitions. - - Some model providers (e.g., Bedrock via LiteLLM) don't support JSON Schema - $ref references. This function flattens the schema by replacing all $ref - occurrences with their actual definitions from the $defs section. - - This is particularly important for Pydantic-generated schemas that use $defs - for enum types, as these would otherwise cause validation errors with certain - model providers. - - Args: - schema: A JSON schema dict that may contain $ref references and a $defs section - - Returns: - A new schema dict with all $ref references replaced by their definitions. - The $defs section is removed from the result. - - Example: - Input schema with $ref: - { - "$defs": {"Color": {"type": "string", "enum": ["red", "blue"]}}, - "properties": {"color": {"$ref": "#/$defs/Color"}} - } - - Output schema with resolved reference: - { - "properties": {"color": {"type": "string", "enum": ["red", "blue"]}} - } - """ - # Get definitions if they exist - defs = schema.get("$defs", {}) - if not defs: - return schema - - def resolve_node(node: Any) -> Any: - """Recursively process a schema node, replacing any $ref with actual definitions. - - Args: - node: Any value from the schema (dict, list, or primitive) - - Returns: - The node with all $ref references resolved - """ - if not isinstance(node, dict): - return node - - # If this node is a $ref, replace it with the referenced definition - if "$ref" in node: - # Extract the definition name from the reference (e.g., "#/$defs/Color" -> "Color") - ref_name = node["$ref"].split("/")[-1] - if ref_name in defs: - # Copy the referenced definition to avoid modifying the original - resolved = defs[ref_name].copy() - # Preserve any additional properties from the $ref node (e.g., "default", "description") - for key, value in node.items(): - if key != "$ref": - resolved[key] = value - # Recursively resolve in case the definition itself contains references - return resolve_node(resolved) - # If reference not found, return as-is (shouldn't happen with valid schemas) - return node - - # For dict nodes, recursively process all values - result: dict[str, Any] = {} - for key, value in node.items(): - if isinstance(value, list): - # For arrays, resolve each item - result[key] = [resolve_node(item) for item in value] - elif isinstance(value, dict): - # For objects, check if this is a properties dict that needs special handling - if key == "properties" and isinstance(value, dict): - # Ensure all property definitions are fully resolved - result[key] = { - prop_name: resolve_node(prop_schema) - for prop_name, prop_schema in value.items() - } - else: - result[key] = resolve_node(value) - else: - # Primitive values are copied as-is - result[key] = value - return result - - # Process the entire schema, excluding the $defs section from the result - result = { - key: resolve_node(value) for key, value in schema.items() if key != "$defs" - } - - return result - - # Type for wrapped function T = TypeVar("T", bound=Callable[..., Any]) @@ -193,8 +101,7 @@ def __init__(self, func: Callable[..., Any]) -> None: # Get parameter descriptions from parsed docstring self.param_descriptions = { - param.arg_name: param.description or f"Parameter {param.arg_name}" - for param in self.doc.params + param.arg_name: param.description or f"Parameter {param.arg_name}" for param in self.doc.params } # Create a Pydantic model for validation @@ -224,10 +131,7 @@ def _create_input_model(self) -> Type[BaseModel]: description = self.param_descriptions.get(name, f"Parameter {name}") # Create Field with description and default - field_definitions[name] = ( - param_type, - Field(default=default, description=description), - ) + field_definitions[name] = (param_type, Field(default=default, description=description)) # Create model name based on function name model_name = f"{self.func.__name__.capitalize()}Tool" @@ -269,17 +173,8 @@ def extract_metadata(self) -> ToolSpec: # Clean up Pydantic-specific schema elements self._clean_pydantic_schema(input_schema) - # Flatten schema by resolving $ref references to their definitions - # This is required for compatibility with model providers that don't support - # JSON Schema $ref (e.g., Bedrock/Anthropic via LiteLLM) - input_schema = _resolve_json_schema_references(input_schema) - # Create tool specification - tool_spec: ToolSpec = { - "name": func_name, - "description": description, - "inputSchema": {"json": input_schema}, - } + tool_spec: ToolSpec = {"name": func_name, "description": description, "inputSchema": {"json": input_schema}} return tool_spec @@ -311,9 +206,7 @@ def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None: if "anyOf" in prop_schema: any_of = prop_schema["anyOf"] # Handle Optional[Type] case (represented as anyOf[Type, null]) - if len(any_of) == 2 and any( - item.get("type") == "null" for item in any_of - ): + if len(any_of) == 2 and any(item.get("type") == "null" for item in any_of): # Find the non-null type for item in any_of: if item.get("type") != "null": @@ -357,9 +250,7 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]: except Exception as e: # Re-raise with more detailed error message error_msg = str(e) - raise ValueError( - f"Validation failed for input parameters: {error_msg}" - ) from e + raise ValueError(f"Validation failed for input parameters: {error_msg}") from e P = ParamSpec("P") # Captures all parameters @@ -405,9 +296,7 @@ def __init__( functools.update_wrapper(wrapper=self, wrapped=self._tool_func) - def __get__( - self, instance: Any, obj_type: Optional[Type] = None - ) -> "DecoratedFunctionTool[P, R]": + def __get__(self, instance: Any, obj_type: Optional[Type] = None) -> "DecoratedFunctionTool[P, R]": """Descriptor protocol implementation for proper method binding. This method enables the decorated function to work correctly when used as a class method. @@ -436,9 +325,7 @@ def my_tool(): if instance is not None and not inspect.ismethod(self._tool_func): # Create a bound method tool_func = self._tool_func.__get__(instance, instance.__class__) - return DecoratedFunctionTool( - self._tool_name, self._tool_spec, tool_func, self._metadata - ) + return DecoratedFunctionTool(self._tool_name, self._tool_spec, tool_func, self._metadata) return self @@ -485,9 +372,7 @@ def tool_type(self) -> str: return "function" @override - async def stream( - self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any - ) -> ToolGenerator: + async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: """Stream the tool with a tool use specification. This method handles tool use streams from a Strands Agent. It validates the input, @@ -518,10 +403,7 @@ async def stream( validated_input = self._metadata.validate_input(tool_input) # Pass along the agent if provided and expected by the function - if ( - "agent" in invocation_state - and "agent" in self._metadata.signature.parameters - ): + if "agent" in invocation_state and "agent" in self._metadata.signature.parameters: validated_input["agent"] = invocation_state.get("agent") # "Too few arguments" expected, hence the type ignore @@ -586,8 +468,6 @@ def get_display_properties(self) -> dict[str, str]: # Handle @decorator @overload def tool(__func: Callable[P, R]) -> DecoratedFunctionTool[P, R]: ... - - # Handle @decorator() @overload def tool( @@ -595,8 +475,6 @@ def tool( inputSchema: Optional[JSONSchema] = None, name: Optional[str] = None, ) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ... - - # Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the # call site, but the actual implementation handles that and it's not representable via the type-system def tool( # type: ignore @@ -604,9 +482,7 @@ def tool( # type: ignore description: Optional[str] = None, inputSchema: Optional[JSONSchema] = None, name: Optional[str] = None, -) -> Union[ - DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]] -]: +) -> Union[DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]]: """Decorator that transforms a Python function into a Strands tool. This decorator seamlessly enables a function to be called both as a regular Python function and as a Strands tool. From 9cf8e25270fbfb1de2514a2cf8c9dec252d5b1ec Mon Sep 17 00:00:00 2001 From: Aaron Brown Date: Mon, 18 Aug 2025 09:43:27 -0500 Subject: [PATCH 07/11] Removed public grammar/schema methods, now use standard params pattern; --- src/strands/models/llamacpp.py | 90 +++-------------------- tests/strands/models/test_llamacpp.py | 30 +++----- tests_integ/models/test_model_llamacpp.py | 12 +-- 3 files changed, 29 insertions(+), 103 deletions(-) diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index 149d7381d..d673f0f44 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -13,6 +13,7 @@ import json import logging import mimetypes +import time from typing import ( Any, AsyncGenerator, @@ -40,25 +41,6 @@ T = TypeVar("T", bound=BaseModel) -class LlamaCppError(Exception): - """Base exception for llama.cpp specific errors. - - This exception serves as the base class for all llama.cpp-specific errors, - allowing for targeted error handling in client code. - """ - - -class LlamaCppContextOverflowError(LlamaCppError, ContextWindowOverflowException): - """Raised when context window is exceeded in llama.cpp. - - This error occurs when the combined input and output tokens exceed - the model's context window size. Common causes include: - - Long input prompts - - Extended conversations - - Large system prompts or tool definitions - """ - - class LlamaCppModel(Model): """llama.cpp model provider implementation. @@ -213,62 +195,7 @@ def get_config(self) -> LlamaCppConfig: """ return self.config # type: ignore[return-value] - def use_grammar_constraint(self, grammar: str) -> None: - r"""Apply a GBNF grammar constraint to the generation. - - Args: - grammar: GBNF (Backus-Naur Form) grammar string defining allowed outputs. - See https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md - - Example: - >>> # Constrain output to yes/no answers - >>> model.use_grammar_constraint(''' - ... root ::= answer - ... answer ::= "yes" | "no" - ... ''') - - >>> # JSON object grammar - >>> model.use_grammar_constraint(''' - ... root ::= object - ... object ::= "{" pair ("," pair)* "}" - ... pair ::= string ":" value - ... string ::= "\"" [^"]* "\"" - ... value ::= string | number | "true" | "false" | "null" - ... number ::= "-"? [0-9]+ ("." [0-9]+)? - ... ''') - """ - params = self.config.get("params", {}) - if not isinstance(params, dict): - params = {} - params["grammar"] = grammar - self.config["params"] = params - logger.debug( - "grammar=<%s> | applied grammar constraint", - grammar[:50] + "..." if len(grammar) > 50 else grammar, - ) - def use_json_schema(self, schema: dict[str, Any]) -> None: - """Apply a JSON schema constraint for structured output. - - Args: - schema: JSON schema dictionary defining the expected output structure. - - Example: - >>> model.use_json_schema({ - ... "type": "object", - ... "properties": { - ... "name": {"type": "string"}, - ... "age": {"type": "integer", "minimum": 0} - ... }, - ... "required": ["name", "age"] - ... }) - """ - params = self.config.get("params", {}) - if not isinstance(params, dict): - params = {} - params["json_schema"] = schema - self.config["params"] = params - logger.debug("schema=<%s> | applied JSON schema constraint", schema) def _format_message_content(self, content: Union[ContentBlock, Dict[str, Any]]) -> dict[str, Any]: """Format a content block for llama.cpp. @@ -459,7 +386,8 @@ def _format_request( request["json_schema"] = params["json_schema"] # llama.cpp-specific parameters that must be passed via extra_body - # NOTE: grammar and json_schema removed from this set + # NOTE: grammar and json_schema are NOT in this set because llama.cpp server + # expects them directly in the request body for proper constraint application llamacpp_specific_params = { "repeat_penalty", "top_k", @@ -572,7 +500,7 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: "totalTokens": event["data"].total_tokens, }, "metrics": { - "latencyMs": 0, # TODO: Add actual latency calculation + "latencyMs": event.get("latency_ms", 0), }, }, } @@ -600,9 +528,12 @@ async def stream( Formatted message chunks from the model. Raises: - LlamaCppContextOverflowError: When the context window is exceeded. + ContextWindowOverflowException: When the context window is exceeded. ModelThrottledException: When the llama.cpp server is overloaded. """ + # Track request start time for latency calculation + start_time = time.perf_counter() + try: logger.debug("formatting request for llama.cpp server") request = self._format_request(messages, tool_specs, system_prompt) @@ -727,6 +658,8 @@ async def stream( # Send usage metadata if available if usage_data: + # Calculate latency + latency_ms = int((time.perf_counter() - start_time) * 1000) yield self._format_chunk( { "chunk_type": "metadata", @@ -739,6 +672,7 @@ async def stream( "total_tokens": usage_data.get("total_tokens", 0), }, )(), + "latency_ms": latency_ms, } ) @@ -755,7 +689,7 @@ async def stream( # Check for context overflow by looking for specific error indicators if any(term in error_msg.lower() for term in ["context", "kv cache", "slot"]): - raise LlamaCppContextOverflowError(f"Context window exceeded: {error_msg}") from e + raise ContextWindowOverflowException(f"Context window exceeded: {error_msg}") from e elif e.response.status_code == 503: raise ModelThrottledException("llama.cpp server is busy or overloaded") from e raise diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index f200385ee..6f0bcebe6 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -8,11 +8,9 @@ import pytest from pydantic import BaseModel -from strands.models.llamacpp import ( - LlamaCppContextOverflowError, - LlamaCppModel, -) +from strands.models.llamacpp import LlamaCppModel from strands.types.exceptions import ( + ContextWindowOverflowException, ModelThrottledException, ) @@ -320,36 +318,30 @@ def test_max_retries_configuration(self) -> None: model = LlamaCppModel() assert model.config["model_id"] == "default" - def test_use_grammar_constraint(self) -> None: - """Test grammar constraint method.""" - model = LlamaCppModel() - - # Apply grammar constraint + def test_grammar_constraint_via_params(self) -> None: + """Test grammar constraint via params.""" grammar = """ root ::= answer answer ::= "yes" | "no" """ - model.use_grammar_constraint(grammar) + model = LlamaCppModel(params={"grammar": grammar}) assert model.config["params"]["grammar"] == grammar - # Update grammar + # Update grammar via update_config new_grammar = "root ::= [0-9]+" - model.use_grammar_constraint(new_grammar) + model.update_config(params={"grammar": new_grammar}) assert model.config["params"]["grammar"] == new_grammar - def test_use_json_schema(self) -> None: - """Test JSON schema constraint method.""" - model = LlamaCppModel() - - # Apply JSON schema + def test_json_schema_via_params(self) -> None: + """Test JSON schema constraint via params.""" schema = { "type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, "required": ["name", "age"], } - model.use_json_schema(schema) + model = LlamaCppModel(params={"json_schema": schema}) assert model.config["params"]["json_schema"] == schema @@ -370,7 +362,7 @@ async def test_stream_with_context_overflow_error(self) -> None: with patch.object(model.client, "post", side_effect=error): messages = [{"role": "user", "content": [{"text": "Very long message"}]}] - with pytest.raises(LlamaCppContextOverflowError) as exc_info: + with pytest.raises(ContextWindowOverflowException) as exc_info: async for _ in model.stream(messages): pass diff --git a/tests_integ/models/test_model_llamacpp.py b/tests_integ/models/test_model_llamacpp.py index e1b302504..131f0fb88 100644 --- a/tests_integ/models/test_model_llamacpp.py +++ b/tests_integ/models/test_model_llamacpp.py @@ -270,12 +270,12 @@ async def test_grammar_constraint(self, llamacpp_model: LlamaCppModel) -> None: {"role": "user", "content": [{"text": "Is the sky blue? Answer yes or no."}]}, ] - # Use the new grammar constraint method + # Set grammar constraint via params grammar = """ root ::= answer answer ::= "yes" | "no" """ - llamacpp_model.use_grammar_constraint(grammar) + llamacpp_model.update_config(params={"grammar": grammar}) response_text = "" async for event in llamacpp_model.stream(messages): @@ -297,13 +297,13 @@ async def test_json_schema_constraint(self, llamacpp_model: LlamaCppModel) -> No }, ] - # Use JSON schema constraint + # Set JSON schema constraint via params schema = { "type": "object", "properties": {"temperature": {"type": "number"}, "description": {"type": "string"}}, "required": ["temperature", "description"], } - llamacpp_model.use_json_schema(schema) + llamacpp_model.update_config(params={"json_schema": schema}) response_text = "" async for event in llamacpp_model.stream(messages): @@ -486,9 +486,9 @@ async def test_context_overflow_handling(self, llamacpp_model: LlamaCppModel) -> assert len(response_text) > 0 except Exception as e: # If it fails, it should be our custom error - from strands.models.llamacpp import LlamaCppContextOverflowError + from strands.types.exceptions import ContextWindowOverflowException - if isinstance(e, LlamaCppContextOverflowError): + if isinstance(e, ContextWindowOverflowException): assert "context" in str(e).lower() else: # Some other error - re-raise to see what it was From 02e555366986e2160d37f94d3dc172bc3b82b00f Mon Sep 17 00:00:00 2001 From: Aaron Brown Date: Mon, 18 Aug 2025 09:45:45 -0500 Subject: [PATCH 08/11] cleaned up README per feedback; --- README.md | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/README.md b/README.md index 143635475..44d10b67e 100644 --- a/README.md +++ b/README.md @@ -154,20 +154,13 @@ llama_model = LlamaAPIModel( ) agent = Agent(model=llama_model) response = agent("Tell me about Agentic AI") - -# llama.cpp -llamacpp_model = LlamaCppModel( - base_url="http://localhost:8080", - params={"temperature": 0.7, "max_tokens": 100} -) -agent = Agent(model=llamacpp_model) -response = agent("Tell me about Agentic AI") ``` Built-in providers: - [Amazon Bedrock](https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/) - [Anthropic](https://strandsagents.com/latest/user-guide/concepts/model-providers/anthropic/) - [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/) + - [llama.cpp](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamacpp/) - [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/) - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/) - [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) From 7a9e462c9d18a6880e0242895ef988e5bd3fac49 Mon Sep 17 00:00:00 2001 From: Aaron Brown Date: Mon, 18 Aug 2025 09:54:47 -0500 Subject: [PATCH 09/11] Removed LlamaCppModel from models/__init__.py; --- src/strands/models/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index 35036203f..ead290a35 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -3,9 +3,8 @@ This package includes an abstract base Model class along with concrete implementations for specific providers. """ -from . import bedrock, llamacpp, model +from . import bedrock, model from .bedrock import BedrockModel -from .llamacpp import LlamaCppModel from .model import Model -__all__ = ["bedrock", "llamacpp", "model", "BedrockModel", "LlamaCppModel", "Model"] +__all__ = ["bedrock", "model", "BedrockModel", "Model"] From 5fd3c35d9e19288c30a6a86aa2ff4e946c97da02 Mon Sep 17 00:00:00 2001 From: Aaron Brown Date: Mon, 18 Aug 2025 09:55:41 -0500 Subject: [PATCH 10/11] Refactor integration tests to flat functions matching other provider tests; --- src/strands/models/llamacpp.py | 4 +- tests/strands/models/test_llamacpp.py | 1133 +++++++++++---------- tests_integ/models/test_model_llamacpp.py | 875 ++++++++-------- 3 files changed, 1022 insertions(+), 990 deletions(-) diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index d673f0f44..73b248462 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -195,8 +195,6 @@ def get_config(self) -> LlamaCppConfig: """ return self.config # type: ignore[return-value] - - def _format_message_content(self, content: Union[ContentBlock, Dict[str, Any]]) -> dict[str, Any]: """Format a content block for llama.cpp. @@ -533,7 +531,7 @@ async def stream( """ # Track request start time for latency calculation start_time = time.perf_counter() - + try: logger.debug("formatting request for llama.cpp server") request = self._format_request(messages, tool_specs, system_prompt) diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index 6f0bcebe6..e5b2614c0 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -15,606 +15,625 @@ ) -class TestLlamaCppModel: - """Test suite for LlamaCppModel.""" - - def test_init_default_config(self) -> None: - """Test initialization with default configuration.""" - model = LlamaCppModel() - - assert model.config["model_id"] == "default" - assert isinstance(model.client, httpx.AsyncClient) - assert model.base_url == "http://localhost:8080" - - def test_init_custom_config(self) -> None: - """Test initialization with custom configuration.""" - model = LlamaCppModel( - base_url="http://example.com:8081", - model_id="llama-3-8b", - params={"temperature": 0.7, "max_tokens": 100}, - ) +def test_init_default_config() -> None: + """Test initialization with default configuration.""" + model = LlamaCppModel() - assert model.config["model_id"] == "llama-3-8b" - assert model.config["params"]["temperature"] == 0.7 - assert model.config["params"]["max_tokens"] == 100 - assert model.base_url == "http://example.com:8081" - - def test_format_request_basic(self) -> None: - """Test basic request formatting.""" - model = LlamaCppModel(model_id="test-model") - - messages = [ - {"role": "user", "content": [{"text": "Hello"}]}, - ] - - request = model._format_request(messages) - - assert request["model"] == "test-model" - assert request["messages"][0]["role"] == "user" - assert request["messages"][0]["content"][0]["type"] == "text" - assert request["messages"][0]["content"][0]["text"] == "Hello" - assert request["stream"] is True - assert "extra_body" not in request - - def test_format_request_with_system_prompt(self) -> None: - """Test request formatting with system prompt.""" - model = LlamaCppModel() - - messages = [ - {"role": "user", "content": [{"text": "Hello"}]}, - ] - - request = model._format_request(messages, system_prompt="You are a helpful assistant") - - assert request["messages"][0]["role"] == "system" - assert request["messages"][0]["content"] == "You are a helpful assistant" - assert request["messages"][1]["role"] == "user" - - def test_format_request_with_llamacpp_params(self) -> None: - """Test request formatting with llama.cpp specific parameters.""" - model = LlamaCppModel( - params={ - "temperature": 0.8, - "max_tokens": 50, - "repeat_penalty": 1.1, - "top_k": 40, - "min_p": 0.05, - "grammar": "root ::= 'yes' | 'no'", - } - ) + assert model.config["model_id"] == "default" + assert isinstance(model.client, httpx.AsyncClient) + assert model.base_url == "http://localhost:8080" - messages = [ - {"role": "user", "content": [{"text": "Is the sky blue?"}]}, - ] - - request = model._format_request(messages) - - # Standard OpenAI params - assert request["temperature"] == 0.8 - assert request["max_tokens"] == 50 - - # Grammar and json_schema go directly in request for llama.cpp - assert request["grammar"] == "root ::= 'yes' | 'no'" - - # Other llama.cpp specific params should be in extra_body - assert "extra_body" in request - assert request["extra_body"]["repeat_penalty"] == 1.1 - assert request["extra_body"]["top_k"] == 40 - assert request["extra_body"]["min_p"] == 0.05 - - def test_format_request_with_all_new_params(self) -> None: - """Test request formatting with all new llama.cpp parameters.""" - model = LlamaCppModel( - params={ - # OpenAI params - "temperature": 0.7, - "max_tokens": 100, - "top_p": 0.9, - "seed": 42, - # All llama.cpp specific params - "repeat_penalty": 1.1, - "top_k": 40, - "min_p": 0.05, - "typical_p": 0.95, - "tfs_z": 0.97, - "top_a": 0.1, - "mirostat": 2, - "mirostat_lr": 0.1, - "mirostat_ent": 5.0, - "grammar": "root ::= answer", - "json_schema": {"type": "object"}, - "penalty_last_n": 256, - "n_probs": 5, - "min_keep": 1, - "ignore_eos": False, - "logit_bias": {100: 5.0, 200: -5.0}, - "cache_prompt": True, - "slot_id": 1, - "samplers": ["top_k", "tfs_z", "typical_p"], - } - ) - messages = [{"role": "user", "content": [{"text": "Test"}]}] - request = model._format_request(messages) - - # Check OpenAI params are in root - assert request["temperature"] == 0.7 - assert request["max_tokens"] == 100 - assert request["top_p"] == 0.9 - assert request["seed"] == 42 - - # Grammar and json_schema go directly in request for llama.cpp - assert request["grammar"] == "root ::= answer" - assert request["json_schema"] == {"type": "object"} - - # Check all other llama.cpp params are in extra_body - assert "extra_body" in request - extra = request["extra_body"] - assert extra["repeat_penalty"] == 1.1 - assert extra["top_k"] == 40 - assert extra["min_p"] == 0.05 - assert extra["typical_p"] == 0.95 - assert extra["tfs_z"] == 0.97 - assert extra["top_a"] == 0.1 - assert extra["mirostat"] == 2 - assert extra["mirostat_lr"] == 0.1 - assert extra["mirostat_ent"] == 5.0 - assert extra["penalty_last_n"] == 256 - assert extra["n_probs"] == 5 - assert extra["min_keep"] == 1 - assert extra["ignore_eos"] is False - assert extra["logit_bias"] == {100: 5.0, 200: -5.0} - assert extra["cache_prompt"] is True - assert extra["slot_id"] == 1 - assert extra["samplers"] == ["top_k", "tfs_z", "typical_p"] - - def test_format_request_with_tools(self) -> None: - """Test request formatting with tool specifications.""" - model = LlamaCppModel() - - messages = [ - {"role": "user", "content": [{"text": "What's the weather?"}]}, - ] - - tool_specs = [ - { - "name": "get_weather", - "description": "Get current weather", - "inputSchema": { - "json": { - "type": "object", - "properties": { - "location": {"type": "string"}, - }, - "required": ["location"], - } - }, - } - ] - - request = model._format_request(messages, tool_specs=tool_specs) - - assert "tools" in request - assert len(request["tools"]) == 1 - assert request["tools"][0]["function"]["name"] == "get_weather" - - def test_update_config(self) -> None: - """Test configuration update.""" - model = LlamaCppModel(model_id="initial-model") - - assert model.config["model_id"] == "initial-model" - - model.update_config(model_id="updated-model", params={"temperature": 0.5}) - - assert model.config["model_id"] == "updated-model" - assert model.config["params"]["temperature"] == 0.5 - - def test_get_config(self) -> None: - """Test configuration retrieval.""" - config = { - "model_id": "test-model", - "params": {"temperature": 0.9}, +def test_init_custom_config() -> None: + """Test initialization with custom configuration.""" + model = LlamaCppModel( + base_url="http://example.com:8081", + model_id="llama-3-8b", + params={"temperature": 0.7, "max_tokens": 100}, + ) + + assert model.config["model_id"] == "llama-3-8b" + assert model.config["params"]["temperature"] == 0.7 + assert model.config["params"]["max_tokens"] == 100 + assert model.base_url == "http://example.com:8081" + + +def test_format_request_basic() -> None: + """Test basic request formatting.""" + model = LlamaCppModel(model_id="test-model") + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + request = model._format_request(messages) + + assert request["model"] == "test-model" + assert request["messages"][0]["role"] == "user" + assert request["messages"][0]["content"][0]["type"] == "text" + assert request["messages"][0]["content"][0]["text"] == "Hello" + assert request["stream"] is True + assert "extra_body" not in request + + +def test_format_request_with_system_prompt() -> None: + """Test request formatting with system prompt.""" + model = LlamaCppModel() + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + request = model._format_request(messages, system_prompt="You are a helpful assistant") + + assert request["messages"][0]["role"] == "system" + assert request["messages"][0]["content"] == "You are a helpful assistant" + assert request["messages"][1]["role"] == "user" + + +def test_format_request_with_llamacpp_params() -> None: + """Test request formatting with llama.cpp specific parameters.""" + model = LlamaCppModel( + params={ + "temperature": 0.8, + "max_tokens": 50, + "repeat_penalty": 1.1, + "top_k": 40, + "min_p": 0.05, + "grammar": "root ::= 'yes' | 'no'", } - model = LlamaCppModel(**config) - - retrieved_config = model.get_config() - - assert retrieved_config["model_id"] == "test-model" - assert retrieved_config["params"]["temperature"] == 0.9 - - @pytest.mark.asyncio - async def test_stream_basic(self) -> None: - """Test basic streaming functionality.""" - model = LlamaCppModel() - - # Mock HTTP response with Server-Sent Events format - mock_response_lines = [ - 'data: {"choices": [{"delta": {"content": "Hello"}}]}', - 'data: {"choices": [{"delta": {"content": " world"}, "finish_reason": "stop"}]}', - 'data: {"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}}', - "data: [DONE]", - ] - - async def mock_aiter_lines(): - for line in mock_response_lines: - yield line - - mock_response = AsyncMock() - mock_response.aiter_lines = mock_aiter_lines - mock_response.raise_for_status = AsyncMock() - - with patch.object(model.client, "post", return_value=mock_response): - messages = [{"role": "user", "content": [{"text": "Hi"}]}] - - chunks = [] - async for chunk in model.stream(messages): - chunks.append(chunk) - - # Verify we got the expected chunks - assert any("messageStart" in chunk for chunk in chunks) - assert any( - "contentBlockDelta" in chunk and chunk["contentBlockDelta"]["delta"]["text"] == "Hello" - for chunk in chunks - ) - assert any( - "contentBlockDelta" in chunk and chunk["contentBlockDelta"]["delta"]["text"] == " world" - for chunk in chunks - ) - assert any("messageStop" in chunk for chunk in chunks) - - @pytest.mark.asyncio - async def test_structured_output(self) -> None: - """Test structured output functionality.""" - - class TestOutput(BaseModel): - """Test output model for structured output testing.""" - - answer: str - confidence: float - - model = LlamaCppModel() - - # Mock successful JSON response using the new structured_output implementation - mock_response_text = '{"answer": "yes", "confidence": 0.95}' - - # Create mock stream that returns JSON - async def mock_stream(*_args, **_kwargs): - # Verify json_schema was set - assert "json_schema" in model.config.get("params", {}) - - yield {"messageStart": {"role": "assistant"}} - yield {"contentBlockStart": {"start": {}}} - yield {"contentBlockDelta": {"delta": {"text": mock_response_text}}} - yield {"contentBlockStop": {}} - yield {"messageStop": {"stopReason": "end_turn"}} - - with patch.object(model, "stream", side_effect=mock_stream): - messages = [{"role": "user", "content": [{"text": "Is the earth round?"}]}] - - events = [] - async for event in model.structured_output(TestOutput, messages): - events.append(event) - - # Check we got the output - output_event = next((e for e in events if "output" in e), None) - assert output_event is not None - assert output_event["output"].answer == "yes" - assert output_event["output"].confidence == 0.95 - - def test_timeout_configuration(self) -> None: - """Test timeout configuration.""" - # Test that timeout configuration is accepted without error - model = LlamaCppModel(timeout=30.0) - assert model.client.timeout is not None - - # Test with tuple timeout - model2 = LlamaCppModel(timeout=(10.0, 60.0)) - assert model2.client.timeout is not None - - def test_max_retries_configuration(self) -> None: - """Test max retries configuration is handled gracefully.""" - # Since httpx doesn't use max_retries in the same way, - # we just test that the model initializes without error - model = LlamaCppModel() - assert model.config["model_id"] == "default" - - def test_grammar_constraint_via_params(self) -> None: - """Test grammar constraint via params.""" - grammar = """ - root ::= answer - answer ::= "yes" | "no" - """ - model = LlamaCppModel(params={"grammar": grammar}) - - assert model.config["params"]["grammar"] == grammar - - # Update grammar via update_config - new_grammar = "root ::= [0-9]+" - model.update_config(params={"grammar": new_grammar}) - - assert model.config["params"]["grammar"] == new_grammar - - def test_json_schema_via_params(self) -> None: - """Test JSON schema constraint via params.""" - schema = { - "type": "object", - "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, - "required": ["name", "age"], + ) + + messages = [ + {"role": "user", "content": [{"text": "Is the sky blue?"}]}, + ] + + request = model._format_request(messages) + + # Standard OpenAI params + assert request["temperature"] == 0.8 + assert request["max_tokens"] == 50 + + # Grammar and json_schema go directly in request for llama.cpp + assert request["grammar"] == "root ::= 'yes' | 'no'" + + # Other llama.cpp specific params should be in extra_body + assert "extra_body" in request + assert request["extra_body"]["repeat_penalty"] == 1.1 + assert request["extra_body"]["top_k"] == 40 + assert request["extra_body"]["min_p"] == 0.05 + + +def test_format_request_with_all_new_params() -> None: + """Test request formatting with all new llama.cpp parameters.""" + model = LlamaCppModel( + params={ + # OpenAI params + "temperature": 0.7, + "max_tokens": 100, + "top_p": 0.9, + "seed": 42, + # All llama.cpp specific params + "repeat_penalty": 1.1, + "top_k": 40, + "min_p": 0.05, + "typical_p": 0.95, + "tfs_z": 0.97, + "top_a": 0.1, + "mirostat": 2, + "mirostat_lr": 0.1, + "mirostat_ent": 5.0, + "grammar": "root ::= answer", + "json_schema": {"type": "object"}, + "penalty_last_n": 256, + "n_probs": 5, + "min_keep": 1, + "ignore_eos": False, + "logit_bias": {100: 5.0, 200: -5.0}, + "cache_prompt": True, + "slot_id": 1, + "samplers": ["top_k", "tfs_z", "typical_p"], } - model = LlamaCppModel(params={"json_schema": schema}) + ) + + messages = [{"role": "user", "content": [{"text": "Test"}]}] + request = model._format_request(messages) + + # Check OpenAI params are in root + assert request["temperature"] == 0.7 + assert request["max_tokens"] == 100 + assert request["top_p"] == 0.9 + assert request["seed"] == 42 + + # Grammar and json_schema go directly in request for llama.cpp + assert request["grammar"] == "root ::= answer" + assert request["json_schema"] == {"type": "object"} + + # Check all other llama.cpp params are in extra_body + assert "extra_body" in request + extra = request["extra_body"] + assert extra["repeat_penalty"] == 1.1 + assert extra["top_k"] == 40 + assert extra["min_p"] == 0.05 + assert extra["typical_p"] == 0.95 + assert extra["tfs_z"] == 0.97 + assert extra["top_a"] == 0.1 + assert extra["mirostat"] == 2 + assert extra["mirostat_lr"] == 0.1 + assert extra["mirostat_ent"] == 5.0 + assert extra["penalty_last_n"] == 256 + assert extra["n_probs"] == 5 + assert extra["min_keep"] == 1 + assert extra["ignore_eos"] is False + assert extra["logit_bias"] == {100: 5.0, 200: -5.0} + assert extra["cache_prompt"] is True + assert extra["slot_id"] == 1 + assert extra["samplers"] == ["top_k", "tfs_z", "typical_p"] + + +def test_format_request_with_tools() -> None: + """Test request formatting with tool specifications.""" + model = LlamaCppModel() + + messages = [ + {"role": "user", "content": [{"text": "What's the weather?"}]}, + ] + + tool_specs = [ + { + "name": "get_weather", + "description": "Get current weather", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "location": {"type": "string"}, + }, + "required": ["location"], + } + }, + } + ] - assert model.config["params"]["json_schema"] == schema + request = model._format_request(messages, tool_specs=tool_specs) - @pytest.mark.asyncio - async def test_stream_with_context_overflow_error(self) -> None: - """Test stream handling of context overflow errors.""" - model = LlamaCppModel() + assert "tools" in request + assert len(request["tools"]) == 1 + assert request["tools"][0]["function"]["name"] == "get_weather" - # Create HTTP error response - error_response = httpx.Response( - status_code=400, - json={"error": {"message": "Context window exceeded. Max context length is 4096 tokens"}}, - request=httpx.Request("POST", "http://test.com"), - ) - error = httpx.HTTPStatusError("Bad Request", request=error_response.request, response=error_response) - # Mock the client to raise the error - with patch.object(model.client, "post", side_effect=error): - messages = [{"role": "user", "content": [{"text": "Very long message"}]}] +def test_update_config() -> None: + """Test configuration update.""" + model = LlamaCppModel(model_id="initial-model") + + assert model.config["model_id"] == "initial-model" + + model.update_config(model_id="updated-model", params={"temperature": 0.5}) + + assert model.config["model_id"] == "updated-model" + assert model.config["params"]["temperature"] == 0.5 - with pytest.raises(ContextWindowOverflowException) as exc_info: - async for _ in model.stream(messages): - pass - assert "Context window exceeded" in str(exc_info.value) +def test_get_config() -> None: + """Test configuration retrieval.""" + config = { + "model_id": "test-model", + "params": {"temperature": 0.9}, + } + model = LlamaCppModel(**config) - @pytest.mark.asyncio - async def test_stream_with_server_overload_error(self) -> None: - """Test stream handling of server overload errors.""" - model = LlamaCppModel() + retrieved_config = model.get_config() - # Create HTTP error response for 503 - error_response = httpx.Response( - status_code=503, - text="Server is busy", - request=httpx.Request("POST", "http://test.com"), + assert retrieved_config["model_id"] == "test-model" + assert retrieved_config["params"]["temperature"] == 0.9 + + +@pytest.mark.asyncio +async def test_stream_basic() -> None: + """Test basic streaming functionality.""" + model = LlamaCppModel() + + # Mock HTTP response with Server-Sent Events format + mock_response_lines = [ + 'data: {"choices": [{"delta": {"content": "Hello"}}]}', + 'data: {"choices": [{"delta": {"content": " world"}, "finish_reason": "stop"}]}', + 'data: {"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}}', + "data: [DONE]", + ] + + async def mock_aiter_lines(): + for line in mock_response_lines: + yield line + + mock_response = AsyncMock() + mock_response.aiter_lines = mock_aiter_lines + mock_response.raise_for_status = AsyncMock() + + with patch.object(model.client, "post", return_value=mock_response): + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + + chunks = [] + async for chunk in model.stream(messages): + chunks.append(chunk) + + # Verify we got the expected chunks + assert any("messageStart" in chunk for chunk in chunks) + assert any( + "contentBlockDelta" in chunk and chunk["contentBlockDelta"]["delta"]["text"] == "Hello" for chunk in chunks ) - error = httpx.HTTPStatusError( - "Service Unavailable", - request=error_response.request, - response=error_response, + assert any( + "contentBlockDelta" in chunk and chunk["contentBlockDelta"]["delta"]["text"] == " world" for chunk in chunks ) + assert any("messageStop" in chunk for chunk in chunks) + + +@pytest.mark.asyncio +async def test_structured_output() -> None: + """Test structured output functionality.""" + + class TestOutput(BaseModel): + """Test output model for structured output testing.""" + + answer: str + confidence: float + + model = LlamaCppModel() + + # Mock successful JSON response using the new structured_output implementation + mock_response_text = '{"answer": "yes", "confidence": 0.95}' + + # Create mock stream that returns JSON + async def mock_stream(*_args, **_kwargs): + # Verify json_schema was set + assert "json_schema" in model.config.get("params", {}) + + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": mock_response_text}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + with patch.object(model, "stream", side_effect=mock_stream): + messages = [{"role": "user", "content": [{"text": "Is the earth round?"}]}] + + events = [] + async for event in model.structured_output(TestOutput, messages): + events.append(event) + + # Check we got the output + output_event = next((e for e in events if "output" in e), None) + assert output_event is not None + assert output_event["output"].answer == "yes" + assert output_event["output"].confidence == 0.95 + + +def test_timeout_configuration() -> None: + """Test timeout configuration.""" + # Test that timeout configuration is accepted without error + model = LlamaCppModel(timeout=30.0) + assert model.client.timeout is not None + + # Test with tuple timeout + model2 = LlamaCppModel(timeout=(10.0, 60.0)) + assert model2.client.timeout is not None - # Mock the client to raise the error - with patch.object(model.client, "post", side_effect=error): - messages = [{"role": "user", "content": [{"text": "Test"}]}] - with pytest.raises(ModelThrottledException) as exc_info: - async for _ in model.stream(messages): - pass +def test_max_retries_configuration() -> None: + """Test max retries configuration is handled gracefully.""" + # Since httpx doesn't use max_retries in the same way, + # we just test that the model initializes without error + model = LlamaCppModel() + assert model.config["model_id"] == "default" - assert "server is busy or overloaded" in str(exc_info.value) - @pytest.mark.asyncio - async def test_structured_output_with_json_schema(self) -> None: - """Test structured output using JSON schema.""" +def test_grammar_constraint_via_params() -> None: + """Test grammar constraint via params.""" + grammar = """ + root ::= answer + answer ::= "yes" | "no" + """ + model = LlamaCppModel(params={"grammar": grammar}) - class TestOutput(BaseModel): - """Test output model for JSON schema testing.""" + assert model.config["params"]["grammar"] == grammar - answer: str - confidence: float + # Update grammar via update_config + new_grammar = "root ::= [0-9]+" + model.update_config(params={"grammar": new_grammar}) - model = LlamaCppModel() + assert model.config["params"]["grammar"] == new_grammar - # Mock successful JSON response - mock_response_text = '{"answer": "yes", "confidence": 0.95}' - # Create mock stream that returns JSON - async def mock_stream(*_args, **_kwargs): - # Check that json_schema was set correctly - assert model.config["params"]["json_schema"] == TestOutput.model_json_schema() +def test_json_schema_via_params() -> None: + """Test JSON schema constraint via params.""" + schema = { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "required": ["name", "age"], + } + model = LlamaCppModel(params={"json_schema": schema}) - yield {"messageStart": {"role": "assistant"}} - yield {"contentBlockStart": {"start": {}}} - yield {"contentBlockDelta": {"delta": {"text": mock_response_text}}} - yield {"contentBlockStop": {}} - yield {"messageStop": {"stopReason": "end_turn"}} + assert model.config["params"]["json_schema"] == schema - with patch.object(model, "stream", side_effect=mock_stream): - messages = [{"role": "user", "content": [{"text": "Is the earth round?"}]}] - events = [] - async for event in model.structured_output(TestOutput, messages): - events.append(event) +@pytest.mark.asyncio +async def test_stream_with_context_overflow_error() -> None: + """Test stream handling of context overflow errors.""" + model = LlamaCppModel() - # Check we got the output - output_event = next((e for e in events if "output" in e), None) - assert output_event is not None - assert output_event["output"].answer == "yes" - assert output_event["output"].confidence == 0.95 + # Create HTTP error response + error_response = httpx.Response( + status_code=400, + json={"error": {"message": "Context window exceeded. Max context length is 4096 tokens"}}, + request=httpx.Request("POST", "http://test.com"), + ) + error = httpx.HTTPStatusError("Bad Request", request=error_response.request, response=error_response) - @pytest.mark.asyncio - async def test_structured_output_invalid_json_error(self) -> None: - """Test structured output raises error for invalid JSON.""" + # Mock the client to raise the error + with patch.object(model.client, "post", side_effect=error): + messages = [{"role": "user", "content": [{"text": "Very long message"}]}] - class TestOutput(BaseModel): - """Test output model for invalid JSON testing.""" + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.stream(messages): + pass - value: int + assert "Context window exceeded" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_stream_with_server_overload_error() -> None: + """Test stream handling of server overload errors.""" + model = LlamaCppModel() + + # Create HTTP error response for 503 + error_response = httpx.Response( + status_code=503, + text="Server is busy", + request=httpx.Request("POST", "http://test.com"), + ) + error = httpx.HTTPStatusError( + "Service Unavailable", + request=error_response.request, + response=error_response, + ) - model = LlamaCppModel() + # Mock the client to raise the error + with patch.object(model.client, "post", side_effect=error): + messages = [{"role": "user", "content": [{"text": "Test"}]}] + + with pytest.raises(ModelThrottledException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "server is busy or overloaded" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_structured_output_with_json_schema() -> None: + """Test structured output using JSON schema.""" + + class TestOutput(BaseModel): + """Test output model for JSON schema testing.""" + + answer: str + confidence: float + + model = LlamaCppModel() + + # Mock successful JSON response + mock_response_text = '{"answer": "yes", "confidence": 0.95}' + + # Create mock stream that returns JSON + async def mock_stream(*_args, **_kwargs): + # Check that json_schema was set correctly + assert model.config["params"]["json_schema"] == TestOutput.model_json_schema() + + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": mock_response_text}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + with patch.object(model, "stream", side_effect=mock_stream): + messages = [{"role": "user", "content": [{"text": "Is the earth round?"}]}] + + events = [] + async for event in model.structured_output(TestOutput, messages): + events.append(event) + + # Check we got the output + output_event = next((e for e in events if "output" in e), None) + assert output_event is not None + assert output_event["output"].answer == "yes" + assert output_event["output"].confidence == 0.95 + + +@pytest.mark.asyncio +async def test_structured_output_invalid_json_error() -> None: + """Test structured output raises error for invalid JSON.""" + + class TestOutput(BaseModel): + """Test output model for invalid JSON testing.""" + + value: int + + model = LlamaCppModel() + + # Mock stream that returns invalid JSON + async def mock_stream(*_args, **_kwargs): + # Check that json_schema was set correctly + assert model.config["params"]["json_schema"] == TestOutput.model_json_schema() + + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": "This is not valid JSON"}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + with patch.object(model, "stream", side_effect=mock_stream): + messages = [{"role": "user", "content": [{"text": "Give me a number"}]}] - # Mock stream that returns invalid JSON - async def mock_stream(*_args, **_kwargs): - # Check that json_schema was set correctly - assert model.config["params"]["json_schema"] == TestOutput.model_json_schema() + with pytest.raises(json.JSONDecodeError): + async for _ in model.structured_output(TestOutput, messages): + pass - yield {"messageStart": {"role": "assistant"}} - yield {"contentBlockStart": {"start": {}}} - yield {"contentBlockDelta": {"delta": {"text": "This is not valid JSON"}}} - yield {"contentBlockStop": {}} - yield {"messageStop": {"stopReason": "end_turn"}} - with patch.object(model, "stream", side_effect=mock_stream): - messages = [{"role": "user", "content": [{"text": "Give me a number"}]}] +def test_format_audio_content() -> None: + """Test formatting of audio content for llama.cpp multimodal models.""" + model = LlamaCppModel() - with pytest.raises(json.JSONDecodeError): - async for _ in model.structured_output(TestOutput, messages): - pass + # Create test audio data + audio_bytes = b"fake audio data" + audio_content = {"audio": {"source": {"bytes": audio_bytes}, "format": "wav"}} - def test_format_audio_content(self) -> None: - """Test formatting of audio content for llama.cpp multimodal models.""" - model = LlamaCppModel() + # Format the content + result = model._format_message_content(audio_content) - # Create test audio data - audio_bytes = b"fake audio data" - audio_content = {"audio": {"source": {"bytes": audio_bytes}, "format": "wav"}} + # Verify the structure + assert result["type"] == "input_audio" + assert "input_audio" in result + assert "data" in result["input_audio"] + assert "format" in result["input_audio"] - # Format the content - result = model._format_message_content(audio_content) + # Verify the data is base64 encoded + decoded = base64.b64decode(result["input_audio"]["data"]) + assert decoded == audio_bytes - # Verify the structure - assert result["type"] == "input_audio" - assert "input_audio" in result - assert "data" in result["input_audio"] - assert "format" in result["input_audio"] + # Verify format is preserved + assert result["input_audio"]["format"] == "wav" - # Verify the data is base64 encoded - decoded = base64.b64decode(result["input_audio"]["data"]) - assert decoded == audio_bytes - # Verify format is preserved - assert result["input_audio"]["format"] == "wav" +def test_format_audio_content_default_format() -> None: + """Test audio content formatting uses wav as default format.""" + model = LlamaCppModel() - def test_format_audio_content_default_format(self) -> None: - """Test audio content formatting uses wav as default format.""" - model = LlamaCppModel() + audio_content = { + "audio": {"source": {"bytes": b"test audio"}} + # No format specified + } - audio_content = { - "audio": {"source": {"bytes": b"test audio"}} - # No format specified + result = model._format_message_content(audio_content) + + # Should default to wav + assert result["input_audio"]["format"] == "wav" + + +def test_format_messages_with_audio() -> None: + """Test that _format_messages properly handles audio content.""" + model = LlamaCppModel() + + # Create messages with audio content + messages = [ + { + "role": "user", + "content": [ + {"text": "Listen to this audio:"}, + {"audio": {"source": {"bytes": b"audio data"}, "format": "mp3"}}, + ], + } + ] + + # Format the messages + result = model._format_messages(messages) + + # Check structure + assert len(result) == 1 + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 2 + + # Check text content + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][0]["text"] == "Listen to this audio:" + + # Check audio content + assert result[0]["content"][1]["type"] == "input_audio" + assert "input_audio" in result[0]["content"][1] + assert result[0]["content"][1]["input_audio"]["format"] == "mp3" + + +def test_format_messages_with_system_prompt() -> None: + """Test _format_messages includes system prompt.""" + model = LlamaCppModel() + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt = "You are a helpful assistant" + + result = model._format_messages(messages, system_prompt) + + # Should have system message first + assert len(result) == 2 + assert result[0]["role"] == "system" + assert result[0]["content"] == system_prompt + assert result[1]["role"] == "user" + + +def test_format_messages_with_image() -> None: + """Test that _format_messages properly handles image content.""" + model = LlamaCppModel() + + # Create messages with image content + messages = [ + { + "role": "user", + "content": [ + {"text": "Describe this image:"}, + {"image": {"source": {"bytes": b"image data"}, "format": "png"}}, + ], } + ] + + # Format the messages + result = model._format_messages(messages) + + # Check structure + assert len(result) == 1 + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 2 + + # Check text content + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][0]["text"] == "Describe this image:" + + # Check image content uses standard format + assert result[0]["content"][1]["type"] == "image_url" + assert "image_url" in result[0]["content"][1] + assert "url" in result[0]["content"][1]["image_url"] + assert result[0]["content"][1]["image_url"]["url"].startswith("data:image/png;base64,") + + +def test_format_messages_with_mixed_content() -> None: + """Test that _format_messages handles mixed audio and image content correctly.""" + model = LlamaCppModel() + + # Create messages with both audio and image content + messages = [ + { + "role": "user", + "content": [ + {"text": "Analyze this media:"}, + {"audio": {"source": {"bytes": b"audio data"}, "format": "wav"}}, + {"image": {"source": {"bytes": b"image data"}, "format": "jpg"}}, + ], + } + ] + + # Format the messages + result = model._format_messages(messages) + + # Check structure + assert len(result) == 1 + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 3 + + # Check text content + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][0]["text"] == "Analyze this media:" + + # Check audio content uses llama.cpp specific format + assert result[0]["content"][1]["type"] == "input_audio" + assert "input_audio" in result[0]["content"][1] + assert result[0]["content"][1]["input_audio"]["format"] == "wav" - result = model._format_message_content(audio_content) - - # Should default to wav - assert result["input_audio"]["format"] == "wav" - - def test_format_messages_with_audio(self) -> None: - """Test that _format_messages properly handles audio content.""" - model = LlamaCppModel() - - # Create messages with audio content - messages = [ - { - "role": "user", - "content": [ - {"text": "Listen to this audio:"}, - {"audio": {"source": {"bytes": b"audio data"}, "format": "mp3"}}, - ], - } - ] - - # Format the messages - result = model._format_messages(messages) - - # Check structure - assert len(result) == 1 - assert result[0]["role"] == "user" - assert len(result[0]["content"]) == 2 - - # Check text content - assert result[0]["content"][0]["type"] == "text" - assert result[0]["content"][0]["text"] == "Listen to this audio:" - - # Check audio content - assert result[0]["content"][1]["type"] == "input_audio" - assert "input_audio" in result[0]["content"][1] - assert result[0]["content"][1]["input_audio"]["format"] == "mp3" - - def test_format_messages_with_system_prompt(self) -> None: - """Test _format_messages includes system prompt.""" - model = LlamaCppModel() - - messages = [{"role": "user", "content": [{"text": "Hello"}]}] - system_prompt = "You are a helpful assistant" - - result = model._format_messages(messages, system_prompt) - - # Should have system message first - assert len(result) == 2 - assert result[0]["role"] == "system" - assert result[0]["content"] == system_prompt - assert result[1]["role"] == "user" - - def test_format_messages_with_image(self) -> None: - """Test that _format_messages properly handles image content.""" - model = LlamaCppModel() - - # Create messages with image content - messages = [ - { - "role": "user", - "content": [ - {"text": "Describe this image:"}, - {"image": {"source": {"bytes": b"image data"}, "format": "png"}}, - ], - } - ] - - # Format the messages - result = model._format_messages(messages) - - # Check structure - assert len(result) == 1 - assert result[0]["role"] == "user" - assert len(result[0]["content"]) == 2 - - # Check text content - assert result[0]["content"][0]["type"] == "text" - assert result[0]["content"][0]["text"] == "Describe this image:" - - # Check image content uses standard format - assert result[0]["content"][1]["type"] == "image_url" - assert "image_url" in result[0]["content"][1] - assert "url" in result[0]["content"][1]["image_url"] - assert result[0]["content"][1]["image_url"]["url"].startswith("data:image/png;base64,") - - def test_format_messages_with_mixed_content(self) -> None: - """Test that _format_messages handles mixed audio and image content correctly.""" - model = LlamaCppModel() - - # Create messages with both audio and image content - messages = [ - { - "role": "user", - "content": [ - {"text": "Analyze this media:"}, - {"audio": {"source": {"bytes": b"audio data"}, "format": "wav"}}, - {"image": {"source": {"bytes": b"image data"}, "format": "jpg"}}, - ], - } - ] - - # Format the messages - result = model._format_messages(messages) - - # Check structure - assert len(result) == 1 - assert result[0]["role"] == "user" - assert len(result[0]["content"]) == 3 - - # Check text content - assert result[0]["content"][0]["type"] == "text" - assert result[0]["content"][0]["text"] == "Analyze this media:" - - # Check audio content uses llama.cpp specific format - assert result[0]["content"][1]["type"] == "input_audio" - assert "input_audio" in result[0]["content"][1] - assert result[0]["content"][1]["input_audio"]["format"] == "wav" - - # Check image content uses standard OpenAI format - assert result[0]["content"][2]["type"] == "image_url" - assert "image_url" in result[0]["content"][2] - assert result[0]["content"][2]["image_url"]["url"].startswith("data:image/jpeg;base64,") + # Check image content uses standard OpenAI format + assert result[0]["content"][2]["type"] == "image_url" + assert "image_url" in result[0]["content"][2] + assert result[0]["content"][2]["image_url"]["url"].startswith("data:image/jpeg;base64,") diff --git a/tests_integ/models/test_model_llamacpp.py b/tests_integ/models/test_model_llamacpp.py index 131f0fb88..95047e7ab 100644 --- a/tests_integ/models/test_model_llamacpp.py +++ b/tests_integ/models/test_model_llamacpp.py @@ -40,307 +40,456 @@ async def llamacpp_model() -> LlamaCppModel: return LlamaCppModel(base_url=LLAMACPP_URL) -class TestLlamaCppModelIntegration: - """Integration tests for LlamaCppModel with a real server.""" - - @pytest.mark.asyncio - async def test_basic_completion(self, llamacpp_model: LlamaCppModel) -> None: - """Test basic text completion.""" - messages: list[Message] = [ - {"role": "user", "content": [{"text": "Say 'Hello, World!' and nothing else."}]}, - ] - - response_text = "" - async for event in llamacpp_model.stream(messages): - if "contentBlockDelta" in event: - delta = event["contentBlockDelta"]["delta"] - if "text" in delta: - response_text += delta["text"] - - assert "Hello, World!" in response_text - - @pytest.mark.asyncio - async def test_system_prompt(self, llamacpp_model: LlamaCppModel) -> None: - """Test completion with system prompt.""" - messages: list[Message] = [ - {"role": "user", "content": [{"text": "Who are you?"}]}, - ] - - system_prompt = "You are a helpful AI assistant named Claude." - - response_text = "" - async for event in llamacpp_model.stream(messages, system_prompt=system_prompt): - if "contentBlockDelta" in event: - delta = event["contentBlockDelta"]["delta"] - if "text" in delta: - response_text += delta["text"] - - # Response should reflect the system prompt - assert len(response_text) > 0 - assert "assistant" in response_text.lower() or "claude" in response_text.lower() - - @pytest.mark.asyncio - async def test_streaming_chunks(self, llamacpp_model: LlamaCppModel) -> None: - """Test that streaming returns proper chunk sequence.""" - messages: list[Message] = [ - {"role": "user", "content": [{"text": "Count from 1 to 3."}]}, - ] - - chunk_types = [] - async for event in llamacpp_model.stream(messages): - chunk_types.append(next(iter(event.keys()))) - - # Verify proper chunk sequence - assert chunk_types[0] == "messageStart" - assert chunk_types[1] == "contentBlockStart" - assert "contentBlockDelta" in chunk_types - assert chunk_types[-3] == "contentBlockStop" - assert chunk_types[-2] == "messageStop" - assert chunk_types[-1] == "metadata" - - @pytest.mark.asyncio - async def test_temperature_parameter(self, llamacpp_model: LlamaCppModel) -> None: - """Test temperature parameter affects randomness.""" - messages: list[Message] = [ - {"role": "user", "content": [{"text": "Generate a random word."}]}, - ] - - # Low temperature should give more consistent results - llamacpp_model.update_config(params={"temperature": 0.1, "seed": 42}) - - response1 = "" - async for event in llamacpp_model.stream(messages): - if "contentBlockDelta" in event: - delta = event["contentBlockDelta"]["delta"] - if "text" in delta: - response1 += delta["text"] - - # Same seed and low temperature should give similar result - response2 = "" - async for event in llamacpp_model.stream(messages): - if "contentBlockDelta" in event: - delta = event["contentBlockDelta"]["delta"] - if "text" in delta: - response2 += delta["text"] - - # With low temperature and same seed, responses should be very similar - assert len(response1) > 0 - assert len(response2) > 0 - - @pytest.mark.asyncio - async def test_max_tokens_limit(self, llamacpp_model: LlamaCppModel) -> None: - """Test max_tokens parameter limits response length.""" - messages: list[Message] = [ - {"role": "user", "content": [{"text": "Tell me a very long story about dragons."}]}, - ] - - # Set very low token limit - llamacpp_model.update_config(params={"max_tokens": 10}) - - token_count = 0 - async for event in llamacpp_model.stream(messages): - if "metadata" in event: - usage = event["metadata"]["usage"] - token_count = usage["outputTokens"] - if "messageStop" in event: - stop_reason = event["messageStop"]["stopReason"] - - # Should stop due to max_tokens - assert token_count <= 15 # Allow small overage due to tokenization - assert stop_reason == "max_tokens" - - @pytest.mark.asyncio - async def test_structured_output(self, llamacpp_model: LlamaCppModel) -> None: - """Test structured output generation.""" - messages: list[Message] = [ - { - "role": "user", - "content": [ - { - "text": "What's the weather like in Paris? " - "Respond with temperature in Celsius, condition, and location." - } - ], - }, - ] - - # Enable JSON response format for structured output - llamacpp_model.update_config(params={"response_format": {"type": "json_object"}}) - - result = None - async for event in llamacpp_model.structured_output(WeatherOutput, messages): - if "output" in event: - result = event["output"] - - assert result is not None - assert isinstance(result, WeatherOutput) - assert isinstance(result.temperature, float) - assert isinstance(result.condition, str) - assert result.location.lower() == "paris" - - @pytest.mark.asyncio - async def test_llamacpp_specific_params(self, llamacpp_model: LlamaCppModel) -> None: - """Test llama.cpp specific parameters.""" - messages: list[Message] = [ - {"role": "user", "content": [{"text": "Say 'test' five times."}]}, - ] - - # Use llama.cpp specific parameters - llamacpp_model.update_config( - params={ - "repeat_penalty": 1.5, # Penalize repetition - "top_k": 10, # Limit vocabulary - "min_p": 0.1, # Min-p sampling - } - ) - - response_text = "" - async for event in llamacpp_model.stream(messages): - if "contentBlockDelta" in event: - delta = event["contentBlockDelta"]["delta"] - if "text" in delta: - response_text += delta["text"] - - # Response should contain "test" but with repetition penalty it might vary - assert "test" in response_text.lower() - - @pytest.mark.asyncio - async def test_advanced_sampling_params(self, llamacpp_model: LlamaCppModel) -> None: - """Test advanced sampling parameters.""" - messages: list[Message] = [ - {"role": "user", "content": [{"text": "Generate a random sentence about space."}]}, - ] - - # Test advanced sampling parameters - llamacpp_model.update_config( - params={ - "temperature": 0.8, - "tfs_z": 0.95, # Tail-free sampling - "top_a": 0.1, # Top-a sampling - "typical_p": 0.9, # Typical-p sampling - "penalty_last_n": 64, # Penalty context window - "min_keep": 1, # Minimum tokens to keep - "samplers": ["top_k", "tfs_z", "typical_p", "top_p", "min_p", "temperature"], - } - ) - - response_text = "" - async for event in llamacpp_model.stream(messages): - if "contentBlockDelta" in event: - delta = event["contentBlockDelta"]["delta"] - if "text" in delta: - response_text += delta["text"] - - # Should generate something about space - assert len(response_text) > 0 - assert any(word in response_text.lower() for word in ["space", "star", "planet", "galaxy", "universe"]) - - @pytest.mark.asyncio - async def test_mirostat_sampling(self, llamacpp_model: LlamaCppModel) -> None: - """Test Mirostat sampling modes.""" - messages: list[Message] = [ - {"role": "user", "content": [{"text": "Write a short poem."}]}, - ] +# Integration tests for LlamaCppModel with a real server + + +@pytest.mark.asyncio +async def test_basic_completion(llamacpp_model: LlamaCppModel) -> None: + """Test basic text completion.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Say 'Hello, World!' and nothing else."}]}, + ] + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + assert "Hello, World!" in response_text + + +@pytest.mark.asyncio +async def test_system_prompt(llamacpp_model: LlamaCppModel) -> None: + """Test completion with system prompt.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Who are you?"}]}, + ] + + system_prompt = "You are a helpful AI assistant named Claude." + + response_text = "" + async for event in llamacpp_model.stream(messages, system_prompt=system_prompt): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Response should reflect the system prompt + assert len(response_text) > 0 + assert "assistant" in response_text.lower() or "claude" in response_text.lower() + + +@pytest.mark.asyncio +async def test_streaming_chunks(llamacpp_model: LlamaCppModel) -> None: + """Test that streaming returns proper chunk sequence.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Count from 1 to 3."}]}, + ] + + chunk_types = [] + async for event in llamacpp_model.stream(messages): + chunk_types.append(next(iter(event.keys()))) + + # Verify proper chunk sequence + assert chunk_types[0] == "messageStart" + assert chunk_types[1] == "contentBlockStart" + assert "contentBlockDelta" in chunk_types + assert chunk_types[-3] == "contentBlockStop" + assert chunk_types[-2] == "messageStop" + assert chunk_types[-1] == "metadata" + + +@pytest.mark.asyncio +async def test_temperature_parameter(llamacpp_model: LlamaCppModel) -> None: + """Test temperature parameter affects randomness.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Generate a random word."}]}, + ] + + # Low temperature should give more consistent results + llamacpp_model.update_config(params={"temperature": 0.1, "seed": 42}) + + response1 = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response1 += delta["text"] + + # Same seed and low temperature should give similar result + response2 = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response2 += delta["text"] + + # With low temperature and same seed, responses should be very similar + assert len(response1) > 0 + assert len(response2) > 0 + + +@pytest.mark.asyncio +async def test_max_tokens_limit(llamacpp_model: LlamaCppModel) -> None: + """Test max_tokens parameter limits response length.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Tell me a very long story about dragons."}]}, + ] + + # Set very low token limit + llamacpp_model.update_config(params={"max_tokens": 10}) + + token_count = 0 + async for event in llamacpp_model.stream(messages): + if "metadata" in event: + usage = event["metadata"]["usage"] + token_count = usage["outputTokens"] + if "messageStop" in event: + stop_reason = event["messageStop"]["stopReason"] + + # Should stop due to max_tokens + assert token_count <= 15 # Allow small overage due to tokenization + assert stop_reason == "max_tokens" + + +@pytest.mark.asyncio +async def test_structured_output(llamacpp_model: LlamaCppModel) -> None: + """Test structured output generation.""" + messages: list[Message] = [ + { + "role": "user", + "content": [ + { + "text": "What's the weather like in Paris? " + "Respond with temperature in Celsius, condition, and location." + } + ], + }, + ] + + # Enable JSON response format for structured output + llamacpp_model.update_config(params={"response_format": {"type": "json_object"}}) + + result = None + async for event in llamacpp_model.structured_output(WeatherOutput, messages): + if "output" in event: + result = event["output"] + + assert result is not None + assert isinstance(result, WeatherOutput) + assert isinstance(result.temperature, float) + assert isinstance(result.condition, str) + assert result.location.lower() == "paris" + + +@pytest.mark.asyncio +async def test_llamacpp_specific_params(llamacpp_model: LlamaCppModel) -> None: + """Test llama.cpp specific parameters.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Say 'test' five times."}]}, + ] + + # Use llama.cpp specific parameters + llamacpp_model.update_config( + params={ + "repeat_penalty": 1.5, # Penalize repetition + "top_k": 10, # Limit vocabulary + "min_p": 0.1, # Min-p sampling + } + ) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Response should contain "test" but with repetition penalty it might vary + assert "test" in response_text.lower() + + +@pytest.mark.asyncio +async def test_advanced_sampling_params(llamacpp_model: LlamaCppModel) -> None: + """Test advanced sampling parameters.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Generate a random sentence about space."}]}, + ] + + # Test advanced sampling parameters + llamacpp_model.update_config( + params={ + "temperature": 0.8, + "tfs_z": 0.95, # Tail-free sampling + "top_a": 0.1, # Top-a sampling + "typical_p": 0.9, # Typical-p sampling + "penalty_last_n": 64, # Penalty context window + "min_keep": 1, # Minimum tokens to keep + "samplers": ["top_k", "tfs_z", "typical_p", "top_p", "min_p", "temperature"], + } + ) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Should generate something about space + assert len(response_text) > 0 + assert any(word in response_text.lower() for word in ["space", "star", "planet", "galaxy", "universe"]) + + +@pytest.mark.asyncio +async def test_mirostat_sampling(llamacpp_model: LlamaCppModel) -> None: + """Test Mirostat sampling modes.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Write a short poem."}]}, + ] + + # Test Mirostat v2 + llamacpp_model.update_config( + params={ + "mirostat": 2, + "mirostat_lr": 0.1, + "mirostat_ent": 5.0, + "seed": 42, # For reproducibility + } + ) - # Test Mirostat v2 - llamacpp_model.update_config( - params={ - "mirostat": 2, - "mirostat_lr": 0.1, - "mirostat_ent": 5.0, - "seed": 42, # For reproducibility - } - ) + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] - response_text = "" - async for event in llamacpp_model.stream(messages): - if "contentBlockDelta" in event: - delta = event["contentBlockDelta"]["delta"] - if "text" in delta: - response_text += delta["text"] + # Should generate a poem + assert len(response_text) > 20 + assert "\n" in response_text # Poems typically have line breaks - # Should generate a poem - assert len(response_text) > 20 - assert "\n" in response_text # Poems typically have line breaks - @pytest.mark.asyncio - async def test_grammar_constraint(self, llamacpp_model: LlamaCppModel) -> None: - """Test grammar constraint feature (llama.cpp specific).""" - messages: list[Message] = [ - {"role": "user", "content": [{"text": "Is the sky blue? Answer yes or no."}]}, - ] +@pytest.mark.asyncio +async def test_grammar_constraint(llamacpp_model: LlamaCppModel) -> None: + """Test grammar constraint feature (llama.cpp specific).""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Is the sky blue? Answer yes or no."}]}, + ] - # Set grammar constraint via params - grammar = """ + # Set grammar constraint via params + grammar = """ root ::= answer answer ::= "yes" | "no" """ - llamacpp_model.update_config(params={"grammar": grammar}) - - response_text = "" - async for event in llamacpp_model.stream(messages): - if "contentBlockDelta" in event: - delta = event["contentBlockDelta"]["delta"] - if "text" in delta: - response_text += delta["text"] - - # Response should be exactly "yes" or "no" - assert response_text.strip().lower() in ["yes", "no"] - - @pytest.mark.asyncio - async def test_json_schema_constraint(self, llamacpp_model: LlamaCppModel) -> None: - """Test JSON schema constraint feature.""" - messages: list[Message] = [ - { - "role": "user", - "content": [{"text": "Describe the weather in JSON format with temperature and description."}], + llamacpp_model.update_config(params={"grammar": grammar}) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Response should be exactly "yes" or "no" + assert response_text.strip().lower() in ["yes", "no"] + + +@pytest.mark.asyncio +async def test_json_schema_constraint(llamacpp_model: LlamaCppModel) -> None: + """Test JSON schema constraint feature.""" + messages: list[Message] = [ + { + "role": "user", + "content": [{"text": "Describe the weather in JSON format with temperature and description."}], + }, + ] + + # Set JSON schema constraint via params + schema = { + "type": "object", + "properties": {"temperature": {"type": "number"}, "description": {"type": "string"}}, + "required": ["temperature", "description"], + } + llamacpp_model.update_config(params={"json_schema": schema}) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Should be valid JSON matching the schema + import json + + data = json.loads(response_text.strip()) + assert "temperature" in data + assert "description" in data + assert isinstance(data["temperature"], (int, float)) + assert isinstance(data["description"], str) + + +@pytest.mark.asyncio +async def test_logit_bias(llamacpp_model: LlamaCppModel) -> None: + """Test logit bias feature.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Choose between 'cat' and 'dog'."}]}, + ] + + # This is a simplified test - in reality you'd need to know the actual token IDs + # for "cat" and "dog" in the model's vocabulary + llamacpp_model.update_config( + params={ + "logit_bias": { + # These are placeholder token IDs - real implementation would need actual token IDs + 1234: 10.0, # Strong positive bias (hypothetical "cat" token) + 5678: -10.0, # Strong negative bias (hypothetical "dog" token) }, - ] - - # Set JSON schema constraint via params - schema = { - "type": "object", - "properties": {"temperature": {"type": "number"}, "description": {"type": "string"}}, - "required": ["temperature", "description"], + "seed": 42, # For reproducibility } - llamacpp_model.update_config(params={"json_schema": schema}) + ) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Should generate text (exact behavior depends on actual token IDs) + assert len(response_text) > 0 + + +@pytest.mark.asyncio +async def test_cache_prompt(llamacpp_model: LlamaCppModel) -> None: + """Test prompt caching feature.""" + messages: list[Message] = [ + {"role": "system", "content": [{"text": "You are a helpful assistant. Always be concise."}]}, + {"role": "user", "content": [{"text": "What is 2+2?"}]}, + ] + + # Enable prompt caching + llamacpp_model.update_config( + params={ + "cache_prompt": True, + "slot_id": 0, # Use specific slot for caching + } + ) + + # First request + response1 = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response1 += delta["text"] + + # Second request with same system prompt should use cache + messages2 = [ + {"role": "system", "content": [{"text": "You are a helpful assistant. Always be concise."}]}, + {"role": "user", "content": [{"text": "What is 3+3?"}]}, + ] + + response2 = "" + async for event in llamacpp_model.stream(messages2): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response2 += delta["text"] + + # Both should give valid responses + assert "4" in response1 + assert "6" in response2 + + +@pytest.mark.asyncio +async def test_concurrent_requests(llamacpp_model: LlamaCppModel) -> None: + """Test handling multiple concurrent requests.""" + import asyncio + + async def make_request(prompt: str) -> str: + messages: list[Message] = [ + {"role": "user", "content": [{"text": prompt}]}, + ] - response_text = "" + response = "" async for event in llamacpp_model.stream(messages): if "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] if "text" in delta: - response_text += delta["text"] - - # Should be valid JSON matching the schema - import json - - data = json.loads(response_text.strip()) - assert "temperature" in data - assert "description" in data - assert isinstance(data["temperature"], (int, float)) - assert isinstance(data["description"], str) - - @pytest.mark.asyncio - async def test_logit_bias(self, llamacpp_model: LlamaCppModel) -> None: - """Test logit bias feature.""" - messages: list[Message] = [ - {"role": "user", "content": [{"text": "Choose between 'cat' and 'dog'."}]}, - ] - - # This is a simplified test - in reality you'd need to know the actual token IDs - # for "cat" and "dog" in the model's vocabulary - llamacpp_model.update_config( - params={ - "logit_bias": { - # These are placeholder token IDs - real implementation would need actual token IDs - 1234: 10.0, # Strong positive bias (hypothetical "cat" token) - 5678: -10.0, # Strong negative bias (hypothetical "dog" token) - }, - "seed": 42, # For reproducibility - } - ) - + response += delta["text"] + return response + + # Make concurrent requests + prompts = [ + "Say 'one'", + "Say 'two'", + "Say 'three'", + ] + + responses = await asyncio.gather(*[make_request(p) for p in prompts]) + + # Each response should contain the expected number + assert "one" in responses[0].lower() + assert "two" in responses[1].lower() + assert "three" in responses[2].lower() + + +@pytest.mark.asyncio +async def test_enhanced_structured_output(llamacpp_model: LlamaCppModel) -> None: + """Test enhanced structured output with native JSON schema support.""" + + class BookInfo(BaseModel): + title: str + author: str + year: int + genres: list[str] + + messages: list[Message] = [ + { + "role": "user", + "content": [ + { + "text": "Create information about a fictional science fiction book. " + "Include title, author, publication year, and 2-3 genres." + } + ], + }, + ] + + result = None + events = [] + async for event in llamacpp_model.structured_output(BookInfo, messages): + events.append(event) + if "output" in event: + result = event["output"] + + # Verify we got structured output + assert result is not None + assert isinstance(result, BookInfo) + assert isinstance(result.title, str) and len(result.title) > 0 + assert isinstance(result.author, str) and len(result.author) > 0 + assert isinstance(result.year, int) and 1900 <= result.year <= 2100 + assert isinstance(result.genres, list) and len(result.genres) >= 2 + assert all(isinstance(genre, str) for genre in result.genres) + + # Should have streamed events before the output + assert len(events) > 1 + + +@pytest.mark.asyncio +async def test_context_overflow_handling(llamacpp_model: LlamaCppModel) -> None: + """Test proper handling of context window overflow.""" + # Create a very long message that might exceed context + long_text = "This is a test sentence. " * 1000 + messages: list[Message] = [ + {"role": "user", "content": [{"text": f"Summarize this text: {long_text}"}]}, + ] + + try: response_text = "" async for event in llamacpp_model.stream(messages): if "contentBlockDelta" in event: @@ -348,148 +497,14 @@ async def test_logit_bias(self, llamacpp_model: LlamaCppModel) -> None: if "text" in delta: response_text += delta["text"] - # Should generate text (exact behavior depends on actual token IDs) + # If it succeeds, we got a response assert len(response_text) > 0 - - @pytest.mark.asyncio - async def test_cache_prompt(self, llamacpp_model: LlamaCppModel) -> None: - """Test prompt caching feature.""" - messages: list[Message] = [ - {"role": "system", "content": [{"text": "You are a helpful assistant. Always be concise."}]}, - {"role": "user", "content": [{"text": "What is 2+2?"}]}, - ] - - # Enable prompt caching - llamacpp_model.update_config( - params={ - "cache_prompt": True, - "slot_id": 0, # Use specific slot for caching - } - ) - - # First request - response1 = "" - async for event in llamacpp_model.stream(messages): - if "contentBlockDelta" in event: - delta = event["contentBlockDelta"]["delta"] - if "text" in delta: - response1 += delta["text"] - - # Second request with same system prompt should use cache - messages2 = [ - {"role": "system", "content": [{"text": "You are a helpful assistant. Always be concise."}]}, - {"role": "user", "content": [{"text": "What is 3+3?"}]}, - ] - - response2 = "" - async for event in llamacpp_model.stream(messages2): - if "contentBlockDelta" in event: - delta = event["contentBlockDelta"]["delta"] - if "text" in delta: - response2 += delta["text"] - - # Both should give valid responses - assert "4" in response1 - assert "6" in response2 - - @pytest.mark.asyncio - async def test_concurrent_requests(self, llamacpp_model: LlamaCppModel) -> None: - """Test handling multiple concurrent requests.""" - import asyncio - - async def make_request(prompt: str) -> str: - messages: list[Message] = [ - {"role": "user", "content": [{"text": prompt}]}, - ] - - response = "" - async for event in llamacpp_model.stream(messages): - if "contentBlockDelta" in event: - delta = event["contentBlockDelta"]["delta"] - if "text" in delta: - response += delta["text"] - return response - - # Make concurrent requests - prompts = [ - "Say 'one'", - "Say 'two'", - "Say 'three'", - ] - - responses = await asyncio.gather(*[make_request(p) for p in prompts]) - - # Each response should contain the expected number - assert "one" in responses[0].lower() - assert "two" in responses[1].lower() - assert "three" in responses[2].lower() - - @pytest.mark.asyncio - async def test_enhanced_structured_output(self, llamacpp_model: LlamaCppModel) -> None: - """Test enhanced structured output with native JSON schema support.""" - - class BookInfo(BaseModel): - title: str - author: str - year: int - genres: list[str] - - messages: list[Message] = [ - { - "role": "user", - "content": [ - { - "text": "Create information about a fictional science fiction book. " - "Include title, author, publication year, and 2-3 genres." - } - ], - }, - ] - - result = None - events = [] - async for event in llamacpp_model.structured_output(BookInfo, messages): - events.append(event) - if "output" in event: - result = event["output"] - - # Verify we got structured output - assert result is not None - assert isinstance(result, BookInfo) - assert isinstance(result.title, str) and len(result.title) > 0 - assert isinstance(result.author, str) and len(result.author) > 0 - assert isinstance(result.year, int) and 1900 <= result.year <= 2100 - assert isinstance(result.genres, list) and len(result.genres) >= 2 - assert all(isinstance(genre, str) for genre in result.genres) - - # Should have streamed events before the output - assert len(events) > 1 - - @pytest.mark.asyncio - async def test_context_overflow_handling(self, llamacpp_model: LlamaCppModel) -> None: - """Test proper handling of context window overflow.""" - # Create a very long message that might exceed context - long_text = "This is a test sentence. " * 1000 - messages: list[Message] = [ - {"role": "user", "content": [{"text": f"Summarize this text: {long_text}"}]}, - ] - - try: - response_text = "" - async for event in llamacpp_model.stream(messages): - if "contentBlockDelta" in event: - delta = event["contentBlockDelta"]["delta"] - if "text" in delta: - response_text += delta["text"] - - # If it succeeds, we got a response - assert len(response_text) > 0 - except Exception as e: - # If it fails, it should be our custom error - from strands.types.exceptions import ContextWindowOverflowException - - if isinstance(e, ContextWindowOverflowException): - assert "context" in str(e).lower() - else: - # Some other error - re-raise to see what it was - raise + except Exception as e: + # If it fails, it should be our custom error + from strands.types.exceptions import ContextWindowOverflowException + + if isinstance(e, ContextWindowOverflowException): + assert "context" in str(e).lower() + else: + # Some other error - re-raise to see what it was + raise From 11e6c11bda9fd870a31cf95944b104e7cac0f3c3 Mon Sep 17 00:00:00 2001 From: Aaron Brown Date: Mon, 18 Aug 2025 10:09:06 -0500 Subject: [PATCH 11/11] Grammar constraints via params example refined; --- src/strands/models/llamacpp.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index 73b248462..94a225a06 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -57,11 +57,13 @@ class LlamaCppModel(Model): >>> model = LlamaCppModel(base_url="http://localhost:8080") >>> model.update_config(params={"temperature": 0.7, "top_k": 40}) - Grammar constraints: - >>> model.use_grammar_constraint(''' - ... root ::= answer - ... answer ::= "yes" | "no" - ... ''') + Grammar constraints via params: + >>> model.update_config(params={ + ... "grammar": ''' + ... root ::= answer + ... answer ::= "yes" | "no" + ... ''' + ... }) Advanced sampling: >>> model.update_config(params={