Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions newrelic/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ def __asgi_application(*args, **kwargs):
from newrelic.api.message_transaction import (
wrap_message_transaction as __wrap_message_transaction,
)
from newrelic.api.ml_model import get_ai_message_ids as __get_ai_message_ids
from newrelic.api.ml_model import get_llm_message_ids as __get_llm_message_ids
from newrelic.api.ml_model import (
record_llm_feedback_event as __record_llm_feedback_event,
)
from newrelic.api.ml_model import wrap_mlmodel as __wrap_mlmodel
from newrelic.api.profile_trace import ProfileTraceWrapper as __ProfileTraceWrapper
from newrelic.api.profile_trace import profile_trace as __profile_trace
Expand Down Expand Up @@ -341,4 +344,5 @@ def __asgi_application(*args, **kwargs):
insert_html_snippet = __wrap_api_call(__insert_html_snippet, "insert_html_snippet")
verify_body_exists = __wrap_api_call(__verify_body_exists, "verify_body_exists")
wrap_mlmodel = __wrap_api_call(__wrap_mlmodel, "wrap_mlmodel")
get_ai_message_ids = __wrap_api_call(__get_ai_message_ids, "get_ai_message_ids")
get_llm_message_ids = __wrap_api_call(__get_llm_message_ids, "get_llm_message_ids")
record_llm_feedback_event = __wrap_api_call(__record_llm_feedback_event, "record_llm_feedback_event")
34 changes: 32 additions & 2 deletions newrelic/api/ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import sys
import uuid
import warnings

from newrelic.api.transaction import current_transaction
Expand All @@ -37,7 +38,7 @@ def wrap_mlmodel(model, name=None, version=None, feature_names=None, label_names
model._nr_wrapped_metadata = metadata


def get_ai_message_ids(response_id=None):
def get_llm_message_ids(response_id=None):
transaction = current_transaction()
if response_id and transaction:
nr_message_ids = getattr(transaction, "_nr_message_ids", {})
Expand All @@ -50,5 +51,34 @@ def get_ai_message_ids(response_id=None):
conversation_id, request_id, ids = message_id_info

return [{"conversation_id": conversation_id, "request_id": request_id, "message_id": _id} for _id in ids]
warnings.warn("No message ids found. get_ai_message_ids must be called within the scope of a transaction.")
warnings.warn("No message ids found. get_llm_message_ids must be called within the scope of a transaction.")
return []


def record_llm_feedback_event(
message_id, rating, conversation_id=None, request_id=None, category=None, message=None, metadata=None
):
transaction = current_transaction()
if not transaction:
warnings.warn(
"No message feedback events will be recorded. record_llm_feedback_event must be called within the "
"scope of a transaction."
)
return

feedback_message_id = str(uuid.uuid4())
metadata = metadata or {}

feedback_message_event = {
"id": feedback_message_id,
"message_id": message_id,
"rating": rating,
"conversation_id": conversation_id or "",
"request_id": request_id or "",
"category": category or "",
"message": message or "",
"ingest_source": "Python",
}
feedback_message_event.update(metadata)

transaction.record_ml_event("LlmFeedbackMessage", feedback_message_event)
95 changes: 95 additions & 0 deletions tests/agent_features/test_record_llm_feedback_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2010 New Relic, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from testing_support.fixtures import reset_core_stats_engine
from testing_support.validators.validate_ml_event_count import validate_ml_event_count
from testing_support.validators.validate_ml_events import validate_ml_events

from newrelic.api.background_task import background_task
from newrelic.api.ml_model import record_llm_feedback_event


@reset_core_stats_engine()
def test_record_llm_feedback_event_all_args_supplied():
llm_feedback_all_args_recorded_events = [
(
{"type": "LlmFeedbackMessage"},
{
"id": None,
"category": "informative",
"rating": 1,
"message_id": "message_id",
"request_id": "request_id",
"conversation_id": "conversation_id",
"ingest_source": "Python",
"message": "message",
"foo": "bar",
},
),
]

@validate_ml_events(llm_feedback_all_args_recorded_events)
@background_task()
def _test():
record_llm_feedback_event(
rating=1,
message_id="message_id",
category="informative",
request_id="request_id",
conversation_id="conversation_id",
message="message",
metadata={"foo": "bar"},
)

_test()


@reset_core_stats_engine()
def test_record_llm_feedback_event_required_args_supplied():
llm_feedback_required_args_recorded_events = [
(
{"type": "LlmFeedbackMessage"},
{
"id": None,
"category": "",
"rating": "Good",
"message_id": "message_id",
"request_id": "",
"conversation_id": "",
"ingest_source": "Python",
"message": "",
},
),
]

@validate_ml_events(llm_feedback_required_args_recorded_events)
@background_task()
def _test():
record_llm_feedback_event(message_id="message_id", rating="Good")

_test()


@reset_core_stats_engine()
@validate_ml_event_count(count=0)
def test_record_llm_feedback_event_outside_txn():
record_llm_feedback_event(
rating="Good",
message_id="message_id",
category="informative",
request_id="request_id",
conversation_id="conversation_id",
message="message",
metadata={"foo": "bar"},
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

import openai
from testing_support.fixtures import reset_core_stats_engine
from testing_support.validators.validate_ml_event_count import validate_ml_event_count

from newrelic.api.background_task import background_task
from newrelic.api.ml_model import get_ai_message_ids
from newrelic.api.ml_model import get_llm_message_ids, record_llm_feedback_event
from newrelic.api.transaction import add_custom_attribute, current_transaction

_test_openai_chat_completion_messages_1 = (
Expand Down Expand Up @@ -100,20 +101,20 @@

@reset_core_stats_engine()
@background_task()
def test_get_ai_message_ids_when_nr_message_ids_not_set():
message_ids = get_ai_message_ids("request-id-1")
def test_get_llm_message_ids_when_nr_message_ids_not_set():
message_ids = get_llm_message_ids("request-id-1")
assert message_ids == []


@reset_core_stats_engine()
def test_get_ai_message_ids_outside_transaction():
message_ids = get_ai_message_ids("request-id-1")
def test_get_llm_message_ids_outside_transaction():
message_ids = get_llm_message_ids("request-id-1")
assert message_ids == []


@reset_core_stats_engine()
@background_task()
def test_get_ai_message_ids_mulitple_async(loop, set_trace_info):
def test_get_llm_message_ids_mulitple_async(loop, set_trace_info):
set_trace_info()
add_custom_attribute("conversation_id", "my-awesome-id")

Expand All @@ -128,10 +129,10 @@ async def _run():

results = loop.run_until_complete(_run())

message_ids = [m for m in get_ai_message_ids(results[0].id)]
message_ids = [m for m in get_llm_message_ids(results[0].id)]
assert message_ids == expected_message_ids_1

message_ids = [m for m in get_ai_message_ids(results[1].id)]
message_ids = [m for m in get_llm_message_ids(results[1].id)]
assert message_ids == expected_message_ids_2

# Make sure we aren't causing a memory leak.
Expand All @@ -141,7 +142,7 @@ async def _run():

@reset_core_stats_engine()
@background_task()
def test_get_ai_message_ids_mulitple_async_no_conversation_id(loop, set_trace_info):
def test_get_llm_message_ids_mulitple_async_no_conversation_id(loop, set_trace_info):
set_trace_info()

async def _run():
Expand All @@ -155,10 +156,10 @@ async def _run():

results = loop.run_until_complete(_run())

message_ids = [m for m in get_ai_message_ids(results[0].id)]
message_ids = [m for m in get_llm_message_ids(results[0].id)]
assert message_ids == expected_message_ids_1_no_conversation_id

message_ids = [m for m in get_ai_message_ids(results[1].id)]
message_ids = [m for m in get_llm_message_ids(results[1].id)]
assert message_ids == expected_message_ids_2_no_conversation_id

# Make sure we aren't causing a memory leak.
Expand All @@ -167,21 +168,33 @@ async def _run():


@reset_core_stats_engine()
# Three chat completion messages and one chat completion summary for each create call (8 in total)
# Three feedback events for the first create call
@validate_ml_event_count(11)
@background_task()
def test_get_ai_message_ids_mulitple_sync(set_trace_info):
def test_get_llm_message_ids_mulitple_sync(set_trace_info):
set_trace_info()
add_custom_attribute("conversation_id", "my-awesome-id")

results = openai.ChatCompletion.create(
model="gpt-3.5-turbo", messages=_test_openai_chat_completion_messages_1, temperature=0.7, max_tokens=100
)
message_ids = [m for m in get_ai_message_ids(results.id)]
message_ids = [m for m in get_llm_message_ids(results.id)]
assert message_ids == expected_message_ids_1

for message_id in message_ids:
record_llm_feedback_event(
category="informative",
rating=1,
message_id=message_id.get("message_id"),
request_id=message_id.get("request_id"),
conversation_id=message_id.get("conversation_id"),
)

results = openai.ChatCompletion.create(
model="gpt-3.5-turbo", messages=_test_openai_chat_completion_messages_2, temperature=0.7, max_tokens=100
)
message_ids = [m for m in get_ai_message_ids(results.id)]
message_ids = [m for m in get_llm_message_ids(results.id)]
assert message_ids == expected_message_ids_2

# Make sure we aren't causing a memory leak.
Expand All @@ -190,20 +203,30 @@ def test_get_ai_message_ids_mulitple_sync(set_trace_info):


@reset_core_stats_engine()
@validate_ml_event_count(11)
@background_task()
def test_get_ai_message_ids_mulitple_sync_no_conversation_id(set_trace_info):
def test_get_llm_message_ids_mulitple_sync_no_conversation_id(set_trace_info):
set_trace_info()

results = openai.ChatCompletion.create(
model="gpt-3.5-turbo", messages=_test_openai_chat_completion_messages_1, temperature=0.7, max_tokens=100
)
message_ids = [m for m in get_ai_message_ids(results.id)]
message_ids = [m for m in get_llm_message_ids(results.id)]
assert message_ids == expected_message_ids_1_no_conversation_id

for message_id in message_ids:
record_llm_feedback_event(
category="informative",
rating=1,
message_id=message_id.get("message_id"),
request_id=message_id.get("request_id"),
conversation_id=message_id.get("conversation_id"),
)

results = openai.ChatCompletion.create(
model="gpt-3.5-turbo", messages=_test_openai_chat_completion_messages_2, temperature=0.7, max_tokens=100
)
message_ids = [m for m in get_ai_message_ids(results.id)]
message_ids = [m for m in get_llm_message_ids(results.id)]
assert message_ids == expected_message_ids_2_no_conversation_id

# Make sure we aren't causing a memory leak.
Expand Down