Skip to content

Commit 6cef136

Browse files
committed
Avoid double counting tokens with explicit blocklist
1 parent 517c0b9 commit 6cef136

File tree

2 files changed

+89
-48
lines changed

2 files changed

+89
-48
lines changed

sentry_sdk/integrations/langchain.py

Lines changed: 65 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def count_tokens(s):
6060
"tags": SPANDATA.AI_TAGS,
6161
}
6262

63+
# To avoid double collecting tokens, we do *not* measure
64+
# token counts for models for which we have an explicit integration
65+
NO_COLLECT_TOKEN_MODELS = ["openai-chat"]
66+
6367

6468
class LangchainIntegration(Integration):
6569
identifier = "langchain"
@@ -82,6 +86,8 @@ class WatchedSpan:
8286
span = None # type: Span
8387
num_completion_tokens = 0 # type: int
8488
num_prompt_tokens = 0 # type: int
89+
no_collect_tokens = False # type: bool
90+
children = [] # type: List[WatchedSpan]
8591

8692
def __init__(self, span):
8793
# type: (Span) -> None
@@ -104,7 +110,8 @@ def gc_span_map(self):
104110
# type: () -> None
105111

106112
while len(self.span_map) > self.max_span_map_size:
107-
self.span_map.popitem(last=False)[1].span.__exit__(None, None, None)
113+
run_id, watched_span = self.span_map.popitem(last=False)
114+
self._exit_span(watched_span, run_id)
108115

109116
def _handle_error(self, run_id, error):
110117
# type: (UUID, Any) -> None
@@ -125,24 +132,30 @@ def _normalize_langchain_message(self, message):
125132
return parsed
126133

127134
def _create_span(self, run_id, parent_id, **kwargs):
128-
# type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> Span
135+
# type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan
129136

130137
if "origin" not in kwargs:
131138
kwargs["origin"] = "auto.ai.langchain"
132139

133-
span = None # type: Optional[Span]
140+
watched_span = None # type: Optional[WatchedSpan]
134141
if parent_id:
135142
parent_span = self.span_map[parent_id] # type: Optional[WatchedSpan]
136143
if parent_span:
137-
span = parent_span.span.start_child(**kwargs)
138-
if span is None:
139-
span = sentry_sdk.start_span(**kwargs)
144+
watched_span = WatchedSpan(parent_span.span.start_child(**kwargs))
145+
parent_span.children.append(watched_span)
146+
if watched_span is None:
147+
watched_span = WatchedSpan(sentry_sdk.start_span(**kwargs))
140148

141-
span.__enter__()
142-
watched_span = WatchedSpan(span)
149+
watched_span.span.__enter__()
143150
self.span_map[run_id] = watched_span
144151
self.gc_span_map()
145-
return span
152+
return watched_span
153+
154+
def _exit_span(self, span_data, run_id):
155+
# type: (SentryLangchainCallback, WatchedSpan, UUID) -> None
156+
157+
span_data.span.__exit__(None, None, None)
158+
del self.span_map[run_id]
146159

147160
def on_llm_start(
148161
self,
@@ -162,12 +175,13 @@ def on_llm_start(
162175
return
163176
all_params = kwargs.get("invocation_params", {})
164177
all_params.update(serialized.get("kwargs", {}))
165-
span = self._create_span(
178+
watched_span = self._create_span(
166179
run_id,
167180
kwargs.get("parent_run_id"),
168181
op=OP.LANGCHAIN_RUN,
169182
description=kwargs.get("name") or "Langchain LLM call",
170183
)
184+
span = watched_span.span
171185
if should_send_default_pii() and self.include_prompts:
172186
set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompts)
173187
for k, v in DATA_FIELDS.items():
@@ -182,15 +196,19 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
182196
return
183197
all_params = kwargs.get("invocation_params", {})
184198
all_params.update(serialized.get("kwargs", {}))
185-
span = self._create_span(
199+
watched_span = self._create_span(
186200
run_id,
187201
kwargs.get("parent_run_id"),
188202
op=OP.LANGCHAIN_CHAT_COMPLETIONS_CREATE,
189203
description=kwargs.get("name") or "Langchain Chat Model",
190204
)
205+
span = watched_span.span
191206
model = all_params.get(
192207
"model", all_params.get("model_name", all_params.get("model_id"))
193208
)
209+
watched_span.no_collect_tokens = any(
210+
x in all_params.get("_type", "") for x in NO_COLLECT_TOKEN_MODELS
211+
)
194212
if not model and "anthropic" in all_params.get("_type"):
195213
model = "claude-2"
196214
if model:
@@ -207,11 +225,12 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
207225
for k, v in DATA_FIELDS.items():
208226
if k in all_params:
209227
set_data_normalized(span, v, all_params[k])
210-
for list_ in messages:
211-
for message in list_:
212-
self.span_map[run_id].num_prompt_tokens += count_tokens(
213-
message.content
214-
) + count_tokens(message.type)
228+
if not watched_span.no_collect_tokens:
229+
for list_ in messages:
230+
for message in list_:
231+
self.span_map[run_id].num_prompt_tokens += count_tokens(
232+
message.content
233+
) + count_tokens(message.type)
215234

216235
def on_llm_new_token(self, token, *, run_id, **kwargs):
217236
# type: (SentryLangchainCallback, str, UUID, Any) -> Any
@@ -220,7 +239,7 @@ def on_llm_new_token(self, token, *, run_id, **kwargs):
220239
if not run_id or run_id not in self.span_map:
221240
return
222241
span_data = self.span_map[run_id]
223-
if not span_data:
242+
if not span_data or span_data.no_collect_tokens:
224243
return
225244
span_data.num_completion_tokens += count_tokens(token)
226245

@@ -246,22 +265,22 @@ def on_llm_end(self, response, *, run_id, **kwargs):
246265
[[x.text for x in list_] for list_ in response.generations],
247266
)
248267

249-
if token_usage:
250-
record_token_usage(
251-
span_data.span,
252-
token_usage.get("prompt_tokens"),
253-
token_usage.get("completion_tokens"),
254-
token_usage.get("total_tokens"),
255-
)
256-
else:
257-
record_token_usage(
258-
span_data.span,
259-
span_data.num_prompt_tokens,
260-
span_data.num_completion_tokens,
261-
)
268+
if not span_data.no_collect_tokens:
269+
if token_usage:
270+
record_token_usage(
271+
span_data.span,
272+
token_usage.get("prompt_tokens"),
273+
token_usage.get("completion_tokens"),
274+
token_usage.get("total_tokens"),
275+
)
276+
else:
277+
record_token_usage(
278+
span_data.span,
279+
span_data.num_prompt_tokens,
280+
span_data.num_completion_tokens,
281+
)
262282

263-
span_data.span.__exit__(None, None, None)
264-
del self.span_map[run_id]
283+
self._exit_span(span_data, run_id)
265284

266285
def on_llm_error(self, error, *, run_id, **kwargs):
267286
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
@@ -275,7 +294,7 @@ def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):
275294
with capture_internal_exceptions():
276295
if not run_id:
277296
return
278-
span = self._create_span(
297+
watched_span = self._create_span(
279298
run_id,
280299
kwargs.get("parent_run_id"),
281300
op=(
@@ -287,7 +306,7 @@ def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):
287306
)
288307
metadata = kwargs.get("metadata")
289308
if metadata:
290-
set_data_normalized(span, SPANDATA.AI_METADATA, metadata)
309+
set_data_normalized(watched_span.span, SPANDATA.AI_METADATA, metadata)
291310

292311
def on_chain_end(self, outputs, *, run_id, **kwargs):
293312
# type: (SentryLangchainCallback, Dict[str, Any], UUID, Any) -> Any
@@ -299,8 +318,7 @@ def on_chain_end(self, outputs, *, run_id, **kwargs):
299318
span_data = self.span_map[run_id]
300319
if not span_data:
301320
return
302-
span_data.span.__exit__(None, None, None)
303-
del self.span_map[run_id]
321+
self._exit_span(span_data, run_id)
304322

305323
def on_chain_error(self, error, *, run_id, **kwargs):
306324
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
@@ -312,14 +330,16 @@ def on_agent_action(self, action, *, run_id, **kwargs):
312330
with capture_internal_exceptions():
313331
if not run_id:
314332
return
315-
span = self._create_span(
333+
watched_span = self._create_span(
316334
run_id,
317335
kwargs.get("parent_run_id"),
318336
op=OP.LANGCHAIN_AGENT,
319337
description=action.tool or "AI tool usage",
320338
)
321339
if action.tool_input and should_send_default_pii() and self.include_prompts:
322-
set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, action.tool_input)
340+
set_data_normalized(
341+
watched_span.span, SPANDATA.AI_INPUT_MESSAGES, action.tool_input
342+
)
323343

324344
def on_agent_finish(self, finish, *, run_id, **kwargs):
325345
# type: (SentryLangchainCallback, AgentFinish, UUID, Any) -> Any
@@ -334,16 +354,15 @@ def on_agent_finish(self, finish, *, run_id, **kwargs):
334354
set_data_normalized(
335355
span_data.span, SPANDATA.AI_RESPONSES, finish.return_values.items()
336356
)
337-
span_data.span.__exit__(None, None, None)
338-
del self.span_map[run_id]
357+
self._exit_span(span_data, run_id)
339358

340359
def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
341360
# type: (SentryLangchainCallback, Dict[str, Any], str, UUID, Any) -> Any
342361
"""Run when tool starts running."""
343362
with capture_internal_exceptions():
344363
if not run_id:
345364
return
346-
span = self._create_span(
365+
watched_span = self._create_span(
347366
run_id,
348367
kwargs.get("parent_run_id"),
349368
op=OP.LANGCHAIN_TOOL,
@@ -353,11 +372,13 @@ def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
353372
)
354373
if should_send_default_pii() and self.include_prompts:
355374
set_data_normalized(
356-
span, SPANDATA.AI_INPUT_MESSAGES, kwargs.get("inputs", [input_str])
375+
watched_span.span,
376+
SPANDATA.AI_INPUT_MESSAGES,
377+
kwargs.get("inputs", [input_str]),
357378
)
358379
if kwargs.get("metadata"):
359380
set_data_normalized(
360-
span, SPANDATA.AI_METADATA, kwargs.get("metadata")
381+
watched_span.span, SPANDATA.AI_METADATA, kwargs.get("metadata")
361382
)
362383

363384
def on_tool_end(self, output, *, run_id, **kwargs):
@@ -372,8 +393,7 @@ def on_tool_end(self, output, *, run_id, **kwargs):
372393
return
373394
if should_send_default_pii() and self.include_prompts:
374395
set_data_normalized(span_data.span, SPANDATA.AI_RESPONSES, output)
375-
span_data.span.__exit__(None, None, None)
376-
del self.span_map[run_id]
396+
self._exit_span(span_data, run_id)
377397

378398
def on_tool_error(self, error, *args, run_id, **kwargs):
379399
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any

tests/integrations/langchain/test_langchain.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def get_word_length(word: str) -> int:
2020

2121

2222
global stream_result_mock # type: Mock
23+
global llm_type # type: str
2324

2425

2526
class MockOpenAI(ChatOpenAI):
@@ -33,14 +34,26 @@ def _stream(
3334
for x in stream_result_mock():
3435
yield x
3536

37+
@property
38+
def _llm_type(self) -> str:
39+
return llm_type
40+
3641

3742
@pytest.mark.parametrize(
38-
"send_default_pii, include_prompts",
39-
[(True, True), (True, False), (False, True), (False, False)],
43+
"send_default_pii, include_prompts, use_unknown_llm_type",
44+
[
45+
(True, True, False),
46+
(True, False, False),
47+
(False, True, False),
48+
(False, False, True),
49+
],
4050
)
4151
def test_langchain_agent(
42-
sentry_init, capture_events, send_default_pii, include_prompts
52+
sentry_init, capture_events, send_default_pii, include_prompts, use_unknown_llm_type
4353
):
54+
global llm_type
55+
llm_type = "acme-llm" if use_unknown_llm_type else "openai-chat"
56+
4457
sentry_init(
4558
integrations=[LangchainIntegration(include_prompts=include_prompts)],
4659
traces_sample_rate=1.0,
@@ -144,6 +157,14 @@ def test_langchain_agent(
144157
# We can't guarantee anything about the "shape" of the langchain execution graph
145158
assert len(list(x for x in tx["spans"] if x["op"] == "ai.run.langchain")) > 0
146159

160+
if use_unknown_llm_type:
161+
assert "ai_prompt_tokens_used" in chat_spans[0]["measurements"]
162+
assert "ai_total_tokens_used" in chat_spans[0]["measurements"]
163+
else:
164+
# important: to avoid double counting, we do *not* measure
165+
# tokens used if we have an explicit integration (e.g. OpenAI)
166+
assert "measurements" not in chat_spans[0]
167+
147168
if send_default_pii and include_prompts:
148169
assert (
149170
"You are very powerful"

0 commit comments

Comments
 (0)