1818import traceback
1919import platform
2020import sys
21+ import re
2122
2223from optillm .cot_decoding import cot_decode
2324from optillm .entropy_decoding import entropy_decode
2930logging .basicConfig (level = logging .INFO )
3031logger = logging .getLogger (__name__ )
3132
33+ def count_reasoning_tokens (text : str , tokenizer = None ) -> int :
34+ """
35+ Count tokens within <think>...</think> tags in the given text.
36+
37+ Args:
38+ text: The text to analyze
39+ tokenizer: Optional tokenizer instance for precise counting
40+
41+ Returns:
42+ Number of reasoning tokens (0 if no think tags found)
43+ """
44+ if not text or not isinstance (text , str ):
45+ return 0
46+
47+ # Extract all content within <think>...</think> tags
48+ # Handle both complete and truncated think blocks
49+
50+ # First, find all complete <think>...</think> blocks
51+ complete_pattern = r'<think>(.*?)</think>'
52+ complete_matches = re .findall (complete_pattern , text , re .DOTALL )
53+
54+ # Then check for unclosed <think> tag (truncated response)
55+ # This finds <think> that doesn't have a matching </think> after it
56+ truncated_pattern = r'<think>(?!.*</think>)(.*)$'
57+ truncated_match = re .search (truncated_pattern , text , re .DOTALL )
58+
59+ # Combine all thinking content
60+ thinking_content = '' .join (complete_matches )
61+ if truncated_match :
62+ thinking_content += truncated_match .group (1 )
63+
64+ if not thinking_content :
65+ return 0
66+
67+ if tokenizer and hasattr (tokenizer , 'encode' ):
68+ # Use tokenizer for precise counting
69+ try :
70+ tokens = tokenizer .encode (thinking_content )
71+ return len (tokens )
72+ except Exception as e :
73+ logger .warning (f"Failed to count tokens with tokenizer: { e } " )
74+
75+ # Fallback: rough estimation (4 chars per token on average, minimum 1 token for non-empty content)
76+ content_length = len (thinking_content .strip ())
77+ return max (1 , content_length // 4 ) if content_length > 0 else 0
78+
3279# MLX Support for Apple Silicon
3380try :
3481 import mlx .core as mx
@@ -1502,10 +1549,11 @@ def __init__(
15021549 self .message .logprobs = logprobs
15031550
15041551class ChatCompletionUsage :
1505- def __init__ (self , prompt_tokens : int , completion_tokens : int , total_tokens : int ):
1552+ def __init__ (self , prompt_tokens : int , completion_tokens : int , total_tokens : int , reasoning_tokens : int = 0 ):
15061553 self .prompt_tokens = prompt_tokens
15071554 self .completion_tokens = completion_tokens
15081555 self .total_tokens = total_tokens
1556+ self .reasoning_tokens = reasoning_tokens
15091557
15101558class ChatCompletion :
15111559 def __init__ (self , response_dict : Dict ):
@@ -1547,7 +1595,10 @@ def model_dump(self) -> Dict:
15471595 "usage" : {
15481596 "prompt_tokens" : self .usage .prompt_tokens ,
15491597 "completion_tokens" : self .usage .completion_tokens ,
1550- "total_tokens" : self .usage .total_tokens
1598+ "total_tokens" : self .usage .total_tokens ,
1599+ "completion_tokens_details" : {
1600+ "reasoning_tokens" : getattr (self .usage , 'reasoning_tokens' , 0 )
1601+ }
15511602 }
15521603 }
15531604
@@ -1766,15 +1817,15 @@ def create(
17661817
17671818 logger .debug (f"ThinkDeeper tokens: user={ user_max_tokens } , thinking={ max_thinking_tokens } , adjusted={ adjusted_max_tokens } " )
17681819
1769- result = thinkdeeper_decode_mlx (
1820+ result , reasoning_tokens = thinkdeeper_decode_mlx (
17701821 pipeline .model ,
17711822 pipeline .tokenizer ,
17721823 messages ,
17731824 thinkdeeper_config_with_tokens
17741825 )
17751826 else :
17761827 logger .info ("Using PyTorch ThinkDeeper implementation" )
1777- result = thinkdeeper_decode (
1828+ result , reasoning_tokens = thinkdeeper_decode (
17781829 pipeline .current_model ,
17791830 pipeline .tokenizer ,
17801831 messages ,
@@ -1850,6 +1901,11 @@ def create(
18501901 prompt_tokens = len (pipeline .tokenizer .encode (prompt ))
18511902 completion_tokens = sum (token_counts )
18521903
1904+ # Calculate reasoning tokens from all responses
1905+ total_reasoning_tokens = 0
1906+ for response in responses :
1907+ total_reasoning_tokens += count_reasoning_tokens (response , pipeline .tokenizer )
1908+
18531909 # Create OpenAI-compatible response format
18541910 response_dict = {
18551911 "id" : f"chatcmpl-{ int (time .time ()* 1000 )} " ,
@@ -1871,7 +1927,8 @@ def create(
18711927 "usage" : {
18721928 "prompt_tokens" : prompt_tokens ,
18731929 "completion_tokens" : completion_tokens ,
1874- "total_tokens" : completion_tokens + prompt_tokens
1930+ "total_tokens" : completion_tokens + prompt_tokens ,
1931+ "reasoning_tokens" : total_reasoning_tokens
18751932 }
18761933 }
18771934
0 commit comments