@@ -95,12 +95,12 @@ def create_chat_completion_message_event(
9595 transaction .record_ml_event ("LlmChatCompletionMessage" , chat_completion_message_dict )
9696
9797
98- def extract_bedrock_titan_model (request_body , response_body ):
98+ def extract_bedrock_titan_text_model (request_body , response_body ):
9999 response_body = json .loads (response_body )
100100 request_body = json .loads (request_body )
101101
102102 input_tokens = response_body ["inputTextTokenCount" ]
103- completion_tokens = sum (result ["tokenCount" ] for result in response_body [ "results" ] )
103+ completion_tokens = sum (result ["tokenCount" ] for result in response_body . get ( "results" , []) )
104104 total_tokens = input_tokens + completion_tokens
105105
106106 request_config = request_body .get ("textGenerationConfig" , {})
@@ -121,6 +121,20 @@ def extract_bedrock_titan_model(request_body, response_body):
121121 return message_list , chat_completion_summary_dict
122122
123123
124+ def extract_bedrock_titan_embedding_model (request_body , response_body ):
125+ response_body = json .loads (response_body )
126+ request_body = json .loads (request_body )
127+
128+ input_tokens = response_body ["inputTextTokenCount" ]
129+
130+ embedding_dict = {
131+ "input" : request_body .get ("inputText" , "" ),
132+ "response.usage.prompt_tokens" : input_tokens ,
133+ "response.usage.total_tokens" : input_tokens ,
134+ }
135+ return embedding_dict
136+
137+
124138def extract_bedrock_ai21_j2_model (request_body , response_body ):
125139 response_body = json .loads (response_body )
126140 request_body = json .loads (request_body )
@@ -159,11 +173,12 @@ def extract_bedrock_cohere_model(request_body, response_body):
159173 return message_list , chat_completion_summary_dict
160174
161175
162- MODEL_EXTRACTORS = {
163- "amazon.titan" : extract_bedrock_titan_model ,
164- "ai21.j2" : extract_bedrock_ai21_j2_model ,
165- "cohere" : extract_bedrock_cohere_model ,
166- }
176+ MODEL_EXTRACTORS = [ # Order is important here, avoiding dictionaries
177+ ("amazon.titan-embed" , extract_bedrock_titan_embedding_model ),
178+ ("amazon.titan" , extract_bedrock_titan_text_model ),
179+ ("ai21.j2" , extract_bedrock_ai21_j2_model ),
180+ ("cohere" , extract_bedrock_cohere_model ),
181+ ]
167182
168183
169184@function_wrapper
@@ -194,7 +209,7 @@ def wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
194209 return response
195210
196211 # Determine extractor by model type
197- for extractor_name , extractor in MODEL_EXTRACTORS . items () :
212+ for extractor_name , extractor in MODEL_EXTRACTORS :
198213 if model .startswith (extractor_name ):
199214 break
200215 else :
@@ -213,7 +228,45 @@ def wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
213228 # Read and replace response streaming bodies
214229 response_body = response ["body" ].read ()
215230 response ["body" ] = StreamingBody (BytesIO (response_body ), len (response_body ))
231+ response_headers = response ["ResponseMetadata" ]["HTTPHeaders" ]
232+
233+ if model .startswith ("amazon.titan-embed" ): # Only available embedding models
234+ handle_embedding_event (instance , transaction , extractor , model , response_body , response_headers , request_body , ft .duration )
235+ else :
236+ handle_chat_completion_event (instance , transaction , extractor , model , response_body , response_headers , request_body , ft .duration )
237+
238+ return response
216239
240+ def handle_embedding_event (client , transaction , extractor , model , response_body , response_headers , request_body , duration ):
241+ embedding_id = str (uuid .uuid4 ())
242+ available_metadata = get_trace_linking_metadata ()
243+ span_id = available_metadata .get ("span.id" , "" )
244+ trace_id = available_metadata .get ("trace.id" , "" )
245+
246+ request_id = response_headers .get ("x-amzn-requestid" , "" )
247+ settings = transaction .settings if transaction .settings is not None else global_settings ()
248+
249+ embedding_dict = extractor (request_body , response_body )
250+
251+ embedding_dict .update ({
252+ "vendor" : "bedrock" ,
253+ "ingest_source" : "Python" ,
254+ "id" : embedding_id ,
255+ "appName" : settings .app_name ,
256+ "span_id" : span_id ,
257+ "trace_id" : trace_id ,
258+ "request_id" : request_id ,
259+ "transaction_id" : transaction ._transaction_id ,
260+ "api_key_last_four_digits" : client ._request_signer ._credentials .access_key [- 4 :],
261+ "duration" : duration ,
262+ "request.model" : model ,
263+ "response.model" : model ,
264+ })
265+
266+ transaction .record_ml_event ("LlmEmbedding" , embedding_dict )
267+
268+
269+ def handle_chat_completion_event (client , transaction , extractor , model , response_body , response_headers , request_body , duration ):
217270 custom_attrs_dict = transaction ._custom_params
218271 conversation_id = custom_attrs_dict .get ("conversation_id" , "" )
219272
@@ -222,7 +275,6 @@ def wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
222275 span_id = available_metadata .get ("span.id" , "" )
223276 trace_id = available_metadata .get ("trace.id" , "" )
224277
225- response_headers = response ["ResponseMetadata" ]["HTTPHeaders" ]
226278 request_id = response_headers .get ("x-amzn-requestid" , "" )
227279 settings = transaction .settings if transaction .settings is not None else global_settings ()
228280
@@ -232,15 +284,15 @@ def wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
232284 {
233285 "vendor" : "bedrock" ,
234286 "ingest_source" : "Python" ,
235- "api_key_last_four_digits" : instance ._request_signer ._credentials .access_key [- 4 :],
287+ "api_key_last_four_digits" : client ._request_signer ._credentials .access_key [- 4 :],
236288 "id" : chat_completion_id ,
237289 "appName" : settings .app_name ,
238290 "conversation_id" : conversation_id ,
239291 "span_id" : span_id ,
240292 "trace_id" : trace_id ,
241293 "transaction_id" : transaction ._transaction_id ,
242294 "request_id" : request_id ,
243- "duration" : ft . duration ,
295+ "duration" : duration ,
244296 "request.model" : model ,
245297 "response.model" : model , # Duplicate data required by the UI
246298 }
@@ -261,8 +313,6 @@ def wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
261313 response_id = response_id ,
262314 )
263315
264- return response
265-
266316
267317CUSTOM_TRACE_POINTS = {
268318 ("sns" , "publish" ): message_trace ("SNS" , "Produce" , "Topic" , extract (("TopicArn" , "TargetArn" ), "PhoneNumber" )),
0 commit comments