diff --git a/README.md b/README.md index 08d6bff03..750354c26 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,7 @@ from strands import Agent from strands.models import BedrockModel from strands.models.ollama import OllamaModel from strands.models.llamaapi import LlamaAPIModel +from strands.models.vllm import VLLMModel # Bedrock bedrock_model = BedrockModel( @@ -130,6 +131,14 @@ llama_model = LlamaAPIModel( ) agent = Agent(model=llama_model) response = agent("Tell me about Agentic AI") + +# vLLM +vllm_modal = VLLMModel( + host="http://localhost:8000", + model_id="Qwen/Qwen3-4B" +) +agent_vllm = Agent(model=vllm_modal) +agent_vllm("Tell me about Agentic AI") ``` Built-in providers: diff --git a/pyproject.toml b/pyproject.toml index 6582bdddc..b3ece88c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,9 +75,12 @@ ollama = [ llamaapi = [ "llama-api-client>=0.1.0,<1.0.0", ] +vllm = [ + "vllm>=0.8.5", +] [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama"] +features = ["anthropic", "litellm", "llamaapi", "ollama","vllm"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", diff --git a/src/strands/models/vllm.py b/src/strands/models/vllm.py new file mode 100644 index 000000000..64ce042ae --- /dev/null +++ b/src/strands/models/vllm.py @@ -0,0 +1,310 @@ +"""vLLM model provider. + +- Docs: https://docs.vllm.ai/en/latest/index.html +""" +import json +import logging +import re +from collections import namedtuple +from typing import Any, Iterable, Optional + +import requests +from typing_extensions import TypedDict, Unpack, override + +from ..types.content import Messages +from ..types.models import Model +from ..types.streaming import StreamEvent +from ..types.tools import ToolSpec + +logger = logging.getLogger(__name__) + + +class VLLMModel(Model): + """vLLM model provider implementation for OpenAI compatible /v1/chat/completions endpoint.""" + + class VLLMConfig(TypedDict, total=False): + """Configuration options for vLLM models. + + Attributes: + model_id: Model ID (e.g., "Qwen/Qwen3-4B"). + temperature: Optional[float] + top_p: Optional[float] + max_tokens: Optional[int] + stop_sequences: Optional[list[str]] + additional_args: Optional[dict[str, Any]] + """ + + model_id: str + temperature: Optional[float] + top_p: Optional[float] + max_tokens: Optional[int] + stop_sequences: Optional[list[str]] + additional_args: Optional[dict[str, Any]] + + def __init__(self, host: str, **model_config: Unpack[VLLMConfig]) -> None: + """Initialize provider instance. + + Args: + host: Host and port of the vLLM Inference Server + **model_config: Configuration options for the LiteLLM model. + """ + self.config = VLLMModel.VLLMConfig(**model_config) + self.host = host.rstrip("/") + logger.debug("Initializing vLLM provider with config: %s", self.config) + + @override + def update_config(self, **model_config: Unpack[VLLMConfig]) -> None: + """Update the vLLM model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + self.config.update(model_config) + + @override + def get_config(self) -> VLLMConfig: + """Get the vLLM model configuration. + + Returns: + The vLLM model configuration. + """ + return self.config + + @override + def format_request( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + ) -> dict[str, Any]: + """Format a vLLM chat streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + A vLLM chat streaming request. + """ + + def format_message(msg: dict[str, Any], content: dict[str, Any]) -> dict[str, Any]: + if "text" in content: + return {"role": msg["role"], "content": content["text"]} + if "toolUse" in content: + return { + "role": "assistant", + "tool_calls": [ + { + "id": content["toolUse"]["toolUseId"], + "type": "function", + "function": { + "name": content["toolUse"]["name"], + "arguments": json.dumps(content["toolUse"]["input"]), + }, + } + ], + } + if "toolResult" in content: + return { + "role": "tool", + "tool_call_id": content["toolResult"]["toolUseId"], + "content": json.dumps(content["toolResult"]["content"]), + } + return {"role": msg["role"], "content": json.dumps(content)} + + chat_messages = [] + if system_prompt: + chat_messages.append({"role": "system", "content": system_prompt}) + for msg in messages: + for content in msg["content"]: + chat_messages.append(format_message(msg, content)) + + payload = { + "model": self.config["model_id"], + "messages": chat_messages, + "temperature": self.config.get("temperature", 0.7), + "top_p": self.config.get("top_p", 1.0), + "max_tokens": self.config.get("max_tokens", 2048), + "stream": True, + } + + if self.config.get("stop_sequences"): + payload["stop"] = self.config["stop_sequences"] + + if tool_specs: + payload["tools"] = [ + { + "type": "function", + "function": { + "name": tool["name"], + "description": tool["description"], + "parameters": tool["inputSchema"]["json"], + }, + } + for tool in tool_specs + ] + + if self.config.get("additional_args"): + payload.update(self.config["additional_args"]) + + logger.debug("Formatted vLLM Request:\n%s", json.dumps(payload, indent=2)) + return payload + + @override + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format the vLLM response events into standardized message chunks. + + Args: + event: A response event from the vLLM model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + This error should never be encountered as we control chunk_type in the stream method. + """ + from collections import namedtuple + + Function = namedtuple("Function", ["name", "arguments"]) + + if event.get("chunk_type") == "message_start": + return {"messageStart": {"role": "assistant"}} + + if event.get("chunk_type") == "content_start": + if event["data_type"] == "text": + return {"contentBlockStart": {"start": {}}} + + tool: Function = event["data"] + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": tool.name, + "toolUseId": tool.name, + } + } + } + } + + if event.get("chunk_type") == "content_delta": + if event["data_type"] == "text": + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + tool: Function = event["data"] + return { + "contentBlockDelta": { + "delta": { + "toolUse": { + "input": json.dumps(tool.arguments) # This is already a dict + } + } + } + } + + if event.get("chunk_type") == "content_stop": + return {"contentBlockStop": {}} + + if event.get("chunk_type") == "message_stop": + reason = event["data"] + if reason == "tool_use": + return {"messageStop": {"stopReason": "tool_use"}} + elif reason == "length": + return {"messageStop": {"stopReason": "max_tokens"}} + else: + return {"messageStop": {"stopReason": "end_turn"}} + + if event.get("chunk_type") == "metadata": + usage = event.get("data", {}) + return { + "metadata": { + "usage": { + "inputTokens": usage.get("prompt_eval_count", 0), + "outputTokens": usage.get("eval_count", 0), + "totalTokens": usage.get("prompt_eval_count", 0) + usage.get("eval_count", 0), + }, + "metrics": { + "latencyMs": usage.get("total_duration", 0) / 1e6, + }, + } + } + + raise RuntimeError(f"chunk_type=<{event.get('chunk_type')}> | unknown type") + + @override + def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: + """Send the request to the vLLM model and get the streaming response. + + Args: + request: The formatted request to send to the vLLM model. + + Returns: + An iterable of response events from the vLLM model. + """ + + Function = namedtuple("Function", ["name", "arguments"]) + + headers = {"Content-Type": "application/json"} + url = f"{self.host}/v1/chat/completions" + + accumulated_content = [] + tool_requested = False + + try: + with requests.post(url, headers=headers, data=json.dumps(request), stream=True) as response: + if response.status_code != 200: + logger.error("vLLM server error: %d - %s", response.status_code, response.text) + raise Exception(f"Request failed: {response.status_code} - {response.text}") + + yield {"chunk_type": "message_start"} + yield {"chunk_type": "content_start", "data_type": "text"} + + for line in response.iter_lines(decode_unicode=True): + if not line or not line.startswith("data: "): + continue + line = line[len("data: ") :].strip() + + if line == "[DONE]": + break + + try: + event = json.loads(line) + choices = event.get("choices", []) + if choices: + delta = choices[0].get("delta", {}) + content = delta.get("content") + if content: + accumulated_content.append(content) + + yield {"chunk_type": "content_delta", "data_type": "text", "data": content or ""} + + except json.JSONDecodeError: + logger.warning("Failed to parse line: %s", line) + continue + + yield {"chunk_type": "content_stop", "data_type": "text"} + + full_content = "".join(accumulated_content) + + tool_call_blocks = re.findall(r"(.*?)", full_content, re.DOTALL) + for idx, block in enumerate(tool_call_blocks): + try: + tool_call_data = json.loads(block.strip()) + func = Function(name=tool_call_data["name"], arguments=tool_call_data.get("arguments", {})) + func_str = f"function=Function(name='{func.name}', arguments={func.arguments})" + + yield {"chunk_type": "content_start", "data_type": "tool", "data": func} + yield {"chunk_type": "content_delta", "data_type": "tool", "data": func} + yield {"chunk_type": "content_stop", "data_type": "tool", "data": func} + tool_requested = True + + except json.JSONDecodeError: + logger.warning(f"Failed to parse tool_call block #{idx}: {block}") + continue + + yield {"chunk_type": "message_stop", "data": "tool_use" if tool_requested else "end_turn"} + + except requests.RequestException as e: + logger.error("Streaming request failed: %s", str(e)) + raise Exception("Failed to reach vLLM server") from e diff --git a/tests-integ/test_model_vllm.py b/tests-integ/test_model_vllm.py new file mode 100644 index 000000000..49df3902d --- /dev/null +++ b/tests-integ/test_model_vllm.py @@ -0,0 +1,45 @@ +import pytest +import strands +from strands import Agent +from strands.models.vllm import VLLMModel + + +@pytest.fixture +def model(): + return VLLMModel( + model_id="Qwen/Qwen3-4B", + host="http://localhost:8000", + max_tokens=128, + ) + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "cloudy" + + return [tool_time, tool_weather] + + +@pytest.fixture +def agent(model, tools): + return Agent(model=model, tools=tools) + + +def test_agent(agent): + # Send prompt + result = agent("What is the time and weather in Melbourne Australia?") + + # Extract plain text from the first content block + text_blocks = result.message.get("content", []) + # content is a list of dicts with 'text' keys + text = " ".join(block.get("text", "") for block in text_blocks).lower() + + # Assert that the tool outputs appear in the generated response text + assert "tool_weather" in text + #assert "cloudy" in text diff --git a/tests/strands/models/test_vllm.py b/tests/strands/models/test_vllm.py new file mode 100644 index 000000000..f21741e4b --- /dev/null +++ b/tests/strands/models/test_vllm.py @@ -0,0 +1,153 @@ +import pytest +import requests +import json + +from types import SimpleNamespace +from strands.models.vllm import VLLMModel + + +@pytest.fixture +def model_id(): + return "meta-llama/Llama-3.2-3B" + + +@pytest.fixture +def host(): + return "http://localhost:8000" + + +@pytest.fixture +def model(model_id, host): + return VLLMModel(host, model_id=model_id, max_tokens=128) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "Hello"}]}] + + +@pytest.fixture +def system_prompt(): + return "You are a helpful assistant." + + +def test_init_sets_config(model, model_id): + assert model.get_config()["model_id"] == model_id + assert model.host == "http://localhost:8000" + + +def test_update_config_overrides(model): + model.update_config(temperature=0.3) + assert model.get_config()["temperature"] == 0.3 + + +def test_format_request_basic(model, messages): + request = model.format_request(messages) + assert request["model"] == model.get_config()["model_id"] + assert isinstance(request["messages"], list) + assert request["messages"][0]["role"] == "user" + assert request["messages"][0]["content"] == "Hello" + assert request["stream"] is True + + +def test_format_request_with_system_prompt(model, messages, system_prompt): + request = model.format_request(messages, system_prompt=system_prompt) + assert request["messages"][0]["role"] == "system" + assert request["messages"][0]["content"] == system_prompt + + +def test_format_chunk_text(): + chunk = {"chunk_type": "content_delta", "data_type": "text", "data": "World"} + formatted = VLLMModel.format_chunk(None, chunk) + assert formatted == {"contentBlockDelta": {"delta": {"text": "World"}}} + + +def test_format_chunk_tool_call_delta(): + chunk = { + "chunk_type": "content_delta", + "data_type": "tool", + "data": SimpleNamespace(name="get_time", arguments={"timezone": "UTC"}), + } + + formatted = VLLMModel.format_chunk(None, chunk) + assert "contentBlockDelta" in formatted + assert "toolUse" in formatted["contentBlockDelta"]["delta"] + assert json.loads(formatted["contentBlockDelta"]["delta"]["toolUse"]["input"])["timezone"] == "UTC" + + +def test_stream_response(monkeypatch, model, messages): + mock_lines = [ + 'data: {"choices":[{"delta":{"content":"Hello"}}]}\n', + 'data: {"choices":[{"delta":{"content":" world"}}]}\n', + "data: [DONE]\n", + ] + + class MockResponse: + def __init__(self): + self.status_code = 200 + + def __enter__(self): + return self + + def __exit__(self, *a): pass + + def iter_lines(self, decode_unicode=False): + return iter(mock_lines) + + monkeypatch.setattr(requests, "post", lambda *a, **kw: MockResponse()) + + chunks = list(model.stream(model.format_request(messages))) + chunk_types = [c.get("chunk_type") for c in chunks] + + assert "message_start" in chunk_types + assert chunk_types.count("content_delta") == 2 + assert "content_stop" in chunk_types + assert "message_stop" in chunk_types + + +def test_stream_tool_call(monkeypatch, model, messages): + tool_call = { + "name": "current_time", + "arguments": {"timezone": "UTC"}, + } + tool_call_json = json.dumps(tool_call) + data_str = json.dumps({ + "choices": [ + {"delta": {"content": f"{tool_call_json}"}} + ] + }) + mock_lines = [ + 'data: {"choices":[{"delta":{"content":"Some answer before tool."}}]}\n', + f"data: {data_str}\n", + "data: [DONE]\n", + ] + + class MockResponse: + def __init__(self): self.status_code = 200 + def __enter__(self): return self + def __exit__(self, *a): pass + def iter_lines(self, decode_unicode=False): return iter(mock_lines) + + monkeypatch.setattr(requests, "post", lambda *a, **kw: MockResponse()) + + chunks = list(model.stream(model.format_request(messages))) + tool_chunks = [c for c in chunks if c.get("chunk_type") == "content_start" and c.get("data_type") == "tool"] + + assert tool_chunks + assert any("tool_use" in c.get("chunk_type", "") or "tool" in c.get("data_type", "") for c in chunks) + + + +def test_stream_server_error(monkeypatch, model, messages): + class ErrorResponse: + def __init__(self): + self.status_code = 500 + self.text = "Internal Error" + def __enter__(self): return self + def __exit__(self, *a): pass + def iter_lines(self, decode_unicode=False): return iter([]) + + monkeypatch.setattr(requests, "post", lambda *a, **kw: ErrorResponse()) + + with pytest.raises(Exception, match="Request failed: 500"): + list(model.stream(model.format_request(messages)))