Skip to content

Commit a87ceae

Browse files
committed
AWS Bedrock embedding instrumentation
1 parent 989e38c commit a87ceae

File tree

6 files changed

+3366
-32
lines changed

6 files changed

+3366
-32
lines changed

newrelic/hooks/external_botocore.py

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,12 @@ def create_chat_completion_message_event(
9595
transaction.record_ml_event("LlmChatCompletionMessage", chat_completion_message_dict)
9696

9797

98-
def extract_bedrock_titan_model(request_body, response_body):
98+
def extract_bedrock_titan_text_model(request_body, response_body):
9999
response_body = json.loads(response_body)
100100
request_body = json.loads(request_body)
101101

102102
input_tokens = response_body["inputTextTokenCount"]
103-
completion_tokens = sum(result["tokenCount"] for result in response_body["results"])
103+
completion_tokens = sum(result["tokenCount"] for result in response_body.get("results", []))
104104
total_tokens = input_tokens + completion_tokens
105105

106106
request_config = request_body.get("textGenerationConfig", {})
@@ -121,6 +121,20 @@ def extract_bedrock_titan_model(request_body, response_body):
121121
return message_list, chat_completion_summary_dict
122122

123123

124+
def extract_bedrock_titan_embedding_model(request_body, response_body):
125+
response_body = json.loads(response_body)
126+
request_body = json.loads(request_body)
127+
128+
input_tokens = response_body["inputTextTokenCount"]
129+
130+
embedding_dict = {
131+
"input": request_body.get("inputText", ""),
132+
"response.usage.prompt_tokens": input_tokens,
133+
"response.usage.total_tokens": input_tokens,
134+
}
135+
return embedding_dict
136+
137+
124138
def extract_bedrock_ai21_j2_model(request_body, response_body):
125139
response_body = json.loads(response_body)
126140
request_body = json.loads(request_body)
@@ -159,11 +173,12 @@ def extract_bedrock_cohere_model(request_body, response_body):
159173
return message_list, chat_completion_summary_dict
160174

161175

162-
MODEL_EXTRACTORS = {
163-
"amazon.titan": extract_bedrock_titan_model,
164-
"ai21.j2": extract_bedrock_ai21_j2_model,
165-
"cohere": extract_bedrock_cohere_model,
166-
}
176+
MODEL_EXTRACTORS = [ # Order is important here, avoiding dictionaries
177+
("amazon.titan-embed", extract_bedrock_titan_embedding_model),
178+
("amazon.titan", extract_bedrock_titan_text_model),
179+
("ai21.j2", extract_bedrock_ai21_j2_model),
180+
("cohere", extract_bedrock_cohere_model),
181+
]
167182

168183

169184
@function_wrapper
@@ -194,7 +209,7 @@ def wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
194209
return response
195210

196211
# Determine extractor by model type
197-
for extractor_name, extractor in MODEL_EXTRACTORS.items():
212+
for extractor_name, extractor in MODEL_EXTRACTORS:
198213
if model.startswith(extractor_name):
199214
break
200215
else:
@@ -213,7 +228,45 @@ def wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
213228
# Read and replace response streaming bodies
214229
response_body = response["body"].read()
215230
response["body"] = StreamingBody(BytesIO(response_body), len(response_body))
231+
response_headers = response["ResponseMetadata"]["HTTPHeaders"]
232+
233+
if model.startswith("amazon.titan-embed"): # Only available embedding models
234+
handle_embedding_event(instance, transaction, extractor, model, response_body, response_headers, request_body, ft.duration)
235+
else:
236+
handle_chat_completion_event(instance, transaction, extractor, model, response_body, response_headers, request_body, ft.duration)
237+
238+
return response
216239

240+
def handle_embedding_event(client, transaction, extractor, model, response_body, response_headers, request_body, duration):
241+
embedding_id = str(uuid.uuid4())
242+
available_metadata = get_trace_linking_metadata()
243+
span_id = available_metadata.get("span.id", "")
244+
trace_id = available_metadata.get("trace.id", "")
245+
246+
request_id = response_headers.get("x-amzn-requestid", "")
247+
settings = transaction.settings if transaction.settings is not None else global_settings()
248+
249+
embedding_dict = extractor(request_body, response_body)
250+
251+
embedding_dict.update({
252+
"vendor": "bedrock",
253+
"ingest_source": "Python",
254+
"id": embedding_id,
255+
"appName": settings.app_name,
256+
"span_id": span_id,
257+
"trace_id": trace_id,
258+
"request_id": request_id,
259+
"transaction_id": transaction._transaction_id,
260+
"api_key_last_four_digits": client._request_signer._credentials.access_key[-4:],
261+
"duration": duration,
262+
"request.model": model,
263+
"response.model": model,
264+
})
265+
266+
transaction.record_ml_event("LlmEmbedding", embedding_dict)
267+
268+
269+
def handle_chat_completion_event(client, transaction, extractor, model, response_body, response_headers, request_body, duration):
217270
custom_attrs_dict = transaction._custom_params
218271
conversation_id = custom_attrs_dict.get("conversation_id", "")
219272

@@ -222,7 +275,6 @@ def wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
222275
span_id = available_metadata.get("span.id", "")
223276
trace_id = available_metadata.get("trace.id", "")
224277

225-
response_headers = response["ResponseMetadata"]["HTTPHeaders"]
226278
request_id = response_headers.get("x-amzn-requestid", "")
227279
settings = transaction.settings if transaction.settings is not None else global_settings()
228280

@@ -232,15 +284,15 @@ def wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
232284
{
233285
"vendor": "bedrock",
234286
"ingest_source": "Python",
235-
"api_key_last_four_digits": instance._request_signer._credentials.access_key[-4:],
287+
"api_key_last_four_digits": client._request_signer._credentials.access_key[-4:],
236288
"id": chat_completion_id,
237289
"appName": settings.app_name,
238290
"conversation_id": conversation_id,
239291
"span_id": span_id,
240292
"trace_id": trace_id,
241293
"transaction_id": transaction._transaction_id,
242294
"request_id": request_id,
243-
"duration": ft.duration,
295+
"duration": duration,
244296
"request.model": model,
245297
"response.model": model, # Duplicate data required by the UI
246298
}
@@ -261,8 +313,6 @@ def wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
261313
response_id=response_id,
262314
)
263315

264-
return response
265-
266316

267317
CUSTOM_TRACE_POINTS = {
268318
("sns", "publish"): message_trace("SNS", "Produce", "Topic", extract(("TopicArn", "TargetArn"), "PhoneNumber")),

0 commit comments

Comments
 (0)