diff --git a/README.md b/README.md
index 08d6bff03..750354c26 100644
--- a/README.md
+++ b/README.md
@@ -107,6 +107,7 @@ from strands import Agent
from strands.models import BedrockModel
from strands.models.ollama import OllamaModel
from strands.models.llamaapi import LlamaAPIModel
+from strands.models.vllm import VLLMModel
# Bedrock
bedrock_model = BedrockModel(
@@ -130,6 +131,14 @@ llama_model = LlamaAPIModel(
)
agent = Agent(model=llama_model)
response = agent("Tell me about Agentic AI")
+
+# vLLM
+vllm_modal = VLLMModel(
+ host="http://localhost:8000",
+ model_id="Qwen/Qwen3-4B"
+)
+agent_vllm = Agent(model=vllm_modal)
+agent_vllm("Tell me about Agentic AI")
```
Built-in providers:
diff --git a/pyproject.toml b/pyproject.toml
index 6582bdddc..b3ece88c3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -75,9 +75,12 @@ ollama = [
llamaapi = [
"llama-api-client>=0.1.0,<1.0.0",
]
+vllm = [
+ "vllm>=0.8.5",
+]
[tool.hatch.envs.hatch-static-analysis]
-features = ["anthropic", "litellm", "llamaapi", "ollama"]
+features = ["anthropic", "litellm", "llamaapi", "ollama","vllm"]
dependencies = [
"mypy>=1.15.0,<2.0.0",
"ruff>=0.11.6,<0.12.0",
diff --git a/src/strands/models/vllm.py b/src/strands/models/vllm.py
new file mode 100644
index 000000000..64ce042ae
--- /dev/null
+++ b/src/strands/models/vllm.py
@@ -0,0 +1,310 @@
+"""vLLM model provider.
+
+- Docs: https://docs.vllm.ai/en/latest/index.html
+"""
+import json
+import logging
+import re
+from collections import namedtuple
+from typing import Any, Iterable, Optional
+
+import requests
+from typing_extensions import TypedDict, Unpack, override
+
+from ..types.content import Messages
+from ..types.models import Model
+from ..types.streaming import StreamEvent
+from ..types.tools import ToolSpec
+
+logger = logging.getLogger(__name__)
+
+
+class VLLMModel(Model):
+ """vLLM model provider implementation for OpenAI compatible /v1/chat/completions endpoint."""
+
+ class VLLMConfig(TypedDict, total=False):
+ """Configuration options for vLLM models.
+
+ Attributes:
+ model_id: Model ID (e.g., "Qwen/Qwen3-4B").
+ temperature: Optional[float]
+ top_p: Optional[float]
+ max_tokens: Optional[int]
+ stop_sequences: Optional[list[str]]
+ additional_args: Optional[dict[str, Any]]
+ """
+
+ model_id: str
+ temperature: Optional[float]
+ top_p: Optional[float]
+ max_tokens: Optional[int]
+ stop_sequences: Optional[list[str]]
+ additional_args: Optional[dict[str, Any]]
+
+ def __init__(self, host: str, **model_config: Unpack[VLLMConfig]) -> None:
+ """Initialize provider instance.
+
+ Args:
+ host: Host and port of the vLLM Inference Server
+ **model_config: Configuration options for the LiteLLM model.
+ """
+ self.config = VLLMModel.VLLMConfig(**model_config)
+ self.host = host.rstrip("/")
+ logger.debug("Initializing vLLM provider with config: %s", self.config)
+
+ @override
+ def update_config(self, **model_config: Unpack[VLLMConfig]) -> None:
+ """Update the vLLM model configuration with the provided arguments.
+
+ Args:
+ **model_config: Configuration overrides.
+ """
+ self.config.update(model_config)
+
+ @override
+ def get_config(self) -> VLLMConfig:
+ """Get the vLLM model configuration.
+
+ Returns:
+ The vLLM model configuration.
+ """
+ return self.config
+
+ @override
+ def format_request(
+ self,
+ messages: Messages,
+ tool_specs: Optional[list[ToolSpec]] = None,
+ system_prompt: Optional[str] = None,
+ ) -> dict[str, Any]:
+ """Format a vLLM chat streaming request.
+
+ Args:
+ messages: List of message objects to be processed by the model.
+ tool_specs: List of tool specifications to make available to the model.
+ system_prompt: System prompt to provide context to the model.
+
+ Returns:
+ A vLLM chat streaming request.
+ """
+
+ def format_message(msg: dict[str, Any], content: dict[str, Any]) -> dict[str, Any]:
+ if "text" in content:
+ return {"role": msg["role"], "content": content["text"]}
+ if "toolUse" in content:
+ return {
+ "role": "assistant",
+ "tool_calls": [
+ {
+ "id": content["toolUse"]["toolUseId"],
+ "type": "function",
+ "function": {
+ "name": content["toolUse"]["name"],
+ "arguments": json.dumps(content["toolUse"]["input"]),
+ },
+ }
+ ],
+ }
+ if "toolResult" in content:
+ return {
+ "role": "tool",
+ "tool_call_id": content["toolResult"]["toolUseId"],
+ "content": json.dumps(content["toolResult"]["content"]),
+ }
+ return {"role": msg["role"], "content": json.dumps(content)}
+
+ chat_messages = []
+ if system_prompt:
+ chat_messages.append({"role": "system", "content": system_prompt})
+ for msg in messages:
+ for content in msg["content"]:
+ chat_messages.append(format_message(msg, content))
+
+ payload = {
+ "model": self.config["model_id"],
+ "messages": chat_messages,
+ "temperature": self.config.get("temperature", 0.7),
+ "top_p": self.config.get("top_p", 1.0),
+ "max_tokens": self.config.get("max_tokens", 2048),
+ "stream": True,
+ }
+
+ if self.config.get("stop_sequences"):
+ payload["stop"] = self.config["stop_sequences"]
+
+ if tool_specs:
+ payload["tools"] = [
+ {
+ "type": "function",
+ "function": {
+ "name": tool["name"],
+ "description": tool["description"],
+ "parameters": tool["inputSchema"]["json"],
+ },
+ }
+ for tool in tool_specs
+ ]
+
+ if self.config.get("additional_args"):
+ payload.update(self.config["additional_args"])
+
+ logger.debug("Formatted vLLM Request:\n%s", json.dumps(payload, indent=2))
+ return payload
+
+ @override
+ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
+ """Format the vLLM response events into standardized message chunks.
+
+ Args:
+ event: A response event from the vLLM model.
+
+ Returns:
+ The formatted chunk.
+
+ Raises:
+ RuntimeError: If chunk_type is not recognized.
+ This error should never be encountered as we control chunk_type in the stream method.
+ """
+ from collections import namedtuple
+
+ Function = namedtuple("Function", ["name", "arguments"])
+
+ if event.get("chunk_type") == "message_start":
+ return {"messageStart": {"role": "assistant"}}
+
+ if event.get("chunk_type") == "content_start":
+ if event["data_type"] == "text":
+ return {"contentBlockStart": {"start": {}}}
+
+ tool: Function = event["data"]
+ return {
+ "contentBlockStart": {
+ "start": {
+ "toolUse": {
+ "name": tool.name,
+ "toolUseId": tool.name,
+ }
+ }
+ }
+ }
+
+ if event.get("chunk_type") == "content_delta":
+ if event["data_type"] == "text":
+ return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
+
+ tool: Function = event["data"]
+ return {
+ "contentBlockDelta": {
+ "delta": {
+ "toolUse": {
+ "input": json.dumps(tool.arguments) # This is already a dict
+ }
+ }
+ }
+ }
+
+ if event.get("chunk_type") == "content_stop":
+ return {"contentBlockStop": {}}
+
+ if event.get("chunk_type") == "message_stop":
+ reason = event["data"]
+ if reason == "tool_use":
+ return {"messageStop": {"stopReason": "tool_use"}}
+ elif reason == "length":
+ return {"messageStop": {"stopReason": "max_tokens"}}
+ else:
+ return {"messageStop": {"stopReason": "end_turn"}}
+
+ if event.get("chunk_type") == "metadata":
+ usage = event.get("data", {})
+ return {
+ "metadata": {
+ "usage": {
+ "inputTokens": usage.get("prompt_eval_count", 0),
+ "outputTokens": usage.get("eval_count", 0),
+ "totalTokens": usage.get("prompt_eval_count", 0) + usage.get("eval_count", 0),
+ },
+ "metrics": {
+ "latencyMs": usage.get("total_duration", 0) / 1e6,
+ },
+ }
+ }
+
+ raise RuntimeError(f"chunk_type=<{event.get('chunk_type')}> | unknown type")
+
+ @override
+ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
+ """Send the request to the vLLM model and get the streaming response.
+
+ Args:
+ request: The formatted request to send to the vLLM model.
+
+ Returns:
+ An iterable of response events from the vLLM model.
+ """
+
+ Function = namedtuple("Function", ["name", "arguments"])
+
+ headers = {"Content-Type": "application/json"}
+ url = f"{self.host}/v1/chat/completions"
+
+ accumulated_content = []
+ tool_requested = False
+
+ try:
+ with requests.post(url, headers=headers, data=json.dumps(request), stream=True) as response:
+ if response.status_code != 200:
+ logger.error("vLLM server error: %d - %s", response.status_code, response.text)
+ raise Exception(f"Request failed: {response.status_code} - {response.text}")
+
+ yield {"chunk_type": "message_start"}
+ yield {"chunk_type": "content_start", "data_type": "text"}
+
+ for line in response.iter_lines(decode_unicode=True):
+ if not line or not line.startswith("data: "):
+ continue
+ line = line[len("data: ") :].strip()
+
+ if line == "[DONE]":
+ break
+
+ try:
+ event = json.loads(line)
+ choices = event.get("choices", [])
+ if choices:
+ delta = choices[0].get("delta", {})
+ content = delta.get("content")
+ if content:
+ accumulated_content.append(content)
+
+ yield {"chunk_type": "content_delta", "data_type": "text", "data": content or ""}
+
+ except json.JSONDecodeError:
+ logger.warning("Failed to parse line: %s", line)
+ continue
+
+ yield {"chunk_type": "content_stop", "data_type": "text"}
+
+ full_content = "".join(accumulated_content)
+
+ tool_call_blocks = re.findall(r"(.*?)", full_content, re.DOTALL)
+ for idx, block in enumerate(tool_call_blocks):
+ try:
+ tool_call_data = json.loads(block.strip())
+ func = Function(name=tool_call_data["name"], arguments=tool_call_data.get("arguments", {}))
+ func_str = f"function=Function(name='{func.name}', arguments={func.arguments})"
+
+ yield {"chunk_type": "content_start", "data_type": "tool", "data": func}
+ yield {"chunk_type": "content_delta", "data_type": "tool", "data": func}
+ yield {"chunk_type": "content_stop", "data_type": "tool", "data": func}
+ tool_requested = True
+
+ except json.JSONDecodeError:
+ logger.warning(f"Failed to parse tool_call block #{idx}: {block}")
+ continue
+
+ yield {"chunk_type": "message_stop", "data": "tool_use" if tool_requested else "end_turn"}
+
+ except requests.RequestException as e:
+ logger.error("Streaming request failed: %s", str(e))
+ raise Exception("Failed to reach vLLM server") from e
diff --git a/tests-integ/test_model_vllm.py b/tests-integ/test_model_vllm.py
new file mode 100644
index 000000000..49df3902d
--- /dev/null
+++ b/tests-integ/test_model_vllm.py
@@ -0,0 +1,45 @@
+import pytest
+import strands
+from strands import Agent
+from strands.models.vllm import VLLMModel
+
+
+@pytest.fixture
+def model():
+ return VLLMModel(
+ model_id="Qwen/Qwen3-4B",
+ host="http://localhost:8000",
+ max_tokens=128,
+ )
+
+
+@pytest.fixture
+def tools():
+ @strands.tool
+ def tool_time() -> str:
+ return "12:00"
+
+ @strands.tool
+ def tool_weather() -> str:
+ return "cloudy"
+
+ return [tool_time, tool_weather]
+
+
+@pytest.fixture
+def agent(model, tools):
+ return Agent(model=model, tools=tools)
+
+
+def test_agent(agent):
+ # Send prompt
+ result = agent("What is the time and weather in Melbourne Australia?")
+
+ # Extract plain text from the first content block
+ text_blocks = result.message.get("content", [])
+ # content is a list of dicts with 'text' keys
+ text = " ".join(block.get("text", "") for block in text_blocks).lower()
+
+ # Assert that the tool outputs appear in the generated response text
+ assert "tool_weather" in text
+ #assert "cloudy" in text
diff --git a/tests/strands/models/test_vllm.py b/tests/strands/models/test_vllm.py
new file mode 100644
index 000000000..f21741e4b
--- /dev/null
+++ b/tests/strands/models/test_vllm.py
@@ -0,0 +1,153 @@
+import pytest
+import requests
+import json
+
+from types import SimpleNamespace
+from strands.models.vllm import VLLMModel
+
+
+@pytest.fixture
+def model_id():
+ return "meta-llama/Llama-3.2-3B"
+
+
+@pytest.fixture
+def host():
+ return "http://localhost:8000"
+
+
+@pytest.fixture
+def model(model_id, host):
+ return VLLMModel(host, model_id=model_id, max_tokens=128)
+
+
+@pytest.fixture
+def messages():
+ return [{"role": "user", "content": [{"text": "Hello"}]}]
+
+
+@pytest.fixture
+def system_prompt():
+ return "You are a helpful assistant."
+
+
+def test_init_sets_config(model, model_id):
+ assert model.get_config()["model_id"] == model_id
+ assert model.host == "http://localhost:8000"
+
+
+def test_update_config_overrides(model):
+ model.update_config(temperature=0.3)
+ assert model.get_config()["temperature"] == 0.3
+
+
+def test_format_request_basic(model, messages):
+ request = model.format_request(messages)
+ assert request["model"] == model.get_config()["model_id"]
+ assert isinstance(request["messages"], list)
+ assert request["messages"][0]["role"] == "user"
+ assert request["messages"][0]["content"] == "Hello"
+ assert request["stream"] is True
+
+
+def test_format_request_with_system_prompt(model, messages, system_prompt):
+ request = model.format_request(messages, system_prompt=system_prompt)
+ assert request["messages"][0]["role"] == "system"
+ assert request["messages"][0]["content"] == system_prompt
+
+
+def test_format_chunk_text():
+ chunk = {"chunk_type": "content_delta", "data_type": "text", "data": "World"}
+ formatted = VLLMModel.format_chunk(None, chunk)
+ assert formatted == {"contentBlockDelta": {"delta": {"text": "World"}}}
+
+
+def test_format_chunk_tool_call_delta():
+ chunk = {
+ "chunk_type": "content_delta",
+ "data_type": "tool",
+ "data": SimpleNamespace(name="get_time", arguments={"timezone": "UTC"}),
+ }
+
+ formatted = VLLMModel.format_chunk(None, chunk)
+ assert "contentBlockDelta" in formatted
+ assert "toolUse" in formatted["contentBlockDelta"]["delta"]
+ assert json.loads(formatted["contentBlockDelta"]["delta"]["toolUse"]["input"])["timezone"] == "UTC"
+
+
+def test_stream_response(monkeypatch, model, messages):
+ mock_lines = [
+ 'data: {"choices":[{"delta":{"content":"Hello"}}]}\n',
+ 'data: {"choices":[{"delta":{"content":" world"}}]}\n',
+ "data: [DONE]\n",
+ ]
+
+ class MockResponse:
+ def __init__(self):
+ self.status_code = 200
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *a): pass
+
+ def iter_lines(self, decode_unicode=False):
+ return iter(mock_lines)
+
+ monkeypatch.setattr(requests, "post", lambda *a, **kw: MockResponse())
+
+ chunks = list(model.stream(model.format_request(messages)))
+ chunk_types = [c.get("chunk_type") for c in chunks]
+
+ assert "message_start" in chunk_types
+ assert chunk_types.count("content_delta") == 2
+ assert "content_stop" in chunk_types
+ assert "message_stop" in chunk_types
+
+
+def test_stream_tool_call(monkeypatch, model, messages):
+ tool_call = {
+ "name": "current_time",
+ "arguments": {"timezone": "UTC"},
+ }
+ tool_call_json = json.dumps(tool_call)
+ data_str = json.dumps({
+ "choices": [
+ {"delta": {"content": f"{tool_call_json}"}}
+ ]
+ })
+ mock_lines = [
+ 'data: {"choices":[{"delta":{"content":"Some answer before tool."}}]}\n',
+ f"data: {data_str}\n",
+ "data: [DONE]\n",
+ ]
+
+ class MockResponse:
+ def __init__(self): self.status_code = 200
+ def __enter__(self): return self
+ def __exit__(self, *a): pass
+ def iter_lines(self, decode_unicode=False): return iter(mock_lines)
+
+ monkeypatch.setattr(requests, "post", lambda *a, **kw: MockResponse())
+
+ chunks = list(model.stream(model.format_request(messages)))
+ tool_chunks = [c for c in chunks if c.get("chunk_type") == "content_start" and c.get("data_type") == "tool"]
+
+ assert tool_chunks
+ assert any("tool_use" in c.get("chunk_type", "") or "tool" in c.get("data_type", "") for c in chunks)
+
+
+
+def test_stream_server_error(monkeypatch, model, messages):
+ class ErrorResponse:
+ def __init__(self):
+ self.status_code = 500
+ self.text = "Internal Error"
+ def __enter__(self): return self
+ def __exit__(self, *a): pass
+ def iter_lines(self, decode_unicode=False): return iter([])
+
+ monkeypatch.setattr(requests, "post", lambda *a, **kw: ErrorResponse())
+
+ with pytest.raises(Exception, match="Request failed: 500"):
+ list(model.stream(model.format_request(messages)))