From e194edbaefe0d0c480a86801fdf4133d476b78c3 Mon Sep 17 00:00:00 2001 From: Karan Gathani Date: Mon, 8 Sep 2025 11:08:01 -0700 Subject: [PATCH 1/2] Add Bedrock Anthropic provider support for test generation Extended CLI and internal logic to support 'bedrock-anthropic' as a provider for test generation. Updated help messages, provider validation, and model handling to accommodate Bedrock Anthropic, including AWS credential requirements and model ID usage. Integrated ChatBedrockAnthropic client and adjusted model validation and selection accordingly. --- shiny/_main.py | 15 ++++-- shiny/_main_generate_test.py | 22 +++++--- shiny/pytest/_generate/_main.py | 96 ++++++++++++++++++++++----------- 3 files changed, 91 insertions(+), 42 deletions(-) diff --git a/shiny/_main.py b/shiny/_main.py index f43421a71..c606c7ed8 100644 --- a/shiny/_main.py +++ b/shiny/_main.py @@ -555,14 +555,23 @@ def add() -> None: ) @click.option( "--provider", - type=click.Choice(["anthropic", "openai"]), + type=click.Choice(["anthropic", "openai", "bedrock-anthropic"]), default="anthropic", - help="AI provider to use for test generation.", + help=( + "AI provider to use for test generation. For 'bedrock-anthropic', " + "make sure your AWS credentials are configured (env vars, profile, or role) " + "and provide a Bedrock Anthropic model ID (e.g., " + "us.anthropic.claude-3-7-sonnet-20250219-v1:0)." + ), ) @click.option( "--model", type=str, - help="Specific model to use (optional). Examples: haiku3.5, sonnet, gpt-5, gpt-5-mini", + help=( + "Specific model to use (optional). Examples: haiku3.5, sonnet, gpt-5, gpt-5-mini; " + "or a Bedrock Anthropic model ID when using provider=bedrock-anthropic, e.g. " + "us.anthropic.claude-3-7-sonnet-20250219-v1:0" + ), ) # Param for app.py, param for test_name def test( diff --git a/shiny/_main_generate_test.py b/shiny/_main_generate_test.py index bdca96a7b..64fe5e8cf 100644 --- a/shiny/_main_generate_test.py +++ b/shiny/_main_generate_test.py @@ -54,19 +54,27 @@ def validate_api_key(provider: str) -> None: "env_var": "OPENAI_API_KEY", "url": "https://platform.openai.com/api-keys", }, + "bedrock-anthropic": { + "env_var": None, + "url": "https://docs.aws.amazon.com/bedrock/latest/userguide/setting-up.html", + }, } if provider not in api_configs: raise ValidationError(f"Unsupported provider: {provider}") config = api_configs[provider] - if not os.getenv(config["env_var"]): - raise ValidationError( - f"{config['env_var']} environment variable is not set.\n" - f"Please set your {provider.title()} API key:\n" - f" export {config['env_var']}='your-api-key-here'\n\n" - f"Get your API key from: {config['url']}" - ) + if provider in ("anthropic", "openai"): + env_var = config["env_var"] # type: ignore[assignment] + if not isinstance(env_var, str) or not os.getenv(env_var): + raise ValidationError( + f"{env_var} environment variable is not set.\n" + f"Please set your {provider.title()} API key:\n" + f" export {env_var}='your-api-key-here'\n\n" + f"Get your API key from: {config['url']}" + ) + else: + pass def get_app_file_path(app_file: str | None) -> Path: diff --git a/shiny/pytest/_generate/_main.py b/shiny/pytest/_generate/_main.py index 8bf75de8c..e0f43afe5 100644 --- a/shiny/pytest/_generate/_main.py +++ b/shiny/pytest/_generate/_main.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Literal, Optional, Tuple, Union -from chatlas import ChatAnthropic, ChatOpenAI, token_usage +from chatlas import ChatAnthropic, ChatBedrockAnthropic, ChatOpenAI, token_usage from dotenv import load_dotenv __all__ = [ @@ -32,6 +32,7 @@ class Config: DEFAULT_ANTHROPIC_MODEL = "claude-sonnet-4-20250514" DEFAULT_OPENAI_MODEL = "gpt-5-mini-2025-08-07" + DEFAULT_BEDROCK_ANTHROPIC_MODEL = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" DEFAULT_PROVIDER = "anthropic" MAX_TOKENS = 8092 @@ -50,7 +51,9 @@ class ShinyTestGenerator: def __init__( self, - provider: Literal["anthropic", "openai"] = Config.DEFAULT_PROVIDER, + provider: Literal[ + "anthropic", "openai", "bedrock-anthropic" + ] = Config.DEFAULT_PROVIDER, api_key: Optional[str] = None, log_file: str = Config.LOG_FILE, setup_logging: bool = True, @@ -74,25 +77,28 @@ def __init__( self.setup_logging() @property - def client(self) -> Union[ChatAnthropic, ChatOpenAI]: + def client(self) -> Union[ChatAnthropic, ChatOpenAI, ChatBedrockAnthropic]: """Lazy-loaded chat client based on provider""" if self._client is None: - if not self.api_key: - env_var = ( - "ANTHROPIC_API_KEY" - if self.provider == "anthropic" - else "OPENAI_API_KEY" - ) - self.api_key = os.getenv(env_var) - if not self.api_key: - raise ValueError( - f"Missing API key for provider '{self.provider}'. Set the environment variable " - f"{'ANTHROPIC_API_KEY' if self.provider == 'anthropic' else 'OPENAI_API_KEY'} or pass api_key explicitly." - ) + if self.provider in ("anthropic", "openai"): + if not self.api_key: + env_var = ( + "ANTHROPIC_API_KEY" + if self.provider == "anthropic" + else "OPENAI_API_KEY" + ) + self.api_key = os.getenv(env_var) + if not self.api_key: + raise ValueError( + f"Missing API key for provider '{self.provider}'. Set the environment variable " + f"{'ANTHROPIC_API_KEY' if self.provider == 'anthropic' else 'OPENAI_API_KEY'} or pass api_key explicitly." + ) if self.provider == "anthropic": self._client = ChatAnthropic(api_key=self.api_key) elif self.provider == "openai": self._client = ChatOpenAI(api_key=self.api_key) + elif self.provider == "bedrock-anthropic": + self._client = ChatBedrockAnthropic() else: raise ValueError(f"Unsupported provider: {self.provider}") return self._client @@ -118,6 +124,8 @@ def default_model(self) -> str: return Config.DEFAULT_ANTHROPIC_MODEL elif self.provider == "openai": return Config.DEFAULT_OPENAI_MODEL + elif self.provider == "bedrock-anthropic": + return Config.DEFAULT_BEDROCK_ANTHROPIC_MODEL else: raise ValueError(f"Unsupported provider: {self.provider}") @@ -168,6 +176,15 @@ def _resolve_model(self, model: str) -> str: def _validate_model_for_provider(self, model: str) -> str: """Validate that the model is compatible with the current provider""" + if self.provider == "bedrock-anthropic": + resolved_model = model + if resolved_model.startswith("gpt-") or resolved_model.startswith("o1-"): + raise ValueError( + f"Model '{model}' is an OpenAI model but provider is set to 'bedrock-anthropic'. " + f"Use an Anthropic Bedrock model ID (e.g., 'anthropic.claude-3-5-sonnet-20240620-v1:0')." + ) + return resolved_model + resolved_model = self._resolve_model(model) if self.provider == "anthropic": @@ -193,18 +210,19 @@ def get_llm_response(self, prompt: str, model: Optional[str] = None) -> str: model = self._validate_model_for_provider(model) try: - if not self.api_key: - env_var = ( - "ANTHROPIC_API_KEY" - if self.provider == "anthropic" - else "OPENAI_API_KEY" - ) - self.api_key = os.getenv(env_var) - if not self.api_key: - raise ValueError( - f"Missing API key for provider '{self.provider}'. Set the environment variable " - f"{'ANTHROPIC_API_KEY' if self.provider == 'anthropic' else 'OPENAI_API_KEY'} or pass api_key." - ) + if self.provider in ("anthropic", "openai"): + if not self.api_key: + env_var = ( + "ANTHROPIC_API_KEY" + if self.provider == "anthropic" + else "OPENAI_API_KEY" + ) + self.api_key = os.getenv(env_var) + if not self.api_key: + raise ValueError( + f"Missing API key for provider '{self.provider}'. Set the environment variable " + f"{'ANTHROPIC_API_KEY' if self.provider == 'anthropic' else 'OPENAI_API_KEY'} or pass api_key." + ) # Create chat client with the specified model if self.provider == "anthropic": chat = ChatAnthropic( @@ -219,6 +237,12 @@ def get_llm_response(self, prompt: str, model: Optional[str] = None) -> str: system_prompt=self.system_prompt, api_key=self.api_key, ) + elif self.provider == "bedrock-anthropic": + chat = ChatBedrockAnthropic( + model=model, + system_prompt=self.system_prompt, + max_tokens=Config.MAX_TOKENS, + ) else: raise ValueError(f"Unsupported provider: {self.provider}") @@ -226,15 +250,12 @@ def get_llm_response(self, prompt: str, model: Optional[str] = None) -> str: response = chat.chat(prompt) elapsed = time.perf_counter() - start_time usage = token_usage() - # For Anthropic, token_usage() includes costs. For OpenAI, use chat.get_cost with model pricing. token_price = None if self.provider == "openai": token_price = Config.OPENAI_PRICING.get(model) try: - # Call to compute and cache costs internally; per-entry cost is computed below _ = chat.get_cost(options="all", token_price=token_price) except Exception: - # If cost computation fails, continue without it pass try: @@ -530,7 +551,9 @@ def generate_test_from_code( ) def switch_provider( - self, provider: Literal["anthropic", "openai"], api_key: Optional[str] = None + self, + provider: Literal["anthropic", "openai", "bedrock-anthropic"], + api_key: Optional[str] = None, ): self.provider = provider if api_key: @@ -549,6 +572,11 @@ def create_openai_generator( ) -> "ShinyTestGenerator": return cls(provider="openai", api_key=api_key, **kwargs) + @classmethod + def create_bedrock_anthropic_generator(cls, **kwargs) -> "ShinyTestGenerator": + # AWS credentials and region are resolved from environment or AWS config + return cls(provider="bedrock-anthropic", api_key=None, **kwargs) + def get_available_models(self) -> list[str]: if self.provider == "anthropic": return [ @@ -562,6 +590,10 @@ def get_available_models(self) -> list[str]: for model in Config.MODEL_ALIASES.keys() if (model.startswith("gpt-") or model.startswith("o1-")) ] + elif self.provider == "bedrock-anthropic": + # Bedrock requires full model IDs (e.g., 'us.anthropic.claude-sonnet-4-20250514-v1:0'). + # We don't provide aliases here because IDs are region/account specific. + return [] else: return [] @@ -573,7 +605,7 @@ def cli(): parser.add_argument("app_file", help="Path to the Shiny app file") parser.add_argument( "--provider", - choices=["anthropic", "openai"], + choices=["anthropic", "openai", "bedrock-anthropic"], default=Config.DEFAULT_PROVIDER, help="LLM provider to use", ) From a92f58b3402302689f2bfecaccb5d35a3c2012ba Mon Sep 17 00:00:00 2001 From: Karan Gathani Date: Thu, 25 Sep 2025 08:06:57 -0700 Subject: [PATCH 2/2] correct the links for getting started --- shiny/_main_generate_test.py | 2 +- shiny/pytest/_generate/_main.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/shiny/_main_generate_test.py b/shiny/_main_generate_test.py index 64fe5e8cf..40f181d31 100644 --- a/shiny/_main_generate_test.py +++ b/shiny/_main_generate_test.py @@ -56,7 +56,7 @@ def validate_api_key(provider: str) -> None: }, "bedrock-anthropic": { "env_var": None, - "url": "https://docs.aws.amazon.com/bedrock/latest/userguide/setting-up.html", + "url": "https://docs.aws.amazon.com/bedrock/latest/userguide/getting-started.html", }, } diff --git a/shiny/pytest/_generate/_main.py b/shiny/pytest/_generate/_main.py index 336c162d7..3d2fbeddb 100644 --- a/shiny/pytest/_generate/_main.py +++ b/shiny/pytest/_generate/_main.py @@ -180,7 +180,7 @@ def _validate_model_for_provider(self, model: str) -> str: if resolved_model.startswith("gpt-") or resolved_model.startswith("o1-"): raise ValueError( f"Model '{model}' is an OpenAI model but provider is set to 'bedrock-anthropic'. " - f"Use an Anthropic Bedrock model ID (e.g., 'anthropic.claude-3-5-sonnet-20240620-v1:0')." + f"Use an Anthropic Bedrock model ID (e.g., 'us.anthropic.claude-3-7-sonnet-20250219-v1:0')." ) return resolved_model