|
4 | 4 | import logging |
5 | 5 | import os |
6 | 6 | from dataclasses import dataclass |
7 | | -from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union, cast |
| 7 | +from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union |
8 | 8 |
|
9 | 9 | import boto3 |
10 | 10 | from botocore.config import Config as BotocoreConfig |
@@ -151,8 +151,8 @@ def __init__( |
151 | 151 | validate_config_keys(payload_config, self.SageMakerAIPayloadSchema) |
152 | 152 | payload_config.setdefault("stream", True) |
153 | 153 | payload_config.setdefault("tool_results_as_user_messages", False) |
154 | | - self.endpoint_config = dict(endpoint_config) |
155 | | - self.payload_config = dict(payload_config) |
| 154 | + self.endpoint_config = self.SageMakerAIEndpointConfig(**endpoint_config) |
| 155 | + self.payload_config = self.SageMakerAIPayloadSchema(**payload_config) |
156 | 156 | logger.debug( |
157 | 157 | "endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config |
158 | 158 | ) |
@@ -193,7 +193,7 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i |
193 | 193 | Returns: |
194 | 194 | The Amazon SageMaker model configuration. |
195 | 195 | """ |
196 | | - return cast(SageMakerAIModel.SageMakerAIEndpointConfig, self.endpoint_config) |
| 196 | + return self.endpoint_config |
197 | 197 |
|
198 | 198 | @override |
199 | 199 | def format_request( |
@@ -238,6 +238,10 @@ def format_request( |
238 | 238 | }, |
239 | 239 | } |
240 | 240 |
|
| 241 | + payload_additional_args = self.payload_config.get("additional_args") |
| 242 | + if payload_additional_args: |
| 243 | + payload.update(payload_additional_args) |
| 244 | + |
241 | 245 | # Remove tools and tool_choice if tools = [] |
242 | 246 | if not payload["tools"]: |
243 | 247 | payload.pop("tools") |
@@ -273,16 +277,20 @@ def format_request( |
273 | 277 | } |
274 | 278 |
|
275 | 279 | # Add optional SageMaker parameters if provided |
276 | | - if self.endpoint_config.get("inference_component_name"): |
277 | | - request["InferenceComponentName"] = self.endpoint_config["inference_component_name"] |
278 | | - if self.endpoint_config.get("target_model"): |
279 | | - request["TargetModel"] = self.endpoint_config["target_model"] |
280 | | - if self.endpoint_config.get("target_variant"): |
281 | | - request["TargetVariant"] = self.endpoint_config["target_variant"] |
282 | | - |
283 | | - # Add additional args if provided |
284 | | - if self.endpoint_config.get("additional_args"): |
285 | | - request.update(self.endpoint_config["additional_args"].__dict__) |
| 280 | + inf_component_name = self.endpoint_config.get("inference_component_name") |
| 281 | + if inf_component_name: |
| 282 | + request["InferenceComponentName"] = inf_component_name |
| 283 | + target_model = self.endpoint_config.get("target_model") |
| 284 | + if target_model: |
| 285 | + request["TargetModel"] = target_model |
| 286 | + target_variant = self.endpoint_config.get("target_variant") |
| 287 | + if target_variant: |
| 288 | + request["TargetVariant"] = target_variant |
| 289 | + |
| 290 | + # Add additional request args if provided |
| 291 | + additional_args = self.endpoint_config.get("additional_args") |
| 292 | + if additional_args: |
| 293 | + request.update(additional_args) |
286 | 294 |
|
287 | 295 | return request |
288 | 296 |
|
|
0 commit comments