diff --git a/README.md b/README.md index 62ed54d47..44d10b67e 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,7 @@ from strands import Agent from strands.models import BedrockModel from strands.models.ollama import OllamaModel from strands.models.llamaapi import LlamaAPIModel +from strands.models.llamacpp import LlamaCppModel # Bedrock bedrock_model = BedrockModel( @@ -159,6 +160,7 @@ Built-in providers: - [Amazon Bedrock](https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/) - [Anthropic](https://strandsagents.com/latest/user-guide/concepts/model-providers/anthropic/) - [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/) + - [llama.cpp](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamacpp/) - [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/) - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/) - [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py new file mode 100644 index 000000000..94a225a06 --- /dev/null +++ b/src/strands/models/llamacpp.py @@ -0,0 +1,762 @@ +"""llama.cpp model provider. + +Provides integration with llama.cpp servers running in OpenAI-compatible mode, +with support for advanced llama.cpp-specific features. + +- Docs: https://github.com/ggml-org/llama.cpp +- Server docs: https://github.com/ggml-org/llama.cpp/tree/master/tools/server +- OpenAI API compatibility: + https://github.com/ggml-org/llama.cpp/blob/master/tools/server/README.md#api-endpoints +""" + +import base64 +import json +import logging +import mimetypes +import time +from typing import ( + Any, + AsyncGenerator, + Dict, + Optional, + Type, + TypedDict, + TypeVar, + Union, + cast, +) + +import httpx +from pydantic import BaseModel +from typing_extensions import Unpack, override + +from ..types.content import ContentBlock, Messages +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException +from ..types.streaming import StreamEvent +from ..types.tools import ToolSpec +from .model import Model + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class LlamaCppModel(Model): + """llama.cpp model provider implementation. + + Connects to a llama.cpp server running in OpenAI-compatible mode with + support for advanced llama.cpp-specific features like grammar constraints, + Mirostat sampling, native JSON schema validation, and native multimodal + support for audio and image content. + + The llama.cpp server must be started with the OpenAI-compatible API enabled: + llama-server -m model.gguf --host 0.0.0.0 --port 8080 + + Example: + Basic usage: + >>> model = LlamaCppModel(base_url="http://localhost:8080") + >>> model.update_config(params={"temperature": 0.7, "top_k": 40}) + + Grammar constraints via params: + >>> model.update_config(params={ + ... "grammar": ''' + ... root ::= answer + ... answer ::= "yes" | "no" + ... ''' + ... }) + + Advanced sampling: + >>> model.update_config(params={ + ... "mirostat": 2, + ... "mirostat_lr": 0.1, + ... "tfs_z": 0.95, + ... "repeat_penalty": 1.1 + ... }) + + Multimodal usage (requires multimodal model like Qwen2.5-Omni): + >>> # Audio analysis + >>> audio_content = [{ + ... "audio": {"source": {"bytes": audio_bytes}, "format": "wav"}, + ... "text": "What do you hear in this audio?" + ... }] + >>> response = agent(audio_content) + + >>> # Image analysis + >>> image_content = [{ + ... "image": {"source": {"bytes": image_bytes}, "format": "png"}, + ... "text": "Describe this image" + ... }] + >>> response = agent(image_content) + """ + + class LlamaCppConfig(TypedDict, total=False): + """Configuration options for llama.cpp models. + + Attributes: + model_id: Model identifier for the loaded model in llama.cpp server. + Default is "default" as llama.cpp typically loads a single model. + params: Model parameters supporting both OpenAI and llama.cpp-specific options. + + OpenAI-compatible parameters: + - max_tokens: Maximum number of tokens to generate + - temperature: Sampling temperature (0.0 to 2.0) + - top_p: Nucleus sampling parameter (0.0 to 1.0) + - frequency_penalty: Frequency penalty (-2.0 to 2.0) + - presence_penalty: Presence penalty (-2.0 to 2.0) + - stop: List of stop sequences + - seed: Random seed for reproducibility + - n: Number of completions to generate + - logprobs: Include log probabilities in output + - top_logprobs: Number of top log probabilities to include + + llama.cpp-specific parameters: + - repeat_penalty: Penalize repeat tokens (1.0 = no penalty) + - top_k: Top-k sampling (0 = disabled) + - min_p: Min-p sampling threshold (0.0 to 1.0) + - typical_p: Typical-p sampling (0.0 to 1.0) + - tfs_z: Tail-free sampling parameter (0.0 to 1.0) + - top_a: Top-a sampling parameter + - mirostat: Mirostat sampling mode (0, 1, or 2) + - mirostat_lr: Mirostat learning rate + - mirostat_ent: Mirostat target entropy + - grammar: GBNF grammar string for constrained generation + - json_schema: JSON schema for structured output + - penalty_last_n: Number of tokens to consider for penalties + - n_probs: Number of probabilities to return per token + - min_keep: Minimum tokens to keep in sampling + - ignore_eos: Ignore end-of-sequence token + - logit_bias: Token ID to bias mapping + - cache_prompt: Cache the prompt for faster generation + - slot_id: Slot ID for parallel inference + - samplers: Custom sampler order + """ + + model_id: str + params: Optional[dict[str, Any]] + + def __init__( + self, + base_url: str = "http://localhost:8080", + timeout: Optional[Union[float, tuple[float, float]]] = None, + **model_config: Unpack[LlamaCppConfig], + ) -> None: + """Initialize llama.cpp provider instance. + + Args: + base_url: Base URL for the llama.cpp server. + Default is "http://localhost:8080" for local server. + timeout: Request timeout in seconds. Can be float or tuple of + (connect, read) timeouts. + **model_config: Configuration options for the llama.cpp model. + """ + # Set default model_id if not provided + if "model_id" not in model_config: + model_config["model_id"] = "default" + + self.base_url = base_url.rstrip("/") + self.config = dict(model_config) + + # Configure HTTP client + if isinstance(timeout, tuple): + # Convert tuple to httpx.Timeout object + timeout_obj = httpx.Timeout( + connect=timeout[0] if len(timeout) > 0 else None, + read=timeout[1] if len(timeout) > 1 else None, + write=timeout[2] if len(timeout) > 2 else None, + pool=timeout[3] if len(timeout) > 3 else None, + ) + else: + timeout_obj = httpx.Timeout(timeout or 30.0) + + self.client = httpx.AsyncClient( + base_url=self.base_url, + timeout=timeout_obj, + ) + + logger.debug( + "base_url=<%s>, model_id=<%s> | initializing llama.cpp provider", + base_url, + model_config.get("model_id"), + ) + + @override + def update_config(self, **model_config: Unpack[LlamaCppConfig]) -> None: # type: ignore[override] + """Update the llama.cpp model configuration with provided arguments. + + Args: + **model_config: Configuration overrides. + """ + self.config.update(model_config) + + @override + def get_config(self) -> LlamaCppConfig: + """Get the llama.cpp model configuration. + + Returns: + The llama.cpp model configuration. + """ + return self.config # type: ignore[return-value] + + def _format_message_content(self, content: Union[ContentBlock, Dict[str, Any]]) -> dict[str, Any]: + """Format a content block for llama.cpp. + + Args: + content: Message content. + + Returns: + llama.cpp compatible content block. + + Raises: + TypeError: If the content block type cannot be converted to a compatible format. + """ + if "document" in content: + mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") + file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") + return { + "file": { + "file_data": f"data:{mime_type};base64,{file_data}", + "filename": content["document"]["name"], + }, + "type": "file", + } + + if "image" in content: + mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") + image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + return { + "image_url": { + "detail": "auto", + "format": mime_type, + "url": f"data:{mime_type};base64,{image_data}", + }, + "type": "image_url", + } + + # Handle audio content (not in standard ContentBlock but supported by llama.cpp) + if "audio" in content: + audio_content = cast(Dict[str, Any], content) + audio_data = base64.b64encode(audio_content["audio"]["source"]["bytes"]).decode("utf-8") + audio_format = audio_content["audio"].get("format", "wav") + return { + "type": "input_audio", + "input_audio": {"data": audio_data, "format": audio_format}, + } + + if "text" in content: + return {"text": content["text"], "type": "text"} + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + def _format_tool_call(self, tool_use: dict[str, Any]) -> dict[str, Any]: + """Format a tool call for llama.cpp. + + Args: + tool_use: Tool use requested by the model. + + Returns: + llama.cpp compatible tool call. + """ + return { + "function": { + "arguments": json.dumps(tool_use["input"]), + "name": tool_use["name"], + }, + "id": tool_use["toolUseId"], + "type": "function", + } + + def _format_tool_message(self, tool_result: dict[str, Any]) -> dict[str, Any]: + """Format a tool message for llama.cpp. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + llama.cpp compatible tool message. + """ + contents = [ + {"text": json.dumps(content["json"])} if "json" in content else content + for content in tool_result["content"] + ] + + return { + "role": "tool", + "tool_call_id": tool_result["toolUseId"], + "content": [self._format_message_content(content) for content in contents], + } + + def _format_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format messages for llama.cpp. + + Args: + messages: List of message objects to be processed. + system_prompt: System prompt to provide context to the model. + + Returns: + Formatted messages array compatible with llama.cpp. + """ + formatted_messages: list[dict[str, Any]] = [] + + # Add system prompt if provided + if system_prompt: + formatted_messages.append({"role": "system", "content": system_prompt}) + + for message in messages: + contents = message["content"] + + formatted_contents = [ + self._format_message_content(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + ] + formatted_tool_calls = [ + self._format_tool_call( + { + "name": content["toolUse"]["name"], + "input": content["toolUse"]["input"], + "toolUseId": content["toolUse"]["toolUseId"], + } + ) + for content in contents + if "toolUse" in content + ] + formatted_tool_messages = [ + self._format_tool_message( + { + "toolUseId": content["toolResult"]["toolUseId"], + "content": content["toolResult"]["content"], + } + ) + for content in contents + if "toolResult" in content + ] + + formatted_message = { + "role": message["role"], + "content": formatted_contents, + **({} if not formatted_tool_calls else {"tool_calls": formatted_tool_calls}), + } + formatted_messages.append(formatted_message) + formatted_messages.extend(formatted_tool_messages) + + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + + def _format_request( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + ) -> dict[str, Any]: + """Format a request for the llama.cpp server. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + A request formatted for llama.cpp server's OpenAI-compatible API. + """ + # Separate OpenAI-compatible and llama.cpp-specific parameters + request = { + "messages": self._format_messages(messages, system_prompt), + "model": self.config["model_id"], + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "type": "function", + "function": { + "name": tool_spec["name"], + "description": tool_spec["description"], + "parameters": tool_spec["inputSchema"]["json"], + }, + } + for tool_spec in tool_specs or [] + ], + } + + # Handle parameters if provided + params = self.config.get("params") + if params and isinstance(params, dict): + # Grammar and json_schema go directly in request body for llama.cpp server + if "grammar" in params: + request["grammar"] = params["grammar"] + if "json_schema" in params: + request["json_schema"] = params["json_schema"] + + # llama.cpp-specific parameters that must be passed via extra_body + # NOTE: grammar and json_schema are NOT in this set because llama.cpp server + # expects them directly in the request body for proper constraint application + llamacpp_specific_params = { + "repeat_penalty", + "top_k", + "min_p", + "typical_p", + "tfs_z", + "top_a", + "mirostat", + "mirostat_lr", + "mirostat_ent", + "penalty_last_n", + "n_probs", + "min_keep", + "ignore_eos", + "logit_bias", + "cache_prompt", + "slot_id", + "samplers", + } + + # Standard OpenAI parameters that go directly in the request + openai_params = { + "temperature", + "max_tokens", + "top_p", + "frequency_penalty", + "presence_penalty", + "stop", + "seed", + "n", + "logprobs", + "top_logprobs", + "response_format", + } + + # Add OpenAI parameters directly to request + for param, value in params.items(): + if param in openai_params: + request[param] = value + + # Collect llama.cpp-specific parameters for extra_body + extra_body: Dict[str, Any] = {} + for param, value in params.items(): + if param in llamacpp_specific_params: + extra_body[param] = value + + # Add extra_body if we have llama.cpp-specific parameters + if extra_body: + request["extra_body"] = extra_body + + return request + + def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format a llama.cpp response event into a standardized message chunk. + + Args: + event: A response event from the llama.cpp server. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "tool": + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": event["data"].function.name, + "toolUseId": event["data"].id, + } + } + } + } + return {"contentBlockStart": {"start": {}}} + + case "content_delta": + if event["data_type"] == "tool": + return { + "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}} + } + if event["data_type"] == "reasoning_content": + return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}} + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event["data"]: + case "tool_calls": + return {"messageStop": {"stopReason": "tool_use"}} + case "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + }, + "metrics": { + "latencyMs": event.get("latency_ms", 0), + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']}> | unknown type") + + @override + async def stream( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Stream conversation with the llama.cpp model. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Formatted message chunks from the model. + + Raises: + ContextWindowOverflowException: When the context window is exceeded. + ModelThrottledException: When the llama.cpp server is overloaded. + """ + # Track request start time for latency calculation + start_time = time.perf_counter() + + try: + logger.debug("formatting request for llama.cpp server") + request = self._format_request(messages, tool_specs, system_prompt) + logger.debug("request=<%s>", request) + + logger.debug("sending request to llama.cpp server") + response = await self.client.post("/v1/chat/completions", json=request) + response.raise_for_status() + + logger.debug("processing streaming response") + yield self._format_chunk({"chunk_type": "message_start"}) + yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"}) + + tool_calls: Dict[int, list] = {} + usage_data = None + finish_reason = None + + async for line in response.aiter_lines(): + if not line.strip() or not line.startswith("data: "): + continue + + data_content = line[6:] # Remove "data: " prefix + if data_content.strip() == "[DONE]": + break + + try: + event = json.loads(data_content) + except json.JSONDecodeError: + continue + + # Handle usage information + if "usage" in event: + usage_data = event["usage"] + continue + + if not event.get("choices"): + continue + + choice = event["choices"][0] + delta = choice.get("delta", {}) + + # Handle content deltas + if "content" in delta and delta["content"]: + yield self._format_chunk( + { + "chunk_type": "content_delta", + "data_type": "text", + "data": delta["content"], + } + ) + + # Handle tool calls + if "tool_calls" in delta: + for tool_call in delta["tool_calls"]: + index = tool_call["index"] + if index not in tool_calls: + tool_calls[index] = [] + tool_calls[index].append(tool_call) + + # Check for finish reason + if choice.get("finish_reason"): + finish_reason = choice.get("finish_reason") + break + + yield self._format_chunk({"chunk_type": "content_stop"}) + + # Process tool calls + for tool_deltas in tool_calls.values(): + first_delta = tool_deltas[0] + yield self._format_chunk( + { + "chunk_type": "content_start", + "data_type": "tool", + "data": type( + "ToolCall", + (), + { + "function": type( + "Function", + (), + { + "name": first_delta.get("function", {}).get("name", ""), + }, + )(), + "id": first_delta.get("id", ""), + }, + )(), + } + ) + + for tool_delta in tool_deltas: + yield self._format_chunk( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": type( + "ToolCall", + (), + { + "function": type( + "Function", + (), + { + "arguments": tool_delta.get("function", {}).get("arguments", ""), + }, + )(), + }, + )(), + } + ) + + yield self._format_chunk({"chunk_type": "content_stop"}) + + # Send stop reason + logger.debug("finish_reason=%s, tool_calls=%s", finish_reason, bool(tool_calls)) + if finish_reason == "tool_calls" or tool_calls: + stop_reason = "tool_calls" # Changed from "tool_use" to match format_chunk expectations + else: + stop_reason = finish_reason or "end_turn" + logger.debug("stop_reason=%s", stop_reason) + yield self._format_chunk({"chunk_type": "message_stop", "data": stop_reason}) + + # Send usage metadata if available + if usage_data: + # Calculate latency + latency_ms = int((time.perf_counter() - start_time) * 1000) + yield self._format_chunk( + { + "chunk_type": "metadata", + "data": type( + "Usage", + (), + { + "prompt_tokens": usage_data.get("prompt_tokens", 0), + "completion_tokens": usage_data.get("completion_tokens", 0), + "total_tokens": usage_data.get("total_tokens", 0), + }, + )(), + "latency_ms": latency_ms, + } + ) + + logger.debug("finished streaming response") + + except httpx.HTTPStatusError as e: + if e.response.status_code == 400: + # Parse error response from llama.cpp server + try: + error_data = e.response.json() + error_msg = str(error_data.get("error", {}).get("message", str(error_data))) + except (json.JSONDecodeError, KeyError, AttributeError): + error_msg = e.response.text + + # Check for context overflow by looking for specific error indicators + if any(term in error_msg.lower() for term in ["context", "kv cache", "slot"]): + raise ContextWindowOverflowException(f"Context window exceeded: {error_msg}") from e + elif e.response.status_code == 503: + raise ModelThrottledException("llama.cpp server is busy or overloaded") from e + raise + except Exception as e: + # Handle other potential errors like rate limiting + error_msg = str(e).lower() + if "rate" in error_msg or "429" in str(e): + raise ModelThrottledException(str(e)) from e + raise + + @override + async def structured_output( + self, + output_model: Type[T], + prompt: Messages, + system_prompt: Optional[str] = None, + **kwargs: Any, + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + """Get structured output using llama.cpp's native JSON schema support. + + This implementation uses llama.cpp's json_schema parameter to constrain + the model output to valid JSON matching the provided schema. + + Args: + output_model: The Pydantic model defining the expected output structure. + prompt: The prompt messages to use for generation. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Model events with the last being the structured output. + + Raises: + json.JSONDecodeError: If the model output is not valid JSON. + pydantic.ValidationError: If the output doesn't match the model schema. + """ + # Get the JSON schema from the Pydantic model + schema = output_model.model_json_schema() + + # Store current params to restore later + params = self.config.get("params", {}) + original_params = dict(params) if isinstance(params, dict) else {} + + try: + # Configure for JSON output with schema constraint + params = self.config.get("params", {}) + if not isinstance(params, dict): + params = {} + params["json_schema"] = schema + params["cache_prompt"] = True + self.config["params"] = params + + # Collect the response + response_text = "" + async for event in self.stream(prompt, system_prompt=system_prompt, **kwargs): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + # Forward events to caller + yield cast(Dict[str, Union[T, Any]], event) + + # Parse and validate the JSON response + data = json.loads(response_text.strip()) + output_instance = output_model(**data) + yield {"output": output_instance} + + finally: + # Restore original configuration + self.config["params"] = original_params diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py new file mode 100644 index 000000000..e5b2614c0 --- /dev/null +++ b/tests/strands/models/test_llamacpp.py @@ -0,0 +1,639 @@ +"""Unit tests for llama.cpp model provider.""" + +import base64 +import json +from unittest.mock import AsyncMock, patch + +import httpx +import pytest +from pydantic import BaseModel + +from strands.models.llamacpp import LlamaCppModel +from strands.types.exceptions import ( + ContextWindowOverflowException, + ModelThrottledException, +) + + +def test_init_default_config() -> None: + """Test initialization with default configuration.""" + model = LlamaCppModel() + + assert model.config["model_id"] == "default" + assert isinstance(model.client, httpx.AsyncClient) + assert model.base_url == "http://localhost:8080" + + +def test_init_custom_config() -> None: + """Test initialization with custom configuration.""" + model = LlamaCppModel( + base_url="http://example.com:8081", + model_id="llama-3-8b", + params={"temperature": 0.7, "max_tokens": 100}, + ) + + assert model.config["model_id"] == "llama-3-8b" + assert model.config["params"]["temperature"] == 0.7 + assert model.config["params"]["max_tokens"] == 100 + assert model.base_url == "http://example.com:8081" + + +def test_format_request_basic() -> None: + """Test basic request formatting.""" + model = LlamaCppModel(model_id="test-model") + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + request = model._format_request(messages) + + assert request["model"] == "test-model" + assert request["messages"][0]["role"] == "user" + assert request["messages"][0]["content"][0]["type"] == "text" + assert request["messages"][0]["content"][0]["text"] == "Hello" + assert request["stream"] is True + assert "extra_body" not in request + + +def test_format_request_with_system_prompt() -> None: + """Test request formatting with system prompt.""" + model = LlamaCppModel() + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + request = model._format_request(messages, system_prompt="You are a helpful assistant") + + assert request["messages"][0]["role"] == "system" + assert request["messages"][0]["content"] == "You are a helpful assistant" + assert request["messages"][1]["role"] == "user" + + +def test_format_request_with_llamacpp_params() -> None: + """Test request formatting with llama.cpp specific parameters.""" + model = LlamaCppModel( + params={ + "temperature": 0.8, + "max_tokens": 50, + "repeat_penalty": 1.1, + "top_k": 40, + "min_p": 0.05, + "grammar": "root ::= 'yes' | 'no'", + } + ) + + messages = [ + {"role": "user", "content": [{"text": "Is the sky blue?"}]}, + ] + + request = model._format_request(messages) + + # Standard OpenAI params + assert request["temperature"] == 0.8 + assert request["max_tokens"] == 50 + + # Grammar and json_schema go directly in request for llama.cpp + assert request["grammar"] == "root ::= 'yes' | 'no'" + + # Other llama.cpp specific params should be in extra_body + assert "extra_body" in request + assert request["extra_body"]["repeat_penalty"] == 1.1 + assert request["extra_body"]["top_k"] == 40 + assert request["extra_body"]["min_p"] == 0.05 + + +def test_format_request_with_all_new_params() -> None: + """Test request formatting with all new llama.cpp parameters.""" + model = LlamaCppModel( + params={ + # OpenAI params + "temperature": 0.7, + "max_tokens": 100, + "top_p": 0.9, + "seed": 42, + # All llama.cpp specific params + "repeat_penalty": 1.1, + "top_k": 40, + "min_p": 0.05, + "typical_p": 0.95, + "tfs_z": 0.97, + "top_a": 0.1, + "mirostat": 2, + "mirostat_lr": 0.1, + "mirostat_ent": 5.0, + "grammar": "root ::= answer", + "json_schema": {"type": "object"}, + "penalty_last_n": 256, + "n_probs": 5, + "min_keep": 1, + "ignore_eos": False, + "logit_bias": {100: 5.0, 200: -5.0}, + "cache_prompt": True, + "slot_id": 1, + "samplers": ["top_k", "tfs_z", "typical_p"], + } + ) + + messages = [{"role": "user", "content": [{"text": "Test"}]}] + request = model._format_request(messages) + + # Check OpenAI params are in root + assert request["temperature"] == 0.7 + assert request["max_tokens"] == 100 + assert request["top_p"] == 0.9 + assert request["seed"] == 42 + + # Grammar and json_schema go directly in request for llama.cpp + assert request["grammar"] == "root ::= answer" + assert request["json_schema"] == {"type": "object"} + + # Check all other llama.cpp params are in extra_body + assert "extra_body" in request + extra = request["extra_body"] + assert extra["repeat_penalty"] == 1.1 + assert extra["top_k"] == 40 + assert extra["min_p"] == 0.05 + assert extra["typical_p"] == 0.95 + assert extra["tfs_z"] == 0.97 + assert extra["top_a"] == 0.1 + assert extra["mirostat"] == 2 + assert extra["mirostat_lr"] == 0.1 + assert extra["mirostat_ent"] == 5.0 + assert extra["penalty_last_n"] == 256 + assert extra["n_probs"] == 5 + assert extra["min_keep"] == 1 + assert extra["ignore_eos"] is False + assert extra["logit_bias"] == {100: 5.0, 200: -5.0} + assert extra["cache_prompt"] is True + assert extra["slot_id"] == 1 + assert extra["samplers"] == ["top_k", "tfs_z", "typical_p"] + + +def test_format_request_with_tools() -> None: + """Test request formatting with tool specifications.""" + model = LlamaCppModel() + + messages = [ + {"role": "user", "content": [{"text": "What's the weather?"}]}, + ] + + tool_specs = [ + { + "name": "get_weather", + "description": "Get current weather", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "location": {"type": "string"}, + }, + "required": ["location"], + } + }, + } + ] + + request = model._format_request(messages, tool_specs=tool_specs) + + assert "tools" in request + assert len(request["tools"]) == 1 + assert request["tools"][0]["function"]["name"] == "get_weather" + + +def test_update_config() -> None: + """Test configuration update.""" + model = LlamaCppModel(model_id="initial-model") + + assert model.config["model_id"] == "initial-model" + + model.update_config(model_id="updated-model", params={"temperature": 0.5}) + + assert model.config["model_id"] == "updated-model" + assert model.config["params"]["temperature"] == 0.5 + + +def test_get_config() -> None: + """Test configuration retrieval.""" + config = { + "model_id": "test-model", + "params": {"temperature": 0.9}, + } + model = LlamaCppModel(**config) + + retrieved_config = model.get_config() + + assert retrieved_config["model_id"] == "test-model" + assert retrieved_config["params"]["temperature"] == 0.9 + + +@pytest.mark.asyncio +async def test_stream_basic() -> None: + """Test basic streaming functionality.""" + model = LlamaCppModel() + + # Mock HTTP response with Server-Sent Events format + mock_response_lines = [ + 'data: {"choices": [{"delta": {"content": "Hello"}}]}', + 'data: {"choices": [{"delta": {"content": " world"}, "finish_reason": "stop"}]}', + 'data: {"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}}', + "data: [DONE]", + ] + + async def mock_aiter_lines(): + for line in mock_response_lines: + yield line + + mock_response = AsyncMock() + mock_response.aiter_lines = mock_aiter_lines + mock_response.raise_for_status = AsyncMock() + + with patch.object(model.client, "post", return_value=mock_response): + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + + chunks = [] + async for chunk in model.stream(messages): + chunks.append(chunk) + + # Verify we got the expected chunks + assert any("messageStart" in chunk for chunk in chunks) + assert any( + "contentBlockDelta" in chunk and chunk["contentBlockDelta"]["delta"]["text"] == "Hello" for chunk in chunks + ) + assert any( + "contentBlockDelta" in chunk and chunk["contentBlockDelta"]["delta"]["text"] == " world" for chunk in chunks + ) + assert any("messageStop" in chunk for chunk in chunks) + + +@pytest.mark.asyncio +async def test_structured_output() -> None: + """Test structured output functionality.""" + + class TestOutput(BaseModel): + """Test output model for structured output testing.""" + + answer: str + confidence: float + + model = LlamaCppModel() + + # Mock successful JSON response using the new structured_output implementation + mock_response_text = '{"answer": "yes", "confidence": 0.95}' + + # Create mock stream that returns JSON + async def mock_stream(*_args, **_kwargs): + # Verify json_schema was set + assert "json_schema" in model.config.get("params", {}) + + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": mock_response_text}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + with patch.object(model, "stream", side_effect=mock_stream): + messages = [{"role": "user", "content": [{"text": "Is the earth round?"}]}] + + events = [] + async for event in model.structured_output(TestOutput, messages): + events.append(event) + + # Check we got the output + output_event = next((e for e in events if "output" in e), None) + assert output_event is not None + assert output_event["output"].answer == "yes" + assert output_event["output"].confidence == 0.95 + + +def test_timeout_configuration() -> None: + """Test timeout configuration.""" + # Test that timeout configuration is accepted without error + model = LlamaCppModel(timeout=30.0) + assert model.client.timeout is not None + + # Test with tuple timeout + model2 = LlamaCppModel(timeout=(10.0, 60.0)) + assert model2.client.timeout is not None + + +def test_max_retries_configuration() -> None: + """Test max retries configuration is handled gracefully.""" + # Since httpx doesn't use max_retries in the same way, + # we just test that the model initializes without error + model = LlamaCppModel() + assert model.config["model_id"] == "default" + + +def test_grammar_constraint_via_params() -> None: + """Test grammar constraint via params.""" + grammar = """ + root ::= answer + answer ::= "yes" | "no" + """ + model = LlamaCppModel(params={"grammar": grammar}) + + assert model.config["params"]["grammar"] == grammar + + # Update grammar via update_config + new_grammar = "root ::= [0-9]+" + model.update_config(params={"grammar": new_grammar}) + + assert model.config["params"]["grammar"] == new_grammar + + +def test_json_schema_via_params() -> None: + """Test JSON schema constraint via params.""" + schema = { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "required": ["name", "age"], + } + model = LlamaCppModel(params={"json_schema": schema}) + + assert model.config["params"]["json_schema"] == schema + + +@pytest.mark.asyncio +async def test_stream_with_context_overflow_error() -> None: + """Test stream handling of context overflow errors.""" + model = LlamaCppModel() + + # Create HTTP error response + error_response = httpx.Response( + status_code=400, + json={"error": {"message": "Context window exceeded. Max context length is 4096 tokens"}}, + request=httpx.Request("POST", "http://test.com"), + ) + error = httpx.HTTPStatusError("Bad Request", request=error_response.request, response=error_response) + + # Mock the client to raise the error + with patch.object(model.client, "post", side_effect=error): + messages = [{"role": "user", "content": [{"text": "Very long message"}]}] + + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "Context window exceeded" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_stream_with_server_overload_error() -> None: + """Test stream handling of server overload errors.""" + model = LlamaCppModel() + + # Create HTTP error response for 503 + error_response = httpx.Response( + status_code=503, + text="Server is busy", + request=httpx.Request("POST", "http://test.com"), + ) + error = httpx.HTTPStatusError( + "Service Unavailable", + request=error_response.request, + response=error_response, + ) + + # Mock the client to raise the error + with patch.object(model.client, "post", side_effect=error): + messages = [{"role": "user", "content": [{"text": "Test"}]}] + + with pytest.raises(ModelThrottledException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "server is busy or overloaded" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_structured_output_with_json_schema() -> None: + """Test structured output using JSON schema.""" + + class TestOutput(BaseModel): + """Test output model for JSON schema testing.""" + + answer: str + confidence: float + + model = LlamaCppModel() + + # Mock successful JSON response + mock_response_text = '{"answer": "yes", "confidence": 0.95}' + + # Create mock stream that returns JSON + async def mock_stream(*_args, **_kwargs): + # Check that json_schema was set correctly + assert model.config["params"]["json_schema"] == TestOutput.model_json_schema() + + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": mock_response_text}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + with patch.object(model, "stream", side_effect=mock_stream): + messages = [{"role": "user", "content": [{"text": "Is the earth round?"}]}] + + events = [] + async for event in model.structured_output(TestOutput, messages): + events.append(event) + + # Check we got the output + output_event = next((e for e in events if "output" in e), None) + assert output_event is not None + assert output_event["output"].answer == "yes" + assert output_event["output"].confidence == 0.95 + + +@pytest.mark.asyncio +async def test_structured_output_invalid_json_error() -> None: + """Test structured output raises error for invalid JSON.""" + + class TestOutput(BaseModel): + """Test output model for invalid JSON testing.""" + + value: int + + model = LlamaCppModel() + + # Mock stream that returns invalid JSON + async def mock_stream(*_args, **_kwargs): + # Check that json_schema was set correctly + assert model.config["params"]["json_schema"] == TestOutput.model_json_schema() + + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": "This is not valid JSON"}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + with patch.object(model, "stream", side_effect=mock_stream): + messages = [{"role": "user", "content": [{"text": "Give me a number"}]}] + + with pytest.raises(json.JSONDecodeError): + async for _ in model.structured_output(TestOutput, messages): + pass + + +def test_format_audio_content() -> None: + """Test formatting of audio content for llama.cpp multimodal models.""" + model = LlamaCppModel() + + # Create test audio data + audio_bytes = b"fake audio data" + audio_content = {"audio": {"source": {"bytes": audio_bytes}, "format": "wav"}} + + # Format the content + result = model._format_message_content(audio_content) + + # Verify the structure + assert result["type"] == "input_audio" + assert "input_audio" in result + assert "data" in result["input_audio"] + assert "format" in result["input_audio"] + + # Verify the data is base64 encoded + decoded = base64.b64decode(result["input_audio"]["data"]) + assert decoded == audio_bytes + + # Verify format is preserved + assert result["input_audio"]["format"] == "wav" + + +def test_format_audio_content_default_format() -> None: + """Test audio content formatting uses wav as default format.""" + model = LlamaCppModel() + + audio_content = { + "audio": {"source": {"bytes": b"test audio"}} + # No format specified + } + + result = model._format_message_content(audio_content) + + # Should default to wav + assert result["input_audio"]["format"] == "wav" + + +def test_format_messages_with_audio() -> None: + """Test that _format_messages properly handles audio content.""" + model = LlamaCppModel() + + # Create messages with audio content + messages = [ + { + "role": "user", + "content": [ + {"text": "Listen to this audio:"}, + {"audio": {"source": {"bytes": b"audio data"}, "format": "mp3"}}, + ], + } + ] + + # Format the messages + result = model._format_messages(messages) + + # Check structure + assert len(result) == 1 + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 2 + + # Check text content + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][0]["text"] == "Listen to this audio:" + + # Check audio content + assert result[0]["content"][1]["type"] == "input_audio" + assert "input_audio" in result[0]["content"][1] + assert result[0]["content"][1]["input_audio"]["format"] == "mp3" + + +def test_format_messages_with_system_prompt() -> None: + """Test _format_messages includes system prompt.""" + model = LlamaCppModel() + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt = "You are a helpful assistant" + + result = model._format_messages(messages, system_prompt) + + # Should have system message first + assert len(result) == 2 + assert result[0]["role"] == "system" + assert result[0]["content"] == system_prompt + assert result[1]["role"] == "user" + + +def test_format_messages_with_image() -> None: + """Test that _format_messages properly handles image content.""" + model = LlamaCppModel() + + # Create messages with image content + messages = [ + { + "role": "user", + "content": [ + {"text": "Describe this image:"}, + {"image": {"source": {"bytes": b"image data"}, "format": "png"}}, + ], + } + ] + + # Format the messages + result = model._format_messages(messages) + + # Check structure + assert len(result) == 1 + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 2 + + # Check text content + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][0]["text"] == "Describe this image:" + + # Check image content uses standard format + assert result[0]["content"][1]["type"] == "image_url" + assert "image_url" in result[0]["content"][1] + assert "url" in result[0]["content"][1]["image_url"] + assert result[0]["content"][1]["image_url"]["url"].startswith("data:image/png;base64,") + + +def test_format_messages_with_mixed_content() -> None: + """Test that _format_messages handles mixed audio and image content correctly.""" + model = LlamaCppModel() + + # Create messages with both audio and image content + messages = [ + { + "role": "user", + "content": [ + {"text": "Analyze this media:"}, + {"audio": {"source": {"bytes": b"audio data"}, "format": "wav"}}, + {"image": {"source": {"bytes": b"image data"}, "format": "jpg"}}, + ], + } + ] + + # Format the messages + result = model._format_messages(messages) + + # Check structure + assert len(result) == 1 + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 3 + + # Check text content + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][0]["text"] == "Analyze this media:" + + # Check audio content uses llama.cpp specific format + assert result[0]["content"][1]["type"] == "input_audio" + assert "input_audio" in result[0]["content"][1] + assert result[0]["content"][1]["input_audio"]["format"] == "wav" + + # Check image content uses standard OpenAI format + assert result[0]["content"][2]["type"] == "image_url" + assert "image_url" in result[0]["content"][2] + assert result[0]["content"][2]["image_url"]["url"].startswith("data:image/jpeg;base64,") diff --git a/tests_integ/models/test_model_llamacpp.py b/tests_integ/models/test_model_llamacpp.py new file mode 100644 index 000000000..95047e7ab --- /dev/null +++ b/tests_integ/models/test_model_llamacpp.py @@ -0,0 +1,510 @@ +"""Integration tests for llama.cpp model provider. + +These tests require a running llama.cpp server instance. +To run these tests: +1. Start llama.cpp server: llama-server -m model.gguf --host 0.0.0.0 --port 8080 +2. Run: pytest tests_integ/models/test_model_llamacpp.py + +Set LLAMACPP_TEST_URL environment variable to use a different server URL. +""" + +import os + +import pytest +from pydantic import BaseModel + +from strands.models.llamacpp import LlamaCppModel +from strands.types.content import Message + +# Get server URL from environment or use default +LLAMACPP_URL = os.environ.get("LLAMACPP_TEST_URL", "http://localhost:8080/v1") + +# Skip these tests if LLAMACPP_SKIP_TESTS is set +pytestmark = pytest.mark.skipif( + os.environ.get("LLAMACPP_SKIP_TESTS", "true").lower() == "true", + reason="llama.cpp integration tests disabled (set LLAMACPP_SKIP_TESTS=false to enable)", +) + + +class WeatherOutput(BaseModel): + """Test output model for structured responses.""" + + temperature: float + condition: str + location: str + + +@pytest.fixture +async def llamacpp_model() -> LlamaCppModel: + """Fixture to create a llama.cpp model instance.""" + return LlamaCppModel(base_url=LLAMACPP_URL) + + +# Integration tests for LlamaCppModel with a real server + + +@pytest.mark.asyncio +async def test_basic_completion(llamacpp_model: LlamaCppModel) -> None: + """Test basic text completion.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Say 'Hello, World!' and nothing else."}]}, + ] + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + assert "Hello, World!" in response_text + + +@pytest.mark.asyncio +async def test_system_prompt(llamacpp_model: LlamaCppModel) -> None: + """Test completion with system prompt.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Who are you?"}]}, + ] + + system_prompt = "You are a helpful AI assistant named Claude." + + response_text = "" + async for event in llamacpp_model.stream(messages, system_prompt=system_prompt): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Response should reflect the system prompt + assert len(response_text) > 0 + assert "assistant" in response_text.lower() or "claude" in response_text.lower() + + +@pytest.mark.asyncio +async def test_streaming_chunks(llamacpp_model: LlamaCppModel) -> None: + """Test that streaming returns proper chunk sequence.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Count from 1 to 3."}]}, + ] + + chunk_types = [] + async for event in llamacpp_model.stream(messages): + chunk_types.append(next(iter(event.keys()))) + + # Verify proper chunk sequence + assert chunk_types[0] == "messageStart" + assert chunk_types[1] == "contentBlockStart" + assert "contentBlockDelta" in chunk_types + assert chunk_types[-3] == "contentBlockStop" + assert chunk_types[-2] == "messageStop" + assert chunk_types[-1] == "metadata" + + +@pytest.mark.asyncio +async def test_temperature_parameter(llamacpp_model: LlamaCppModel) -> None: + """Test temperature parameter affects randomness.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Generate a random word."}]}, + ] + + # Low temperature should give more consistent results + llamacpp_model.update_config(params={"temperature": 0.1, "seed": 42}) + + response1 = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response1 += delta["text"] + + # Same seed and low temperature should give similar result + response2 = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response2 += delta["text"] + + # With low temperature and same seed, responses should be very similar + assert len(response1) > 0 + assert len(response2) > 0 + + +@pytest.mark.asyncio +async def test_max_tokens_limit(llamacpp_model: LlamaCppModel) -> None: + """Test max_tokens parameter limits response length.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Tell me a very long story about dragons."}]}, + ] + + # Set very low token limit + llamacpp_model.update_config(params={"max_tokens": 10}) + + token_count = 0 + async for event in llamacpp_model.stream(messages): + if "metadata" in event: + usage = event["metadata"]["usage"] + token_count = usage["outputTokens"] + if "messageStop" in event: + stop_reason = event["messageStop"]["stopReason"] + + # Should stop due to max_tokens + assert token_count <= 15 # Allow small overage due to tokenization + assert stop_reason == "max_tokens" + + +@pytest.mark.asyncio +async def test_structured_output(llamacpp_model: LlamaCppModel) -> None: + """Test structured output generation.""" + messages: list[Message] = [ + { + "role": "user", + "content": [ + { + "text": "What's the weather like in Paris? " + "Respond with temperature in Celsius, condition, and location." + } + ], + }, + ] + + # Enable JSON response format for structured output + llamacpp_model.update_config(params={"response_format": {"type": "json_object"}}) + + result = None + async for event in llamacpp_model.structured_output(WeatherOutput, messages): + if "output" in event: + result = event["output"] + + assert result is not None + assert isinstance(result, WeatherOutput) + assert isinstance(result.temperature, float) + assert isinstance(result.condition, str) + assert result.location.lower() == "paris" + + +@pytest.mark.asyncio +async def test_llamacpp_specific_params(llamacpp_model: LlamaCppModel) -> None: + """Test llama.cpp specific parameters.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Say 'test' five times."}]}, + ] + + # Use llama.cpp specific parameters + llamacpp_model.update_config( + params={ + "repeat_penalty": 1.5, # Penalize repetition + "top_k": 10, # Limit vocabulary + "min_p": 0.1, # Min-p sampling + } + ) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Response should contain "test" but with repetition penalty it might vary + assert "test" in response_text.lower() + + +@pytest.mark.asyncio +async def test_advanced_sampling_params(llamacpp_model: LlamaCppModel) -> None: + """Test advanced sampling parameters.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Generate a random sentence about space."}]}, + ] + + # Test advanced sampling parameters + llamacpp_model.update_config( + params={ + "temperature": 0.8, + "tfs_z": 0.95, # Tail-free sampling + "top_a": 0.1, # Top-a sampling + "typical_p": 0.9, # Typical-p sampling + "penalty_last_n": 64, # Penalty context window + "min_keep": 1, # Minimum tokens to keep + "samplers": ["top_k", "tfs_z", "typical_p", "top_p", "min_p", "temperature"], + } + ) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Should generate something about space + assert len(response_text) > 0 + assert any(word in response_text.lower() for word in ["space", "star", "planet", "galaxy", "universe"]) + + +@pytest.mark.asyncio +async def test_mirostat_sampling(llamacpp_model: LlamaCppModel) -> None: + """Test Mirostat sampling modes.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Write a short poem."}]}, + ] + + # Test Mirostat v2 + llamacpp_model.update_config( + params={ + "mirostat": 2, + "mirostat_lr": 0.1, + "mirostat_ent": 5.0, + "seed": 42, # For reproducibility + } + ) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Should generate a poem + assert len(response_text) > 20 + assert "\n" in response_text # Poems typically have line breaks + + +@pytest.mark.asyncio +async def test_grammar_constraint(llamacpp_model: LlamaCppModel) -> None: + """Test grammar constraint feature (llama.cpp specific).""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Is the sky blue? Answer yes or no."}]}, + ] + + # Set grammar constraint via params + grammar = """ + root ::= answer + answer ::= "yes" | "no" + """ + llamacpp_model.update_config(params={"grammar": grammar}) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Response should be exactly "yes" or "no" + assert response_text.strip().lower() in ["yes", "no"] + + +@pytest.mark.asyncio +async def test_json_schema_constraint(llamacpp_model: LlamaCppModel) -> None: + """Test JSON schema constraint feature.""" + messages: list[Message] = [ + { + "role": "user", + "content": [{"text": "Describe the weather in JSON format with temperature and description."}], + }, + ] + + # Set JSON schema constraint via params + schema = { + "type": "object", + "properties": {"temperature": {"type": "number"}, "description": {"type": "string"}}, + "required": ["temperature", "description"], + } + llamacpp_model.update_config(params={"json_schema": schema}) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Should be valid JSON matching the schema + import json + + data = json.loads(response_text.strip()) + assert "temperature" in data + assert "description" in data + assert isinstance(data["temperature"], (int, float)) + assert isinstance(data["description"], str) + + +@pytest.mark.asyncio +async def test_logit_bias(llamacpp_model: LlamaCppModel) -> None: + """Test logit bias feature.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Choose between 'cat' and 'dog'."}]}, + ] + + # This is a simplified test - in reality you'd need to know the actual token IDs + # for "cat" and "dog" in the model's vocabulary + llamacpp_model.update_config( + params={ + "logit_bias": { + # These are placeholder token IDs - real implementation would need actual token IDs + 1234: 10.0, # Strong positive bias (hypothetical "cat" token) + 5678: -10.0, # Strong negative bias (hypothetical "dog" token) + }, + "seed": 42, # For reproducibility + } + ) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Should generate text (exact behavior depends on actual token IDs) + assert len(response_text) > 0 + + +@pytest.mark.asyncio +async def test_cache_prompt(llamacpp_model: LlamaCppModel) -> None: + """Test prompt caching feature.""" + messages: list[Message] = [ + {"role": "system", "content": [{"text": "You are a helpful assistant. Always be concise."}]}, + {"role": "user", "content": [{"text": "What is 2+2?"}]}, + ] + + # Enable prompt caching + llamacpp_model.update_config( + params={ + "cache_prompt": True, + "slot_id": 0, # Use specific slot for caching + } + ) + + # First request + response1 = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response1 += delta["text"] + + # Second request with same system prompt should use cache + messages2 = [ + {"role": "system", "content": [{"text": "You are a helpful assistant. Always be concise."}]}, + {"role": "user", "content": [{"text": "What is 3+3?"}]}, + ] + + response2 = "" + async for event in llamacpp_model.stream(messages2): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response2 += delta["text"] + + # Both should give valid responses + assert "4" in response1 + assert "6" in response2 + + +@pytest.mark.asyncio +async def test_concurrent_requests(llamacpp_model: LlamaCppModel) -> None: + """Test handling multiple concurrent requests.""" + import asyncio + + async def make_request(prompt: str) -> str: + messages: list[Message] = [ + {"role": "user", "content": [{"text": prompt}]}, + ] + + response = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response += delta["text"] + return response + + # Make concurrent requests + prompts = [ + "Say 'one'", + "Say 'two'", + "Say 'three'", + ] + + responses = await asyncio.gather(*[make_request(p) for p in prompts]) + + # Each response should contain the expected number + assert "one" in responses[0].lower() + assert "two" in responses[1].lower() + assert "three" in responses[2].lower() + + +@pytest.mark.asyncio +async def test_enhanced_structured_output(llamacpp_model: LlamaCppModel) -> None: + """Test enhanced structured output with native JSON schema support.""" + + class BookInfo(BaseModel): + title: str + author: str + year: int + genres: list[str] + + messages: list[Message] = [ + { + "role": "user", + "content": [ + { + "text": "Create information about a fictional science fiction book. " + "Include title, author, publication year, and 2-3 genres." + } + ], + }, + ] + + result = None + events = [] + async for event in llamacpp_model.structured_output(BookInfo, messages): + events.append(event) + if "output" in event: + result = event["output"] + + # Verify we got structured output + assert result is not None + assert isinstance(result, BookInfo) + assert isinstance(result.title, str) and len(result.title) > 0 + assert isinstance(result.author, str) and len(result.author) > 0 + assert isinstance(result.year, int) and 1900 <= result.year <= 2100 + assert isinstance(result.genres, list) and len(result.genres) >= 2 + assert all(isinstance(genre, str) for genre in result.genres) + + # Should have streamed events before the output + assert len(events) > 1 + + +@pytest.mark.asyncio +async def test_context_overflow_handling(llamacpp_model: LlamaCppModel) -> None: + """Test proper handling of context window overflow.""" + # Create a very long message that might exceed context + long_text = "This is a test sentence. " * 1000 + messages: list[Message] = [ + {"role": "user", "content": [{"text": f"Summarize this text: {long_text}"}]}, + ] + + try: + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # If it succeeds, we got a response + assert len(response_text) > 0 + except Exception as e: + # If it fails, it should be our custom error + from strands.types.exceptions import ContextWindowOverflowException + + if isinstance(e, ContextWindowOverflowException): + assert "context" in str(e).lower() + else: + # Some other error - re-raise to see what it was + raise