@@ -34,81 +34,103 @@ def __init__(self, simulation_depth, exploration_weight, client, model):
3434 self .model = model
3535
3636 def select (self , node : MCTSNode ) -> MCTSNode :
37+ logger .debug (f"Selecting node. Current node visits: { node .visits } , value: { node .value } " )
3738 if not node .children :
39+ logger .debug ("Node has no children. Returning current node." )
3840 return node
39- return max (node .children , key = lambda c : c .value / (c .visits + 1e-8 ) + self .exploration_weight * np .sqrt (np .log (node .visits + 1 ) / (c .visits + 1e-8 )))
41+ selected_node = max (node .children , key = lambda c : c .value / (c .visits + 1e-8 ) + self .exploration_weight * np .sqrt (np .log (node .visits + 1 ) / (c .visits + 1e-8 )))
42+ logger .debug (f"Selected child node. Visits: { selected_node .visits } , Value: { selected_node .value } " )
43+ return selected_node
4044
4145 def expand (self , node : MCTSNode ) -> MCTSNode :
46+ logger .debug (f"Expanding node. Current state: { node .state } " )
4247 actions = self .generate_actions (node .state )
43- for action in actions :
48+ logger .debug (f"Generated { len (actions )} possible actions" )
49+ for i , action in enumerate (actions ):
4450 new_state = self .apply_action (node .state , action )
4551 child = MCTSNode (new_state , parent = node )
4652 node .children .append (child )
4753 self .graph .add_edge (id (node ), id (child ))
4854 self .node_labels [id (child )] = f"Visits: { child .visits } \n Value: { child .value :.2f} "
49- return random .choice (node .children )
55+ logger .debug (f"Created child node { i + 1 } . Action: { action [:50 ]} ..." )
56+ selected_child = random .choice (node .children )
57+ logger .debug (f"Randomly selected child node for simulation. Visits: { selected_child .visits } , Value: { selected_child .value } " )
58+ return selected_child
5059
5160 def simulate (self , node : MCTSNode ) -> float :
61+ logger .debug (f"Starting simulation from node. Current query: { node .state .current_query } " )
5262 state = node .state
53- for _ in range (self .simulation_depth ):
63+ for i in range (self .simulation_depth ):
5464 if self .is_terminal (state ):
65+ logger .debug (f"Reached terminal state at depth { i } " )
5566 break
5667 action = random .choice (self .generate_actions (state ))
5768 state = self .apply_action (state , action )
58- return self .evaluate_state (state )
69+ logger .debug (f"Simulation step { i + 1 } . Action: { action [:50 ]} ..." )
70+ value = self .evaluate_state (state )
71+ logger .debug (f"Simulation complete. Final state value: { value } " )
72+ return value
5973
6074 def backpropagate (self , node : MCTSNode , value : float ):
75+ logger .debug (f"Starting backpropagation. Initial value: { value } " )
6176 while node :
6277 node .visits += 1
6378 node .value += value
6479 self .node_labels [id (node )] = f"Visits: { node .visits } \n Value: { node .value :.2f} "
80+ logger .debug (f"Updated node. Visits: { node .visits } , New value: { node .value } " )
6581 node = node .parent
6682
6783 def search (self , initial_state : DialogueState , num_simulations : int ) -> DialogueState :
84+ logger .debug (f"Starting MCTS search with { num_simulations } simulations" )
6885 if not self .root :
6986 self .root = MCTSNode (initial_state )
7087 self .graph .add_node (id (self .root ))
7188 self .node_labels [id (self .root )] = f"Root\n Visits: 0\n Value: 0.00"
89+ logger .debug ("Created root node" )
7290
73- for _ in range (num_simulations ):
91+ for i in range (num_simulations ):
92+ logger .debug (f"Starting simulation { i + 1 } " )
7493 node = self .select (self .root )
7594 if not self .is_terminal (node .state ):
7695 node = self .expand (node )
7796 value = self .simulate (node )
7897 self .backpropagate (node , value )
7998
80- return max (self .root .children , key = lambda c : c .visits ).state
99+ best_child = max (self .root .children , key = lambda c : c .visits )
100+ logger .debug (f"Search complete. Best child node: Visits: { best_child .visits } , Value: { best_child .value } " )
101+ return best_child .state
81102
82103 def generate_actions (self , state : DialogueState ) -> List [str ]:
104+ logger .debug ("Generating actions for current state" )
83105 messages = [{"role" : "system" , "content" : state .system_prompt }]
84106 messages .extend (state .conversation_history )
85107 messages .append ({"role" : "user" , "content" : state .current_query })
86- # messages.append({"role": "system", "content": "Generate 3 possible responses to the user's query. Each response should be on a new line starting with 'Response:'."})
87108
88109 completions = []
89110 n = 3
90111
112+ logger .info (f"Requesting { n } completions from the model" )
91113 response = self .client .chat .completions .create (
92- model = self .model ,
114+ model = self .model ,
93115 messages = messages ,
94116 max_tokens = 4096 ,
95117 n = n ,
96118 temperature = 1
97119 )
98120 completions = [choice .message .content .strip () for choice in response .choices ]
99- # suggested_responses = response.choices[0].message.content.split("Response:")
100- # return [resp.strip() for resp in suggested_responses if resp.strip()]
121+ logger .info (f"Received { len (completions )} completions from the model" )
101122 return completions
102123
103124 def apply_action (self , state : DialogueState , action : str ) -> DialogueState :
125+ logger .info (f"Applying action: { action [:50 ]} ..." )
104126 new_history = state .conversation_history .copy ()
105127 new_history .append ({"role" : "assistant" , "content" : action })
106128
107129 messages = [{"role" : "system" , "content" : state .system_prompt }]
108130 messages .extend (new_history )
109131 messages .append ({"role" : "system" , "content" : "Based on this conversation, what might the user ask or say next? Provide a likely user query." })
110132
111-
133+ logger . info ( "Requesting next user query from the model" )
112134 response = self .client .chat .completions .create (
113135 model = self .model ,
114136 messages = messages ,
@@ -118,14 +140,16 @@ def apply_action(self, state: DialogueState, action: str) -> DialogueState:
118140 )
119141
120142 next_query = response .choices [0 ].message .content
143+ logger .info (f"Generated next user query: { next_query } " )
121144 return DialogueState (state .system_prompt , new_history , next_query )
122145
123146 def is_terminal (self , state : DialogueState ) -> bool :
124- # Consider the state terminal if the conversation has reached a natural conclusion
125- # or if it has exceeded a certain number of turns
126- return len ( state . conversation_history ) > 10 or "goodbye" in state . current_query . lower ()
147+ is_terminal = len ( state . conversation_history ) > 10 or "goodbye" in state . current_query . lower ()
148+ logger . info ( f"Checking if state is terminal: { is_terminal } " )
149+ return is_terminal
127150
128151 def evaluate_state (self , state : DialogueState ) -> float :
152+ logger .info ("Evaluating current state" )
129153 messages = [{"role" : "system" , "content" : state .system_prompt }]
130154 messages .extend (state .conversation_history )
131155 messages .append ({"role" : "system" , "content" : "Evaluate the quality of this conversation on a scale from 0 to 1, where 0 is poor and 1 is excellent. Consider factors such as coherence, relevance, and engagement. Respond with only a number." })
@@ -140,13 +164,21 @@ def evaluate_state(self, state: DialogueState) -> float:
140164
141165 try :
142166 score = float (response .choices [0 ].message .content .strip ())
143- return max (0 , min (score , 1 )) # Ensure the score is between 0 and 1
167+ score = max (0 , min (score , 1 )) # Ensure the score is between 0 and 1
168+ logger .info (f"State evaluation score: { score } " )
169+ return score
144170 except ValueError :
171+ logger .warning ("Failed to parse evaluation score. Using default value 0.5" )
145172 return 0.5 # Default to a neutral score if parsing fails
146173
147174def chat_with_mcts (system_prompt : str , initial_query : str , client , model : str , num_simulations : int = 2 , exploration_weight : float = 0.2 ,
148175 simulation_depth : int = 1 ) -> str :
176+ logger .info ("Starting chat with MCTS" )
177+ logger .info (f"Parameters: num_simulations={ num_simulations } , exploration_weight={ exploration_weight } , simulation_depth={ simulation_depth } " )
149178 mcts = MCTS (simulation_depth = simulation_depth , exploration_weight = exploration_weight , client = client , model = model )
150179 initial_state = DialogueState (system_prompt , [], initial_query )
180+ logger .info (f"Initial query: { initial_query } " )
151181 final_state = mcts .search (initial_state , num_simulations )
152- return final_state .conversation_history [- 1 ]['content' ] if final_state .conversation_history else ""
182+ response = final_state .conversation_history [- 1 ]['content' ] if final_state .conversation_history else ""
183+ logger .info (f"MCTS chat complete. Final response: { response [:100 ]} ..." )
184+ return response
0 commit comments