@@ -112,6 +112,7 @@ def test_init_with_all_params(self, boto_session):
112112 "endpoint_name" : "test-endpoint" ,
113113 "inference_component_name" : "test-component" ,
114114 "region_name" : "us-west-2" ,
115+ "additional_args" : {"test_arg_name" : "test_arg_value" },
115116 }
116117 payload_config = {
117118 "stream" : False ,
@@ -129,6 +130,7 @@ def test_init_with_all_params(self, boto_session):
129130
130131 assert model .endpoint_config ["endpoint_name" ] == "test-endpoint"
131132 assert model .endpoint_config ["inference_component_name" ] == "test-component"
133+ assert model .endpoint_config ["additional_args" ]["test_arg_name" ] == "test_arg_value"
132134 assert model .payload_config ["stream" ] is False
133135 assert model .payload_config ["max_tokens" ] == 1024
134136 assert model .payload_config ["temperature" ] == 0.7
@@ -239,6 +241,22 @@ def test_get_config(self, model, endpoint_config):
239241 # assert "tools" in payload
240242 # assert payload["tools"] == []
241243
244+ def test_format_request_with_additional_args (self , boto_session , endpoint_config , messages , payload_config ):
245+ """Test formatting a request's `additional_args` where provided"""
246+ endpoint_config_ext = {
247+ ** endpoint_config ,
248+ "additional_args" : {
249+ "extra_key" : "extra_value" ,
250+ },
251+ }
252+ model = SageMakerAIModel (
253+ boto_session = boto_session ,
254+ endpoint_config = endpoint_config_ext ,
255+ payload_config = payload_config ,
256+ )
257+ request = model .format_request (messages )
258+ assert request .get ("extra_key" ) == "extra_value"
259+
242260 @pytest .mark .asyncio
243261 async def test_stream_with_streaming_enabled (self , sagemaker_client , model , messages ):
244262 """Test streaming response with streaming enabled."""
0 commit comments