@@ -60,6 +60,10 @@ def count_tokens(s):
60
60
"tags" : SPANDATA .AI_TAGS ,
61
61
}
62
62
63
+ # To avoid double collecting tokens, we do *not* measure
64
+ # token counts for models for which we have an explicit integration
65
+ NO_COLLECT_TOKEN_MODELS = ["openai-chat" ]
66
+
63
67
64
68
class LangchainIntegration (Integration ):
65
69
identifier = "langchain"
@@ -82,6 +86,8 @@ class WatchedSpan:
82
86
span = None # type: Span
83
87
num_completion_tokens = 0 # type: int
84
88
num_prompt_tokens = 0 # type: int
89
+ no_collect_tokens = False # type: bool
90
+ children = [] # type: List[WatchedSpan]
85
91
86
92
def __init__ (self , span ):
87
93
# type: (Span) -> None
@@ -104,7 +110,8 @@ def gc_span_map(self):
104
110
# type: () -> None
105
111
106
112
while len (self .span_map ) > self .max_span_map_size :
107
- self .span_map .popitem (last = False )[1 ].span .__exit__ (None , None , None )
113
+ run_id , watched_span = self .span_map .popitem (last = False )
114
+ self ._exit_span (watched_span , run_id )
108
115
109
116
def _handle_error (self , run_id , error ):
110
117
# type: (UUID, Any) -> None
@@ -125,24 +132,30 @@ def _normalize_langchain_message(self, message):
125
132
return parsed
126
133
127
134
def _create_span (self , run_id , parent_id , ** kwargs ):
128
- # type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> Span
135
+ # type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan
129
136
130
137
if "origin" not in kwargs :
131
138
kwargs ["origin" ] = "auto.ai.langchain"
132
139
133
- span = None # type: Optional[Span ]
140
+ watched_span = None # type: Optional[WatchedSpan ]
134
141
if parent_id :
135
142
parent_span = self .span_map [parent_id ] # type: Optional[WatchedSpan]
136
143
if parent_span :
137
- span = parent_span .span .start_child (** kwargs )
138
- if span is None :
139
- span = sentry_sdk .start_span (** kwargs )
144
+ watched_span = WatchedSpan (parent_span .span .start_child (** kwargs ))
145
+ parent_span .children .append (watched_span )
146
+ if watched_span is None :
147
+ watched_span = WatchedSpan (sentry_sdk .start_span (** kwargs ))
140
148
141
- span .__enter__ ()
142
- watched_span = WatchedSpan (span )
149
+ watched_span .span .__enter__ ()
143
150
self .span_map [run_id ] = watched_span
144
151
self .gc_span_map ()
145
- return span
152
+ return watched_span
153
+
154
+ def _exit_span (self , span_data , run_id ):
155
+ # type: (SentryLangchainCallback, WatchedSpan, UUID) -> None
156
+
157
+ span_data .span .__exit__ (None , None , None )
158
+ del self .span_map [run_id ]
146
159
147
160
def on_llm_start (
148
161
self ,
@@ -162,12 +175,13 @@ def on_llm_start(
162
175
return
163
176
all_params = kwargs .get ("invocation_params" , {})
164
177
all_params .update (serialized .get ("kwargs" , {}))
165
- span = self ._create_span (
178
+ watched_span = self ._create_span (
166
179
run_id ,
167
180
kwargs .get ("parent_run_id" ),
168
181
op = OP .LANGCHAIN_RUN ,
169
182
description = kwargs .get ("name" ) or "Langchain LLM call" ,
170
183
)
184
+ span = watched_span .span
171
185
if should_send_default_pii () and self .include_prompts :
172
186
set_data_normalized (span , SPANDATA .AI_INPUT_MESSAGES , prompts )
173
187
for k , v in DATA_FIELDS .items ():
@@ -182,15 +196,19 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
182
196
return
183
197
all_params = kwargs .get ("invocation_params" , {})
184
198
all_params .update (serialized .get ("kwargs" , {}))
185
- span = self ._create_span (
199
+ watched_span = self ._create_span (
186
200
run_id ,
187
201
kwargs .get ("parent_run_id" ),
188
202
op = OP .LANGCHAIN_CHAT_COMPLETIONS_CREATE ,
189
203
description = kwargs .get ("name" ) or "Langchain Chat Model" ,
190
204
)
205
+ span = watched_span .span
191
206
model = all_params .get (
192
207
"model" , all_params .get ("model_name" , all_params .get ("model_id" ))
193
208
)
209
+ watched_span .no_collect_tokens = any (
210
+ x in all_params .get ("_type" , "" ) for x in NO_COLLECT_TOKEN_MODELS
211
+ )
194
212
if not model and "anthropic" in all_params .get ("_type" ):
195
213
model = "claude-2"
196
214
if model :
@@ -207,11 +225,12 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
207
225
for k , v in DATA_FIELDS .items ():
208
226
if k in all_params :
209
227
set_data_normalized (span , v , all_params [k ])
210
- for list_ in messages :
211
- for message in list_ :
212
- self .span_map [run_id ].num_prompt_tokens += count_tokens (
213
- message .content
214
- ) + count_tokens (message .type )
228
+ if not watched_span .no_collect_tokens :
229
+ for list_ in messages :
230
+ for message in list_ :
231
+ self .span_map [run_id ].num_prompt_tokens += count_tokens (
232
+ message .content
233
+ ) + count_tokens (message .type )
215
234
216
235
def on_llm_new_token (self , token , * , run_id , ** kwargs ):
217
236
# type: (SentryLangchainCallback, str, UUID, Any) -> Any
@@ -220,7 +239,7 @@ def on_llm_new_token(self, token, *, run_id, **kwargs):
220
239
if not run_id or run_id not in self .span_map :
221
240
return
222
241
span_data = self .span_map [run_id ]
223
- if not span_data :
242
+ if not span_data or span_data . no_collect_tokens :
224
243
return
225
244
span_data .num_completion_tokens += count_tokens (token )
226
245
@@ -246,22 +265,22 @@ def on_llm_end(self, response, *, run_id, **kwargs):
246
265
[[x .text for x in list_ ] for list_ in response .generations ],
247
266
)
248
267
249
- if token_usage :
250
- record_token_usage (
251
- span_data .span ,
252
- token_usage .get ("prompt_tokens" ),
253
- token_usage .get ("completion_tokens" ),
254
- token_usage .get ("total_tokens" ),
255
- )
256
- else :
257
- record_token_usage (
258
- span_data .span ,
259
- span_data .num_prompt_tokens ,
260
- span_data .num_completion_tokens ,
261
- )
268
+ if not span_data .no_collect_tokens :
269
+ if token_usage :
270
+ record_token_usage (
271
+ span_data .span ,
272
+ token_usage .get ("prompt_tokens" ),
273
+ token_usage .get ("completion_tokens" ),
274
+ token_usage .get ("total_tokens" ),
275
+ )
276
+ else :
277
+ record_token_usage (
278
+ span_data .span ,
279
+ span_data .num_prompt_tokens ,
280
+ span_data .num_completion_tokens ,
281
+ )
262
282
263
- span_data .span .__exit__ (None , None , None )
264
- del self .span_map [run_id ]
283
+ self ._exit_span (span_data , run_id )
265
284
266
285
def on_llm_error (self , error , * , run_id , ** kwargs ):
267
286
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
@@ -275,7 +294,7 @@ def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):
275
294
with capture_internal_exceptions ():
276
295
if not run_id :
277
296
return
278
- span = self ._create_span (
297
+ watched_span = self ._create_span (
279
298
run_id ,
280
299
kwargs .get ("parent_run_id" ),
281
300
op = (
@@ -287,7 +306,7 @@ def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):
287
306
)
288
307
metadata = kwargs .get ("metadata" )
289
308
if metadata :
290
- set_data_normalized (span , SPANDATA .AI_METADATA , metadata )
309
+ set_data_normalized (watched_span . span , SPANDATA .AI_METADATA , metadata )
291
310
292
311
def on_chain_end (self , outputs , * , run_id , ** kwargs ):
293
312
# type: (SentryLangchainCallback, Dict[str, Any], UUID, Any) -> Any
@@ -299,8 +318,7 @@ def on_chain_end(self, outputs, *, run_id, **kwargs):
299
318
span_data = self .span_map [run_id ]
300
319
if not span_data :
301
320
return
302
- span_data .span .__exit__ (None , None , None )
303
- del self .span_map [run_id ]
321
+ self ._exit_span (span_data , run_id )
304
322
305
323
def on_chain_error (self , error , * , run_id , ** kwargs ):
306
324
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
@@ -312,14 +330,16 @@ def on_agent_action(self, action, *, run_id, **kwargs):
312
330
with capture_internal_exceptions ():
313
331
if not run_id :
314
332
return
315
- span = self ._create_span (
333
+ watched_span = self ._create_span (
316
334
run_id ,
317
335
kwargs .get ("parent_run_id" ),
318
336
op = OP .LANGCHAIN_AGENT ,
319
337
description = action .tool or "AI tool usage" ,
320
338
)
321
339
if action .tool_input and should_send_default_pii () and self .include_prompts :
322
- set_data_normalized (span , SPANDATA .AI_INPUT_MESSAGES , action .tool_input )
340
+ set_data_normalized (
341
+ watched_span .span , SPANDATA .AI_INPUT_MESSAGES , action .tool_input
342
+ )
323
343
324
344
def on_agent_finish (self , finish , * , run_id , ** kwargs ):
325
345
# type: (SentryLangchainCallback, AgentFinish, UUID, Any) -> Any
@@ -334,16 +354,15 @@ def on_agent_finish(self, finish, *, run_id, **kwargs):
334
354
set_data_normalized (
335
355
span_data .span , SPANDATA .AI_RESPONSES , finish .return_values .items ()
336
356
)
337
- span_data .span .__exit__ (None , None , None )
338
- del self .span_map [run_id ]
357
+ self ._exit_span (span_data , run_id )
339
358
340
359
def on_tool_start (self , serialized , input_str , * , run_id , ** kwargs ):
341
360
# type: (SentryLangchainCallback, Dict[str, Any], str, UUID, Any) -> Any
342
361
"""Run when tool starts running."""
343
362
with capture_internal_exceptions ():
344
363
if not run_id :
345
364
return
346
- span = self ._create_span (
365
+ watched_span = self ._create_span (
347
366
run_id ,
348
367
kwargs .get ("parent_run_id" ),
349
368
op = OP .LANGCHAIN_TOOL ,
@@ -353,11 +372,13 @@ def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
353
372
)
354
373
if should_send_default_pii () and self .include_prompts :
355
374
set_data_normalized (
356
- span , SPANDATA .AI_INPUT_MESSAGES , kwargs .get ("inputs" , [input_str ])
375
+ watched_span .span ,
376
+ SPANDATA .AI_INPUT_MESSAGES ,
377
+ kwargs .get ("inputs" , [input_str ]),
357
378
)
358
379
if kwargs .get ("metadata" ):
359
380
set_data_normalized (
360
- span , SPANDATA .AI_METADATA , kwargs .get ("metadata" )
381
+ watched_span . span , SPANDATA .AI_METADATA , kwargs .get ("metadata" )
361
382
)
362
383
363
384
def on_tool_end (self , output , * , run_id , ** kwargs ):
@@ -372,8 +393,7 @@ def on_tool_end(self, output, *, run_id, **kwargs):
372
393
return
373
394
if should_send_default_pii () and self .include_prompts :
374
395
set_data_normalized (span_data .span , SPANDATA .AI_RESPONSES , output )
375
- span_data .span .__exit__ (None , None , None )
376
- del self .span_map [run_id ]
396
+ self ._exit_span (span_data , run_id )
377
397
378
398
def on_tool_error (self , error , * args , run_id , ** kwargs ):
379
399
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
0 commit comments