Skip to content

Commit 27a3038

Browse files
committed
Update mcts.py
Add logger statements
1 parent 57f9310 commit 27a3038

File tree

1 file changed

+49
-17
lines changed

1 file changed

+49
-17
lines changed

mcts.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,81 +34,103 @@ def __init__(self, simulation_depth, exploration_weight, client, model):
3434
self.model = model
3535

3636
def select(self, node: MCTSNode) -> MCTSNode:
37+
logger.debug(f"Selecting node. Current node visits: {node.visits}, value: {node.value}")
3738
if not node.children:
39+
logger.debug("Node has no children. Returning current node.")
3840
return node
39-
return max(node.children, key=lambda c: c.value / (c.visits + 1e-8) + self.exploration_weight * np.sqrt(np.log(node.visits + 1) / (c.visits + 1e-8)))
41+
selected_node = max(node.children, key=lambda c: c.value / (c.visits + 1e-8) + self.exploration_weight * np.sqrt(np.log(node.visits + 1) / (c.visits + 1e-8)))
42+
logger.debug(f"Selected child node. Visits: {selected_node.visits}, Value: {selected_node.value}")
43+
return selected_node
4044

4145
def expand(self, node: MCTSNode) -> MCTSNode:
46+
logger.debug(f"Expanding node. Current state: {node.state}")
4247
actions = self.generate_actions(node.state)
43-
for action in actions:
48+
logger.debug(f"Generated {len(actions)} possible actions")
49+
for i, action in enumerate(actions):
4450
new_state = self.apply_action(node.state, action)
4551
child = MCTSNode(new_state, parent=node)
4652
node.children.append(child)
4753
self.graph.add_edge(id(node), id(child))
4854
self.node_labels[id(child)] = f"Visits: {child.visits}\nValue: {child.value:.2f}"
49-
return random.choice(node.children)
55+
logger.debug(f"Created child node {i+1}. Action: {action[:50]}...")
56+
selected_child = random.choice(node.children)
57+
logger.debug(f"Randomly selected child node for simulation. Visits: {selected_child.visits}, Value: {selected_child.value}")
58+
return selected_child
5059

5160
def simulate(self, node: MCTSNode) -> float:
61+
logger.debug(f"Starting simulation from node. Current query: {node.state.current_query}")
5262
state = node.state
53-
for _ in range(self.simulation_depth):
63+
for i in range(self.simulation_depth):
5464
if self.is_terminal(state):
65+
logger.debug(f"Reached terminal state at depth {i}")
5566
break
5667
action = random.choice(self.generate_actions(state))
5768
state = self.apply_action(state, action)
58-
return self.evaluate_state(state)
69+
logger.debug(f"Simulation step {i+1}. Action: {action[:50]}...")
70+
value = self.evaluate_state(state)
71+
logger.debug(f"Simulation complete. Final state value: {value}")
72+
return value
5973

6074
def backpropagate(self, node: MCTSNode, value: float):
75+
logger.debug(f"Starting backpropagation. Initial value: {value}")
6176
while node:
6277
node.visits += 1
6378
node.value += value
6479
self.node_labels[id(node)] = f"Visits: {node.visits}\nValue: {node.value:.2f}"
80+
logger.debug(f"Updated node. Visits: {node.visits}, New value: {node.value}")
6581
node = node.parent
6682

6783
def search(self, initial_state: DialogueState, num_simulations: int) -> DialogueState:
84+
logger.debug(f"Starting MCTS search with {num_simulations} simulations")
6885
if not self.root:
6986
self.root = MCTSNode(initial_state)
7087
self.graph.add_node(id(self.root))
7188
self.node_labels[id(self.root)] = f"Root\nVisits: 0\nValue: 0.00"
89+
logger.debug("Created root node")
7290

73-
for _ in range(num_simulations):
91+
for i in range(num_simulations):
92+
logger.debug(f"Starting simulation {i+1}")
7493
node = self.select(self.root)
7594
if not self.is_terminal(node.state):
7695
node = self.expand(node)
7796
value = self.simulate(node)
7897
self.backpropagate(node, value)
7998

80-
return max(self.root.children, key=lambda c: c.visits).state
99+
best_child = max(self.root.children, key=lambda c: c.visits)
100+
logger.debug(f"Search complete. Best child node: Visits: {best_child.visits}, Value: {best_child.value}")
101+
return best_child.state
81102

82103
def generate_actions(self, state: DialogueState) -> List[str]:
104+
logger.debug("Generating actions for current state")
83105
messages = [{"role": "system", "content": state.system_prompt}]
84106
messages.extend(state.conversation_history)
85107
messages.append({"role": "user", "content": state.current_query})
86-
# messages.append({"role": "system", "content": "Generate 3 possible responses to the user's query. Each response should be on a new line starting with 'Response:'."})
87108

88109
completions = []
89110
n = 3
90111

112+
logger.info(f"Requesting {n} completions from the model")
91113
response = self.client.chat.completions.create(
92-
model= self.model,
114+
model=self.model,
93115
messages=messages,
94116
max_tokens=4096,
95117
n=n,
96118
temperature=1
97119
)
98120
completions = [choice.message.content.strip() for choice in response.choices]
99-
# suggested_responses = response.choices[0].message.content.split("Response:")
100-
# return [resp.strip() for resp in suggested_responses if resp.strip()]
121+
logger.info(f"Received {len(completions)} completions from the model")
101122
return completions
102123

103124
def apply_action(self, state: DialogueState, action: str) -> DialogueState:
125+
logger.info(f"Applying action: {action[:50]}...")
104126
new_history = state.conversation_history.copy()
105127
new_history.append({"role": "assistant", "content": action})
106128

107129
messages = [{"role": "system", "content": state.system_prompt}]
108130
messages.extend(new_history)
109131
messages.append({"role": "system", "content": "Based on this conversation, what might the user ask or say next? Provide a likely user query."})
110132

111-
133+
logger.info("Requesting next user query from the model")
112134
response = self.client.chat.completions.create(
113135
model=self.model,
114136
messages=messages,
@@ -118,14 +140,16 @@ def apply_action(self, state: DialogueState, action: str) -> DialogueState:
118140
)
119141

120142
next_query = response.choices[0].message.content
143+
logger.info(f"Generated next user query: {next_query}")
121144
return DialogueState(state.system_prompt, new_history, next_query)
122145

123146
def is_terminal(self, state: DialogueState) -> bool:
124-
# Consider the state terminal if the conversation has reached a natural conclusion
125-
# or if it has exceeded a certain number of turns
126-
return len(state.conversation_history) > 10 or "goodbye" in state.current_query.lower()
147+
is_terminal = len(state.conversation_history) > 10 or "goodbye" in state.current_query.lower()
148+
logger.info(f"Checking if state is terminal: {is_terminal}")
149+
return is_terminal
127150

128151
def evaluate_state(self, state: DialogueState) -> float:
152+
logger.info("Evaluating current state")
129153
messages = [{"role": "system", "content": state.system_prompt}]
130154
messages.extend(state.conversation_history)
131155
messages.append({"role": "system", "content": "Evaluate the quality of this conversation on a scale from 0 to 1, where 0 is poor and 1 is excellent. Consider factors such as coherence, relevance, and engagement. Respond with only a number."})
@@ -140,13 +164,21 @@ def evaluate_state(self, state: DialogueState) -> float:
140164

141165
try:
142166
score = float(response.choices[0].message.content.strip())
143-
return max(0, min(score, 1)) # Ensure the score is between 0 and 1
167+
score = max(0, min(score, 1)) # Ensure the score is between 0 and 1
168+
logger.info(f"State evaluation score: {score}")
169+
return score
144170
except ValueError:
171+
logger.warning("Failed to parse evaluation score. Using default value 0.5")
145172
return 0.5 # Default to a neutral score if parsing fails
146173

147174
def chat_with_mcts(system_prompt: str, initial_query: str, client, model: str, num_simulations: int = 2, exploration_weight: float = 0.2,
148175
simulation_depth: int = 1) -> str:
176+
logger.info("Starting chat with MCTS")
177+
logger.info(f"Parameters: num_simulations={num_simulations}, exploration_weight={exploration_weight}, simulation_depth={simulation_depth}")
149178
mcts = MCTS(simulation_depth=simulation_depth, exploration_weight=exploration_weight, client=client, model=model)
150179
initial_state = DialogueState(system_prompt, [], initial_query)
180+
logger.info(f"Initial query: {initial_query}")
151181
final_state = mcts.search(initial_state, num_simulations)
152-
return final_state.conversation_history[-1]['content'] if final_state.conversation_history else ""
182+
response = final_state.conversation_history[-1]['content'] if final_state.conversation_history else ""
183+
logger.info(f"MCTS chat complete. Final response: {response[:100]}...")
184+
return response

0 commit comments

Comments
 (0)