From f164f6afefb80e11163112b3d3a72fc5c6c62bb3 Mon Sep 17 00:00:00 2001 From: Brandon Corfman Date: Sun, 12 Oct 2025 14:28:11 -0400 Subject: [PATCH] feat: Add probabilistic transitions with weighted random selection Add support for weighted transitions that allow non-deterministic state transitions based on configurable probabilities. This is useful for game AI, simulations, and randomized workflows. Key features: - Add optional 'weight' parameter to Transition class - Add optional 'random_seed' parameter to StateMachine for deterministic testing - Implement weighted selection in both sync and async engines - Automatically display probability percentages in state diagrams - Full backward compatibility (no weights = original first-match behavior) - Zero/negative weights are ignored - Conditions (guards/validators) work seamlessly with weighted transitions - Complete pickling support Changes: - statemachine/transition.py: Add weight parameter and repr support - statemachine/statemachine.py: Add random_seed and Random instance - statemachine/engines/sync.py: Implement weighted selection logic - statemachine/engines/async_.py: Implement weighted selection logic - statemachine/contrib/diagram.py: Add probability labels to diagrams - README.md: Add probabilistic transitions to features list - docs/transitions.md: Add comprehensive documentation with examples - tests/test_probabilistic_transitions.py: 20 comprehensive tests - tests/examples/game_character_idle_machine.py: Working example All 348 existing tests pass + 20 new tests = 368 total passing tests --- README.md | 1 + docs/transitions.md | 71 +++ statemachine/contrib/diagram.py | 23 +- statemachine/engines/async_.py | 40 +- statemachine/engines/sync.py | 41 +- statemachine/statemachine.py | 13 + statemachine/transition.py | 12 +- tests/examples/game_character_idle_machine.py | 105 ++++ tests/test_probabilistic_transitions.py | 543 ++++++++++++++++++ 9 files changed, 837 insertions(+), 12 deletions(-) create mode 100644 tests/examples/game_character_idle_machine.py create mode 100644 tests/test_probabilistic_transitions.py diff --git a/README.md b/README.md index 7e9274fa..09e382df 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ machines in sync or asynchonous Python codebases. - ✨ **Basic components**: Easily define **States**, **Events**, and **Transitions** to model your logic. - ⚙️ **Actions and handlers**: Attach actions and handlers to states, events, and transitions to control behavior dynamically. - 🛡️ **Conditional transitions**: Implement **Guards** and **Validators** to conditionally control transitions, ensuring they only occur when specific conditions are met. +- 🎲 **Probabilistic transitions**: Define weighted transitions for non-deterministic behavior, perfect for game AI, simulations, and randomized workflows. - 🚀 **Full async support**: Enjoy full asynchronous support. Await events, and dispatch callbacks asynchronously for seamless integration with async codebases. - 🔄 **Full sync support**: Use the same state machine from synchronous codebases without any modifications. - 🎨 **Declarative and simple API**: Utilize a clean, elegant, and readable API to define your state machine, making it easy to maintain and understand. diff --git a/docs/transitions.md b/docs/transitions.md index 32d17236..3b3a8cc7 100644 --- a/docs/transitions.md +++ b/docs/transitions.md @@ -163,6 +163,77 @@ the event name is used to describe the transition. ``` +### Probabilistic transitions + +```{versionadded} 2.5.0 +Probabilistic transitions allow you to define weighted random selection when multiple transitions +share the same event from the same source state. +``` + +Probabilistic transitions are useful for: +- Game AI with non-deterministic behavior +- Simulations requiring randomness +- Idle animations in games +- Randomized workflows + +When you define multiple transitions with the same event and source state, you can assign weights +to control the probability of each transition being chosen: + +```py +>>> class GameCharacter(StateMachine): +... standing = State(initial=True) +... shift_weight = State() +... adjust_hair = State() +... bang_shield = State() +... +... # Weighted transitions: 70/20/10 probability split +... idle = ( +... standing.to(shift_weight, event="idle", weight=70) +... | standing.to(adjust_hair, event="idle", weight=20) +... | standing.to(bang_shield, event="idle", weight=10) +... ) +... +... # Return transitions +... finish = ( +... shift_weight.to(standing) +... | adjust_hair.to(standing) +... | bang_shield.to(standing) +... ) + +``` + +The `weight` parameter controls the relative probability of each transition. In the example above: +- `shift_weight` has a 70% chance (70/(70+20+10)) +- `adjust_hair` has a 20% chance (20/(70+20+10)) +- `bang_shield` has a 10% chance (10/(70+20+10)) + +```{note} +Weights are relative, not absolute. The actual probability is calculated as `weight / sum(all_weights)`. +``` + +**Key behaviors:** + +1. **Deterministic testing**: Use `random_seed` parameter for reproducible behavior: + +```py +>>> character = GameCharacter(random_seed=42) + +``` + +2. **Zero/negative weights ignored**: Transitions with weight ≤ 0 are excluded from selection. + +3. **Mixed weighted/unweighted**: When any transition has a weight, only weighted transitions are considered. + +4. **Conditions still apply**: Guards and validators filter transitions before weight-based selection. + +5. **Backward compatibility**: If no weights are specified, the first matching transition is used (original behavior). + +```{tip} +Probabilistic transitions integrate seamlessly with guards and validators. The weight-based selection +happens first among matching transitions, then conditions are evaluated to determine if the selected +transition can execute. +``` + ## Events An event is an external signal that something has happened. diff --git a/statemachine/contrib/diagram.py b/statemachine/contrib/diagram.py index ee0d14f4..345948f3 100644 --- a/statemachine/contrib/diagram.py +++ b/statemachine/contrib/diagram.py @@ -127,10 +127,31 @@ def _transition_as_edge(self, transition): cond = ", ".join([str(cond) for cond in transition.cond]) if cond: cond = f"\n[{cond}]" + + # Calculate probability label if this transition has a weight + probability_label = "" + if transition.weight is not None and transition.weight > 0: + # Find all transitions from the same source with the same event + same_event_transitions = [ + t for t in transition.source.transitions + if t.match(transition.event) and t.weight is not None and t.weight > 0 + ] + + if len(same_event_transitions) > 1: + # Calculate probability as percentage + total_weight = sum(t.weight for t in same_event_transitions) + probability = (transition.weight / total_weight) * 100 + + # Format as percentage if it's a clean calculation + if probability == int(probability): + probability_label = f" [{int(probability)}%]" + else: + probability_label = f" [{probability:.1f}%]" + return pydot.Edge( transition.source.id, transition.target.id, - label=f"{transition.event}{cond}", + label=f"{transition.event}{probability_label}{cond}", color="blue", fontname=self.font_name, fontsize=self.transition_font_size, diff --git a/statemachine/engines/async_.py b/statemachine/engines/async_.py index 9d2b3f9f..5640c9ea 100644 --- a/statemachine/engines/async_.py +++ b/statemachine/engines/async_.py @@ -83,10 +83,42 @@ async def _trigger(self, trigger_data: TriggerData): return self._sentinel state = self.sm.current_state - for transition in state.transitions: - if not transition.match(trigger_data.event): - continue - + + # Collect all matching transitions + matching_transitions = [ + t for t in state.transitions if t.match(trigger_data.event) + ] + + if not matching_transitions: + if not self.sm.allow_event_without_transition: + raise TransitionNotAllowed(trigger_data.event, state) + return None + + # Check if any transition has a positive weight + weighted_transitions = [ + t for t in matching_transitions if t.weight is not None and t.weight > 0 + ] + + # If we have weighted transitions, select one randomly + if weighted_transitions: + weights = [t.weight for t in weighted_transitions] + selected_transition = self.sm._random.choices(weighted_transitions, weights=weights, k=1)[0] + executed, result = await self._activate(trigger_data, selected_transition) + if executed: + return result + # If the selected transition failed its conditions, try others + for transition in weighted_transitions: + if transition == selected_transition: + continue + executed, result = await self._activate(trigger_data, transition) + if executed: + return result + if not self.sm.allow_event_without_transition: + raise TransitionNotAllowed(trigger_data.event, state) + return None + + # Otherwise, use first-match behavior (backward compatible) + for transition in matching_transitions: executed, result = await self._activate(trigger_data, transition) if not executed: continue diff --git a/statemachine/engines/sync.py b/statemachine/engines/sync.py index 4400cd08..3e476863 100644 --- a/statemachine/engines/sync.py +++ b/statemachine/engines/sync.py @@ -84,14 +84,45 @@ def _trigger(self, trigger_data: TriggerData): return self._sentinel state = self.sm.current_state - for transition in state.transitions: - if not transition.match(trigger_data.event): - continue - + + # Collect all matching transitions + matching_transitions = [ + t for t in state.transitions if t.match(trigger_data.event) + ] + + if not matching_transitions: + if not self.sm.allow_event_without_transition: + raise TransitionNotAllowed(trigger_data.event, state) + return None + + # Check if any transition has a positive weight + weighted_transitions = [ + t for t in matching_transitions if t.weight is not None and t.weight > 0 + ] + + # If we have weighted transitions, select one randomly + if weighted_transitions: + weights = [t.weight for t in weighted_transitions] + selected_transition = self.sm._random.choices(weighted_transitions, weights=weights, k=1)[0] + executed, result = self._activate(trigger_data, selected_transition) + if executed: + return result + # If the selected transition failed its conditions, try others + for transition in weighted_transitions: + if transition == selected_transition: + continue + executed, result = self._activate(trigger_data, transition) + if executed: + return result + if not self.sm.allow_event_without_transition: + raise TransitionNotAllowed(trigger_data.event, state) + return None + + # Otherwise, use first-match behavior (backward compatible) + for transition in matching_transitions: executed, result = self._activate(trigger_data, transition) if not executed: continue - break else: if not self.sm.allow_event_without_transition: diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index e5fe2628..6688c309 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -1,3 +1,4 @@ +import random import warnings from inspect import isawaitable from typing import TYPE_CHECKING @@ -53,6 +54,11 @@ class StateMachine(metaclass=StateMachineMetaclass): listeners: An optional list of objects that provies attributes to be used as callbacks. See :ref:`listeners` for more details. + random_seed: An optional seed for the random number generator used in probabilistic + transitions. When multiple transitions share the same event from the same source + state and have weights assigned, one will be chosen randomly. Setting a seed + ensures deterministic behavior for testing and reproducibility. Default: ``None``. + """ TransitionNotAllowed = TransitionNotAllowed @@ -74,6 +80,7 @@ def __init__( rtc: bool = True, allow_event_without_transition: bool = False, listeners: "List[object] | None" = None, + random_seed: Any = None, ): self.model = model if model is not None else Model() self.state_field = state_field @@ -81,6 +88,7 @@ def __init__( self.allow_event_without_transition = allow_event_without_transition self._callbacks = CallbacksRegistry() self._states_for_instance: Dict[State, State] = {} + self._random = random.Random(random_seed) self._listeners: Dict[Any, Any] = {} """Listeners that provides attributes to be used as callbacks.""" @@ -130,17 +138,22 @@ def __repr__(self): def __getstate__(self): state = self.__dict__.copy() state["_rtc"] = self._engine._rtc + state["_random_state"] = self._random.getstate() del state["_callbacks"] del state["_states_for_instance"] del state["_engine"] + del state["_random"] return state def __setstate__(self, state): listeners = state.pop("_listeners") rtc = state.pop("_rtc") + random_state = state.pop("_random_state") self.__dict__.update(state) self._callbacks = CallbacksRegistry() self._states_for_instance: Dict[State, State] = {} + self._random = random.Random() + self._random.setstate(random_state) self._listeners: Dict[Any, Any] = {} diff --git a/statemachine/transition.py b/statemachine/transition.py index a9044f0f..9abbb962 100644 --- a/statemachine/transition.py +++ b/statemachine/transition.py @@ -34,6 +34,10 @@ class Transition: before the transition is executed. after (Optional[Union[str, Callable, List[Callable]]]): The callbacks to be invoked after the transition is executed. + weight (Optional[float]): The weight for probabilistic transition selection. When multiple + transitions share the same event from the same source state and at least one has a + positive weight, the transition will be chosen randomly based on the weights. + Default ``None``. """ def __init__( @@ -48,10 +52,12 @@ def __init__( on=None, before=None, after=None, + weight=None, ): self.source = source self.target = target self.internal = internal + self.weight = weight if internal and source is not target: raise InvalidDefinition("Internal transitions should be self-transitions.") @@ -75,9 +81,10 @@ def __init__( ) def __repr__(self): + weight_str = f", weight={self.weight!r}" if self.weight is not None else "" return ( f"{type(self).__name__}({self.source!r}, {self.target!r}, event={self.event!r}, " - f"internal={self.internal!r})" + f"internal={self.internal!r}{weight_str})" ) def __str__(self): @@ -137,8 +144,9 @@ def _copy_with_args(self, **kwargs): target = kwargs.pop("target", self.target) event = kwargs.pop("event", self.event) internal = kwargs.pop("internal", self.internal) + weight = kwargs.pop("weight", self.weight) new_transition = Transition( - source=source, target=target, event=event, internal=internal, **kwargs + source=source, target=target, event=event, internal=internal, weight=weight, **kwargs ) for spec in self._specs: new_spec = deepcopy(spec) diff --git a/tests/examples/game_character_idle_machine.py b/tests/examples/game_character_idle_machine.py new file mode 100644 index 00000000..6e89bf11 --- /dev/null +++ b/tests/examples/game_character_idle_machine.py @@ -0,0 +1,105 @@ +""" +Example: Game Character Idle Animations with Probabilistic Transitions +====================================================================== + +This example demonstrates how to use weighted transitions to create +realistic idle animations for a game character. The character will randomly +choose different idle animations based on weighted probabilities. + +The character has a standing state, and when idle, will probabilistically +transition to different animations: +- 70% chance: Shift weight from one foot to the other +- 20% chance: Run hand through hair +- 10% chance: Bang sword against shield + +After performing an idle animation, the character returns to standing. +""" + +from statemachine import State, StateMachine + + +class GameCharacter(StateMachine): + """A game character with weighted idle animations.""" + + # States + standing = State("Standing", initial=True) + shift_weight = State("Shifting Weight") + adjust_hair = State("Adjusting Hair") + bang_shield = State("Banging Shield") + + # Weighted idle transitions - 70/20/10 split + idle = ( + standing.to(shift_weight, event="idle", weight=70) + | standing.to(adjust_hair, event="idle", weight=20) + | standing.to(bang_shield, event="idle", weight=10) + ) + + # Return to standing after each animation + finish = ( + shift_weight.to(standing) + | adjust_hair.to(standing) + | bang_shield.to(standing) + ) + + def __init__(self, random_seed=None): + """Initialize the character. + + Args: + random_seed: Optional seed for deterministic behavior in tests. + """ + self.animation_log = [] + super().__init__(random_seed=random_seed) + + def on_enter_shift_weight(self): + """Called when entering shift_weight state.""" + self.animation_log.append("shift_weight") + print(" → Character shifts weight from one foot to the other") + + def on_enter_adjust_hair(self): + """Called when entering adjust_hair state.""" + self.animation_log.append("adjust_hair") + print(" → Character runs hand through hair") + + def on_enter_bang_shield(self): + """Called when entering bang_shield state.""" + self.animation_log.append("bang_shield") + print(" → Character bangs sword against shield") + + +def main(): + """Run the example.""" + print("Game Character Idle Animations Example") + print("=" * 50) + print() + + # Create a character with a seed for reproducible demonstration + character = GameCharacter(random_seed=42) + + print("Current state:", character.current_state.name) + print() + + # Trigger idle animations 10 times + print("Triggering 10 idle animations:") + print() + + for i in range(1000): + print(f"Idle #{i+1}:") + character.idle() + character.finish() + + print() + print("Animation summary:") + from collections import Counter + + counts = Counter(character.animation_log) + for anim, count in counts.most_common(): + percentage = (count / len(character.animation_log)) * 100 + print(f" {anim}: {count} times ({percentage:.0f}%)") + + print() + print("Expected distribution: shift_weight ~70%, adjust_hair ~20%, bang_shield ~10%") + + +if __name__ == "__main__": + main() + diff --git a/tests/test_probabilistic_transitions.py b/tests/test_probabilistic_transitions.py new file mode 100644 index 00000000..0a2c54ca --- /dev/null +++ b/tests/test_probabilistic_transitions.py @@ -0,0 +1,543 @@ +"""Tests for probabilistic transitions with weighted random selection.""" +import pickle +from collections import Counter + +import pytest + +from statemachine import State +from statemachine import StateMachine + + +# Test Fixtures + + +@pytest.fixture +def weighted_idle_machine(): + """A game character with weighted idle animations.""" + + class CharacterMachine(StateMachine): + standing = State(initial=True) + shift_weight = State() + adjust_hair = State() + bang_shield = State() + + # Weighted idle transitions from standing to itself + idle = ( + standing.to(shift_weight, event="idle", weight=70) + | standing.to(adjust_hair, event="idle", weight=20) + | standing.to(bang_shield, event="idle", weight=10) + ) + + # Return transitions + finish = ( + shift_weight.to(standing) + | adjust_hair.to(standing) + | bang_shield.to(standing) + ) + + def __init__(self, random_seed=None): + self.animations = [] + super().__init__(random_seed=random_seed) + + def on_enter_shift_weight(self): + self.animations.append("shift_weight") + + def on_enter_adjust_hair(self): + self.animations.append("adjust_hair") + + def on_enter_bang_shield(self): + self.animations.append("bang_shield") + + return CharacterMachine + + +# Module-level class for pickle test +class SimpleWeightedMachine(StateMachine): + """Simple machine with two weighted transitions.""" + + a = State(initial=True) + b = State() + c = State() + + go = a.to(b, event="go", weight=75) | a.to(c, event="go", weight=25) + + +@pytest.fixture +def simple_weighted_machine(): + """Simple machine with two weighted transitions.""" + return SimpleWeightedMachine + + +@pytest.fixture +def mixed_weighted_machine(): + """Machine with both weighted and unweighted transitions.""" + + class MixedMachine(StateMachine): + start = State(initial=True) + weighted_a = State() + weighted_b = State() + unweighted = State() + + # Mix of weighted and unweighted transitions + mixed_event = ( + start.to(weighted_a, event="mixed_event", weight=50) + | start.to(weighted_b, event="mixed_event", weight=50) + | start.to(unweighted, event="mixed_event") # No weight + ) + + return MixedMachine + + +@pytest.fixture +def conditional_weighted_machine(): + """Machine with weighted transitions that also have conditions.""" + + class ConditionalMachine(StateMachine): + start = State(initial=True) + allowed_dest = State() + blocked_dest = State() + + go = ( + start.to(allowed_dest, event="go", weight=50, cond="is_allowed") + | start.to(blocked_dest, event="go", weight=50) + ) + + def __init__(self, allow=True, random_seed=None): + self.allow = allow + super().__init__(random_seed=random_seed) + + def is_allowed(self): + return self.allow + + return ConditionalMachine + + +@pytest.fixture +def zero_negative_weight_machine(): + """Machine with zero and negative weights.""" + + class ZeroNegativeMachine(StateMachine): + start = State(initial=True) + valid_a = State() + valid_b = State() + zero_weight = State() + negative_weight = State() + + go = ( + start.to(valid_a, event="go", weight=50) + | start.to(valid_b, event="go", weight=50) + | start.to(zero_weight, event="go", weight=0) + | start.to(negative_weight, event="go", weight=-10) + ) + + return ZeroNegativeMachine + + +@pytest.fixture +def no_weight_machine(): + """Machine with no weights (backward compatibility test).""" + + class NoWeightMachine(StateMachine): + start = State(initial=True) + middle = State() + end = State() + + advance = start.to(middle) | middle.to(end) + + return NoWeightMachine + + +# Test Cases + + +def test_deterministic_weighted_selection(simple_weighted_machine): + """Test that weighted selection is deterministic with a seed.""" + sm1 = simple_weighted_machine(random_seed=42) + sm2 = simple_weighted_machine(random_seed=42) + + results1 = [] + results2 = [] + + for _ in range(10): + sm1.send("go") + results1.append(sm1.current_state.id) + # Reset to initial state + sm1.current_state = sm1.a + + sm2.send("go") + results2.append(sm2.current_state.id) + # Reset to initial state + sm2.current_state = sm2.a + + # Results should be identical with same seed + assert results1 == results2 + + +def test_weighted_distribution(simple_weighted_machine): + """Test that weighted transitions follow the expected distribution.""" + sm = simple_weighted_machine(random_seed=12345) + + results = Counter() + num_trials = 1000 + + for _ in range(num_trials): + sm.send("go") + results[sm.current_state.id] += 1 + # Reset to initial state + sm.current_state = sm.a + + # With 75/25 split and 1000 trials, expect roughly 750/250 + # Allow for statistical variance (use generous bounds) + assert 700 <= results["b"] <= 800 + assert 200 <= results["c"] <= 300 + + +def test_weighted_idle_animations(weighted_idle_machine): + """Test game character idle animation selection.""" + sm = weighted_idle_machine(random_seed=99) + + # Trigger idle multiple times + for _ in range(10): + sm.idle() + # Return to standing + sm.finish() + + # Should have 10 animations recorded + assert len(sm.animations) == 10 + + # Count distribution + animation_counts = Counter(sm.animations) + + # With weights 70/20/10, shift_weight should be most common + # This is probabilistic, so we just check we have variety + assert "shift_weight" in animation_counts + assert len(animation_counts) >= 2 # At least 2 different animations + + +def test_zero_and_negative_weights_ignored(zero_negative_weight_machine): + """Test that zero and negative weights are ignored.""" + sm = zero_negative_weight_machine(random_seed=777) + + results = Counter() + num_trials = 100 + + for _ in range(num_trials): + sm.send("go") + results[sm.current_state.id] += 1 + # Reset to initial state + sm.current_state = sm.start + + # Only valid_a and valid_b should be reached + assert results["valid_a"] > 0 + assert results["valid_b"] > 0 + assert results["zero_weight"] == 0 + assert results["negative_weight"] == 0 + + +def test_mixed_weighted_and_unweighted(mixed_weighted_machine): + """Test that when weights exist, only weighted transitions are considered.""" + sm = mixed_weighted_machine(random_seed=555) + + results = Counter() + num_trials = 100 + + for _ in range(num_trials): + sm.send("mixed_event") + results[sm.current_state.id] += 1 + # Reset to initial state + sm.current_state = sm.start + + # Only weighted_a and weighted_b should be reached + assert results["weighted_a"] > 0 + assert results["weighted_b"] > 0 + assert results["unweighted"] == 0 # Unweighted should be ignored + + +def test_conditions_apply_to_weighted_transitions(conditional_weighted_machine): + """Test that conditions still filter weighted transitions.""" + # First test with allow=True (both transitions can be chosen) + sm_allowed = conditional_weighted_machine(allow=True, random_seed=111) + + results_allowed = Counter() + for _ in range(50): + sm_allowed.send("go") + results_allowed[sm_allowed.current_state.id] += 1 + sm_allowed.current_state = sm_allowed.start + + # Both destinations should be reachable + assert results_allowed["allowed_dest"] > 0 + assert results_allowed["blocked_dest"] > 0 + + # Now test with allow=False (first transition blocked by condition) + sm_blocked = conditional_weighted_machine(allow=False, random_seed=222) + + results_blocked = Counter() + for _ in range(50): + sm_blocked.send("go") + results_blocked[sm_blocked.current_state.id] += 1 + sm_blocked.current_state = sm_blocked.start + + # Only blocked_dest should be reachable + assert results_blocked["allowed_dest"] == 0 + assert results_blocked["blocked_dest"] == 50 + + +def test_no_weights_backward_compatibility(no_weight_machine): + """Test that machines without weights work as before.""" + sm = no_weight_machine() + + # Should transition start -> middle + sm.advance() + assert sm.current_state.id == "middle" + + # Should transition middle -> end + sm.advance() + assert sm.current_state.id == "end" + + +def test_transition_with_weight_parameter(): + """Test that Transition accepts weight parameter.""" + from statemachine.transition import Transition + + source = State("Source", initial=True) + target = State("Target") + + # Create transition with weight + transition = Transition(source, target, event="go", weight=75) + + assert transition.weight == 75 + + +def test_transition_without_weight_parameter(): + """Test that Transition works without weight (default None).""" + from statemachine.transition import Transition + + source = State("Source", initial=True) + target = State("Target") + + # Create transition without weight + transition = Transition(source, target, event="go") + + assert transition.weight is None + + +def test_statemachine_random_seed_parameter(): + """Test that StateMachine accepts random_seed parameter.""" + + class TestMachine(StateMachine): + a = State(initial=True) + b = State() + go = a.to(b) + + sm = TestMachine(random_seed=12345) + assert sm._random is not None + + +def test_pickle_state_machine_with_weights(simple_weighted_machine): + """Test that state machines with weights can be pickled and unpickled.""" + sm1 = simple_weighted_machine(random_seed=999) + + # Trigger a transition + sm1.send("go") + state_after_first = sm1.current_state.id + + # Pickle and unpickle + pickled = pickle.dumps(sm1) + sm2 = pickle.loads(pickled) + + # State should be preserved + assert sm2.current_state.id == state_after_first + + # Reset both to initial state + sm1.current_state = sm1.a + sm2.current_state = sm2.a + + # Random state should be preserved, so next transitions should match + sm1.send("go") + sm2.send("go") + + assert sm1.current_state.id == sm2.current_state.id + + +# Async Tests + + +@pytest.fixture +def async_weighted_machine(): + """Async machine with weighted transitions.""" + + class AsyncWeightedMachine(StateMachine): + start = State(initial=True) + dest_a = State() + dest_b = State() + + go = start.to(dest_a, event="go", weight=60) | start.to( + dest_b, event="go", weight=40 + ) + + async def on_enter_dest_a(self): + self.entered = "dest_a" + + async def on_enter_dest_b(self): + self.entered = "dest_b" + + return AsyncWeightedMachine + + +async def test_async_weighted_selection(async_weighted_machine): + """Test that weighted selection works with async state machines.""" + sm = async_weighted_machine(random_seed=42) + + results = Counter() + num_trials = 100 + + for _ in range(num_trials): + await sm.go() + results[sm.current_state.id] += 1 + # Reset to initial state + sm.current_state = sm.start + + # Both destinations should be reached + assert results["dest_a"] > 0 + assert results["dest_b"] > 0 + + # With 60/40 split, dest_a should be more common + assert results["dest_a"] > results["dest_b"] + + +def test_async_weighted_from_sync_context(async_weighted_machine): + """Test that async weighted machine can be used from sync context.""" + sm = async_weighted_machine(random_seed=42) + + # Should work from sync context + sm.go() + assert sm.current_state.id in ["dest_a", "dest_b"] + + +def test_weight_in_transition_repr(): + """Test that weight appears in transition repr when present.""" + from statemachine.transition import Transition + + source = State("Source", initial=True) + target = State("Target") + + transition = Transition(source, target, event="go", weight=75) + repr_str = repr(transition) + + # Should include weight in representation + assert "weight=75" in repr_str + + +def test_all_zero_weights_falls_back_to_first_match(): + """Test that when all weights are zero/negative, falls back to first match.""" + + class AllZeroWeightMachine(StateMachine): + start = State(initial=True) + first_dest = State() + second_dest = State() + + go = ( + start.to(first_dest, event="go", weight=0) + | start.to(second_dest, event="go", weight=0) + ) + + sm = AllZeroWeightMachine(random_seed=42) + + # With all zero weights, should fall back to first match behavior + for _ in range(10): + sm.send("go") + # First transition in order should be selected + assert sm.current_state.id == "first_dest" + sm.current_state = sm.start + + +def test_single_weighted_transition(): + """Test that a single weighted transition works correctly.""" + + class SingleWeightedMachine(StateMachine): + start = State(initial=True) + end = State() + + go = start.to(end, event="go", weight=100) + + sm = SingleWeightedMachine(random_seed=42) + sm.send("go") + + # Should always transition to end + assert sm.current_state.id == "end" + + +# Diagram Tests + + +def test_diagram_shows_probability_labels(simple_weighted_machine): + """Test that diagrams show probability labels on weighted transitions.""" + sm = simple_weighted_machine(random_seed=42) + + # Generate the diagram + graph = sm._graph() + dot_string = graph.to_string() + + # Check that probability labels are present + assert "[75%]" in dot_string, "Expected 75% probability label in diagram" + assert "[25%]" in dot_string, "Expected 25% probability label in diagram" + + # Check that the event name is still present + assert "go" in dot_string + + +def test_diagram_without_weights_no_probability_labels(no_weight_machine): + """Test that diagrams without weights don't show probability labels.""" + sm = no_weight_machine() + + # Generate the diagram + graph = sm._graph() + dot_string = graph.to_string() + + # Should not have percentage labels + assert "[" not in dot_string or "]" not in dot_string or "%" not in dot_string + + +def test_diagram_with_single_weighted_transition(): + """Test diagram with only one weighted transition (no probability shown).""" + + class SingleWeightMachine(StateMachine): + start = State(initial=True) + end = State() + + go = start.to(end, event="go", weight=100) + + sm = SingleWeightMachine() + graph = sm._graph() + dot_string = graph.to_string() + + # Single weighted transition should not show probability + # (no ambiguity, always 100%) + # The label should just be "go" without percentage + assert "go" in dot_string + + +def test_diagram_probability_calculation(): + """Test that diagram calculates correct probabilities for complex weights.""" + + class ComplexWeightMachine(StateMachine): + start = State(initial=True) + option_a = State() + option_b = State() + option_c = State() + + choose = ( + start.to(option_a, event="choose", weight=10) + | start.to(option_b, event="choose", weight=20) + | start.to(option_c, event="choose", weight=70) + ) + + sm = ComplexWeightMachine() + graph = sm._graph() + dot_string = graph.to_string() + + # Check that probabilities are correctly calculated + assert "[10%]" in dot_string + assert "[20%]" in dot_string + assert "[70%]" in dot_string +