Skip to content

Commit b34c6da

Browse files
committed
test(sagemaker): tests for ep additional_args
Add a test to check for insertion of endpoint config additional_args
1 parent d906733 commit b34c6da

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

src/strands/models/sagemaker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def format_request(
283283
if target_variant:
284284
request["TargetVariant"] = target_variant
285285

286-
# Add additional args if provided
286+
# Add additional request args if provided
287287
additional_args = self.endpoint_config.get("additional_args")
288288
if additional_args:
289289
request.update(additional_args)

tests/strands/models/test_sagemaker.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def test_init_with_all_params(self, boto_session):
112112
"endpoint_name": "test-endpoint",
113113
"inference_component_name": "test-component",
114114
"region_name": "us-west-2",
115+
"additional_args": {"test_arg_name": "test_arg_value"},
115116
}
116117
payload_config = {
117118
"stream": False,
@@ -129,6 +130,7 @@ def test_init_with_all_params(self, boto_session):
129130

130131
assert model.endpoint_config["endpoint_name"] == "test-endpoint"
131132
assert model.endpoint_config["inference_component_name"] == "test-component"
133+
assert model.endpoint_config["additional_args"]["test_arg_name"] == "test_arg_value"
132134
assert model.payload_config["stream"] is False
133135
assert model.payload_config["max_tokens"] == 1024
134136
assert model.payload_config["temperature"] == 0.7
@@ -239,6 +241,22 @@ def test_get_config(self, model, endpoint_config):
239241
# assert "tools" in payload
240242
# assert payload["tools"] == []
241243

244+
def test_format_request_with_additional_args(self, boto_session, endpoint_config, messages, payload_config):
245+
"""Test formatting a request's `additional_args` where provided"""
246+
endpoint_config_ext = {
247+
**endpoint_config,
248+
"additional_args": {
249+
"extra_key": "extra_value",
250+
},
251+
}
252+
model = SageMakerAIModel(
253+
boto_session=boto_session,
254+
endpoint_config=endpoint_config_ext,
255+
payload_config=payload_config,
256+
)
257+
request = model.format_request(messages)
258+
assert request.get("extra_key") == "extra_value"
259+
242260
@pytest.mark.asyncio
243261
async def test_stream_with_streaming_enabled(self, sagemaker_client, model, messages):
244262
"""Test streaming response with streaming enabled."""

0 commit comments

Comments
 (0)