diff --git a/README.md b/README.md index 7a16324de..486b24380 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,7 @@ from strands import Agent from strands.models import BedrockModel from strands.models.ollama import OllamaModel from strands.models.llamaapi import LlamaAPIModel +from strands.models.gemini import GeminiModel # Bedrock bedrock_model = BedrockModel( @@ -130,11 +131,21 @@ llama_model = LlamaAPIModel( ) agent = Agent(model=llama_model) response = agent("Tell me about Agentic AI") + +# Gemini +gemini_model = GeminiModel( + model_id="gemini-pro", + max_tokens=1024, + params={"temperature": 0.7} +) +agent = Agent(model=gemini_model) +response = agent("Tell me about Agentic AI") ``` Built-in providers: - [Amazon Bedrock](https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/) - [Anthropic](https://strandsagents.com/latest/user-guide/concepts/model-providers/anthropic/) + - [Gemini](https://strandsagents.com/latest/user-guide/concepts/model-providers/gemini/) - [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/) - [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/) - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/) diff --git a/pyproject.toml b/pyproject.toml index e3e3f3729..f7b5ef707 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,9 +75,12 @@ ollama = [ llamaapi = [ "llama-api-client>=0.1.0,<1.0.0", ] +gemini = [ + "google-generativeai>=0.8.5", +] [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "gemini"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py new file mode 100644 index 000000000..7a6c75b3f --- /dev/null +++ b/src/strands/models/gemini.py @@ -0,0 +1,269 @@ +"""Google Gemini model provider. + +- Docs: https://ai.google.dev/docs/gemini_api_overview +""" + +import base64 +import json +import logging +import mimetypes +from typing import Any, Iterable, Optional, TypedDict + +import google.generativeai.generative_models as genai # mypy: disable-error-code=import +from typing_extensions import Required, Unpack, override + +from ..types.content import ContentBlock, Messages +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException +from ..types.models import Model +from ..types.streaming import StreamEvent +from ..types.tools import ToolSpec + +logger = logging.getLogger(__name__) + + +class GeminiModel(Model): + """Google Gemini model provider implementation.""" + + EVENT_TYPES = { + "message_start", + "content_block_start", + "content_block_delta", + "content_block_stop", + "message_stop", + } + + OVERFLOW_MESSAGES = { + "input is too long", + "input length exceeds context window", + "input and output tokens exceed your context limit", + } + + class GeminiConfig(TypedDict, total=False): + """Configuration options for Gemini models. + + Attributes: + max_tokens: Maximum number of tokens to generate. + model_id: Gemini model ID (e.g., "gemini-pro"). + For a complete list of supported models, see + https://ai.google.dev/models/gemini. + params: Additional model parameters (e.g., temperature). + For a complete list of supported parameters, see + https://ai.google.dev/docs/gemini_api_overview#generation_config. + """ + + max_tokens: Required[int] + model_id: Required[str] + params: Optional[dict[str, Any]] + + def __init__(self, *, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[GeminiConfig]): + """Initialize provider instance. + + Args: + client_args: Arguments for the underlying Gemini client (e.g., api_key). + For a complete list of supported arguments, see + https://ai.google.dev/docs/gemini_api_overview#client_libraries. + **model_config: Configuration options for the Gemini model. + """ + self.config = GeminiModel.GeminiConfig(**model_config) + + logger.debug("config=<%s> | initializing", self.config) + + client_args = client_args or {} + genai.client.configure(**client_args) + self.model = genai.GenerativeModel(self.config["model_id"]) + + @override + def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type: ignore[override] + """Update the Gemini model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + self.config.update(model_config) + self.model = genai.GenerativeModel(self.config["model_id"]) + + @override + def get_config(self) -> GeminiConfig: + """Get the Gemini model configuration. + + Returns: + The Gemini model configuration. + """ + return self.config + + def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]: + """Format a Gemini content block. + + Args: + content: Message content. + + Returns: + Gemini formatted content block. + """ + if "image" in content: + return { + "inline_data": { + "data": base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8"), + "mime_type": mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream"), + } + } + + if "text" in content: + return {"text": content["text"]} + + return {"text": json.dumps(content)} + + def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: + """Format a Gemini messages array. + + Args: + messages: List of message objects to be processed by the model. + + Returns: + A Gemini messages array. + """ + formatted_messages = [] + + for message in messages: + formatted_contents = [] + + for content in message["content"]: + if "cachePoint" in content: + continue + + formatted_contents.append(self._format_request_message_content(content)) + + if formatted_contents: + formatted_messages.append({"role": message["role"], "parts": formatted_contents}) + + return formatted_messages + + @override + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> dict[str, Any]: + """Format a Gemini streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + A Gemini streaming request. + """ + generation_config = {"max_output_tokens": self.config["max_tokens"], **(self.config.get("params") or {})} + + return { + "contents": self._format_request_messages(messages), + "generation_config": generation_config, + "tools": [ + { + "function_declarations": [ + { + "name": tool_spec["name"], + "description": tool_spec["description"], + "parameters": tool_spec["inputSchema"]["json"], + } + for tool_spec in tool_specs or [] + ] + } + ] + if tool_specs + else None, + "system_instruction": system_prompt, + } + + @override + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format the Gemini response events into standardized message chunks. + + Args: + event: A response event from the Gemini model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + This error should never be encountered as we control chunk_type in the stream method. + """ + match event["type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_block_start": + return {"contentBlockStart": {"start": {}}} + + case "content_block_delta": + return {"contentBlockDelta": {"delta": {"text": event["text"]}}} + + case "content_block_stop": + return {"contentBlockStop": {}} + + case "message_stop": + return {"messageStop": {"stopReason": event["stop_reason"]}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["usage"]["prompt_token_count"], + "outputTokens": event["usage"]["candidates_token_count"], + "totalTokens": event["usage"]["total_token_count"], + }, + "metrics": { + "latencyMs": 0, + }, + } + } + + case _: + raise RuntimeError(f"event_type=<{event['type']} | unknown type") + + @override + def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: + """Send the request to the Gemini model and get the streaming response. + + Args: + request: The formatted request to send to the Gemini model. + + Returns: + An iterable of response events from the Gemini model. + + Raises: + ContextWindowOverflowException: If the input exceeds the model's context window. + ModelThrottledException: If the request is throttled by Gemini. + """ + try: + response = self.model.generate_content(**request, stream=True) + + yield {"type": "message_start"} + yield {"type": "content_block_start"} + + for chunk in response: + if chunk.text: + yield {"type": "content_block_delta", "text": chunk.text} + + yield {"type": "content_block_stop"} + yield {"type": "message_stop", "stop_reason": "end_turn"} + + # Get usage information + usage = response.usage_metadata + yield { + "type": "metadata", + "usage": { + "prompt_token_count": usage.prompt_token_count, + "candidates_token_count": usage.candidates_token_count, + "total_token_count": usage.total_token_count, + }, + } + + except Exception as error: + if "quota" in str(error).lower(): + raise ModelThrottledException(str(error)) from error + + if any(overflow_message in str(error).lower() for overflow_message in GeminiModel.OVERFLOW_MESSAGES): + raise ContextWindowOverflowException(str(error)) from error + + raise error diff --git a/tests-integ/test_model_gemini.py b/tests-integ/test_model_gemini.py new file mode 100644 index 000000000..231602ae1 --- /dev/null +++ b/tests-integ/test_model_gemini.py @@ -0,0 +1,51 @@ +"""Integration tests for the Gemini model provider.""" + +import os + +import pytest + +import strands +from strands import Agent +from strands.models.gemini import GeminiModel + + +@pytest.fixture +def model(): + return GeminiModel( + client_args={ + "api_key": os.getenv("GOOGLE_API_KEY"), + }, + model_id="gemini-pro", + max_tokens=512, + ) + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def system_prompt(): + return "You are an AI assistant that uses & instead of ." + + +@pytest.fixture +def agent(model, tools, system_prompt): + return Agent(model=model, tools=tools, system_prompt=system_prompt) + + +@pytest.mark.skipif("GOOGLE_API_KEY" not in os.environ, reason="GOOGLE_API_KEY environment variable missing") +def test_agent(agent): + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny", "&"]) diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py new file mode 100644 index 000000000..f9d8bff79 --- /dev/null +++ b/tests/strands/models/test_gemini.py @@ -0,0 +1,344 @@ +"""Tests for the Gemini model provider.""" + +import json +import unittest.mock + +import pytest + +import strands +from strands.models.gemini import GeminiModel +from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException + + +@pytest.fixture +def gemini_client(): + with unittest.mock.patch.object(strands.models.gemini.genai, "GenerativeModel") as mock_client_cls: + yield mock_client_cls.return_value + + +@pytest.fixture +def model_id(): + return "gemini-pro" + + +@pytest.fixture +def max_tokens(): + return 1000 + + +@pytest.fixture +def model(gemini_client, model_id, max_tokens): + _ = gemini_client + return GeminiModel(model_id=model_id, max_tokens=max_tokens) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +def test__init__model_configs(gemini_client, model_id, max_tokens): + _ = gemini_client + + model = GeminiModel(model_id=model_id, max_tokens=max_tokens, params={"temperature": 1}) + + tru_temperature = model.get_config().get("params") + exp_temperature = {"temperature": 1} + + assert tru_temperature == exp_temperature + + +def test_update_config(model, model_id): + model.update_config(model_id=model_id) + + tru_model_id = model.get_config().get("model_id") + exp_model_id = model_id + + assert tru_model_id == exp_model_id + + +def test_format_request_default(model, messages, model_id, max_tokens): + tru_request = model.format_request(messages) + exp_request = { + "contents": [{"role": "user", "parts": [{"text": "test"}]}], + "generation_config": {"max_output_tokens": max_tokens}, + "tools": None, + "system_instruction": None, + } + + assert tru_request == exp_request + + +def test_format_request_with_params(model, messages, model_id, max_tokens): + model.update_config(params={"temperature": 1}) + + tru_request = model.format_request(messages) + exp_request = { + "contents": [{"role": "user", "parts": [{"text": "test"}]}], + "generation_config": { + "max_output_tokens": max_tokens, + "temperature": 1, + }, + "tools": None, + "system_instruction": None, + } + + assert tru_request == exp_request + + +def test_format_request_with_system_prompt(model, messages, model_id, max_tokens, system_prompt): + tru_request = model.format_request(messages, system_prompt=system_prompt) + exp_request = { + "contents": [{"role": "user", "parts": [{"text": "test"}]}], + "generation_config": {"max_output_tokens": max_tokens}, + "tools": None, + "system_instruction": system_prompt, + } + + assert tru_request == exp_request + + +def test_format_request_with_image(model, model_id, max_tokens): + messages = [ + { + "role": "user", + "content": [ + { + "image": { + "format": "jpg", + "source": {"bytes": b"base64encodedimage"}, + }, + }, + ], + }, + ] + + tru_request = model.format_request(messages) + exp_request = { + "contents": [ + { + "role": "user", + "parts": [ + { + "inline_data": { + "data": "YmFzZTY0ZW5jb2RlZGltYWdl", + "mime_type": "image/jpeg", + } + } + ], + } + ], + "generation_config": {"max_output_tokens": max_tokens}, + "tools": None, + "system_instruction": None, + } + + assert tru_request == exp_request + + +def test_format_request_with_other(model, model_id, max_tokens): + messages = [ + { + "role": "user", + "content": [{"other": {"a": 1}}], + }, + ] + + tru_request = model.format_request(messages) + exp_request = { + "contents": [ + { + "role": "user", + "parts": [ + { + "text": json.dumps({"other": {"a": 1}}), + } + ], + } + ], + "generation_config": {"max_output_tokens": max_tokens}, + "tools": None, + "system_instruction": None, + } + + assert tru_request == exp_request + + +def test_format_request_with_empty_content(model, model_id, max_tokens): + messages = [ + { + "role": "user", + "content": [], + }, + ] + + tru_request = model.format_request(messages) + exp_request = { + "contents": [], + "generation_config": {"max_output_tokens": max_tokens}, + "tools": None, + "system_instruction": None, + } + + assert tru_request == exp_request + + +def test_format_chunk_message_start(model): + event = {"type": "message_start"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"messageStart": {"role": "assistant"}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_block_start(model): + event = { + "type": "content_block_start", + } + + tru_chunk = model.format_chunk(event) + exp_chunk = { + "contentBlockStart": { + "start": {}, + }, + } + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_block_delta(model): + event = { + "type": "content_block_delta", + "text": "hello", + } + + tru_chunk = model.format_chunk(event) + exp_chunk = { + "contentBlockDelta": { + "delta": {"text": "hello"}, + }, + } + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_block_stop(model): + event = {"type": "content_block_stop"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockStop": {}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_message_stop(model): + event = {"type": "message_stop", "stop_reason": "end_turn"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"messageStop": {"stopReason": "end_turn"}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_metadata(model): + event = { + "type": "metadata", + "usage": { + "prompt_token_count": 1, + "candidates_token_count": 2, + "total_token_count": 3, + }, + } + + tru_chunk = model.format_chunk(event) + exp_chunk = { + "metadata": { + "usage": { + "inputTokens": 1, + "outputTokens": 2, + "totalTokens": 3, + }, + "metrics": { + "latencyMs": 0, + }, + }, + } + + assert tru_chunk == exp_chunk + + +def test_format_chunk_unknown(model): + event = {"type": "unknown"} + + with pytest.raises(RuntimeError, match="event_type= | unknown type"): + model.format_chunk(event) + + +def test_stream(gemini_client, model): + mock_chunk = unittest.mock.Mock(text="test") + mock_response = unittest.mock.MagicMock() + mock_response.__iter__.return_value = iter([mock_chunk]) + mock_response.usage_metadata = unittest.mock.Mock( + prompt_token_count=1, + candidates_token_count=2, + total_token_count=3, + ) + gemini_client.generate_content.return_value = mock_response + + request = {"model": "gemini-pro"} + response = model.stream(request) + + tru_events = list(response) + exp_events = [ + {"type": "message_start"}, + {"type": "content_block_start"}, + {"type": "content_block_delta", "text": "test"}, + {"type": "content_block_stop"}, + {"type": "message_stop", "stop_reason": "end_turn"}, + { + "type": "metadata", + "usage": { + "prompt_token_count": 1, + "candidates_token_count": 2, + "total_token_count": 3, + }, + }, + ] + + assert tru_events == exp_events + gemini_client.generate_content.assert_called_once_with(**request, stream=True) + + +def test_stream_quota_error(gemini_client, model): + gemini_client.generate_content.side_effect = Exception("quota exceeded") + + with pytest.raises(ModelThrottledException, match="quota exceeded"): + next(model.stream({})) + + +@pytest.mark.parametrize( + "overflow_message", + [ + "...input is too long...", + "...input length exceeds context window...", + "...input and output tokens exceed your context limit...", + ], +) +def test_stream_context_window_overflow_error(overflow_message, gemini_client, model): + gemini_client.generate_content.side_effect = Exception(overflow_message) + + with pytest.raises(ContextWindowOverflowException): + next(model.stream({})) + + +def test_stream_other_error(gemini_client, model): + gemini_client.generate_content.side_effect = Exception("other error") + + with pytest.raises(Exception, match="other error"): + next(model.stream({}))