diff --git a/mesa/examples/advanced/wolf_sheep/app.py b/mesa/examples/advanced/wolf_sheep/app.py index 94261021b6a..fa66d2bf11e 100644 --- a/mesa/examples/advanced/wolf_sheep/app.py +++ b/mesa/examples/advanced/wolf_sheep/app.py @@ -1,5 +1,6 @@ from mesa.examples.advanced.wolf_sheep.agents import GrassPatch, Sheep, Wolf from mesa.examples.advanced.wolf_sheep.model import WolfSheep +from mesa.experimental.devs import ABMSimulator from mesa.visualization import ( Slider, SolaraViz, @@ -36,7 +37,11 @@ def wolf_sheep_portrayal(agent): model_params = { - # The following line is an example to showcase StaticText. + "seed": { + "type": "InputText", + "value": 42, + "label": "Random Seed", + }, "grass": { "type": "Select", "value": True, @@ -59,26 +64,32 @@ def wolf_sheep_portrayal(agent): } -def post_process(ax): +def post_process_space(ax): ax.set_aspect("equal") ax.set_xticks([]) ax.set_yticks([]) +def post_process_lines(ax): + ax.legend(loc="center left", bbox_to_anchor=(1, 0.9)) + + space_component = make_space_component( - wolf_sheep_portrayal, draw_grid=False, post_process=post_process + wolf_sheep_portrayal, draw_grid=False, post_process=post_process_space ) lineplot_component = make_plot_component( - {"Wolves": "tab:orange", "Sheep": "tab:cyan", "Grass": "tab:green"} + {"Wolves": "tab:orange", "Sheep": "tab:cyan", "Grass": "tab:green"}, + post_process=post_process_lines, ) -model = WolfSheep(grass=True) - +simulator = ABMSimulator() +model = WolfSheep(simulator, grass=True) page = SolaraViz( model, components=[space_component, lineplot_component], model_params=model_params, name="Wolf Sheep", + simulator=simulator, ) page # noqa diff --git a/mesa/examples/advanced/wolf_sheep/model.py b/mesa/examples/advanced/wolf_sheep/model.py index 6f8887d0491..cc6ec6acc9f 100644 --- a/mesa/examples/advanced/wolf_sheep/model.py +++ b/mesa/examples/advanced/wolf_sheep/model.py @@ -30,8 +30,8 @@ class WolfSheep(Model): def __init__( self, - height=20, width=20, + height=20, initial_sheep=100, initial_wolves=50, sheep_reproduce=0.04, diff --git a/mesa/visualization/solara_viz.py b/mesa/visualization/solara_viz.py index 023d449faf2..9819e1c94e1 100644 --- a/mesa/visualization/solara_viz.py +++ b/mesa/visualization/solara_viz.py @@ -32,6 +32,7 @@ import solara import mesa.visualization.components.altair_components as components_altair +from mesa.experimental.devs.simulator import Simulator from mesa.visualization.user_param import Slider from mesa.visualization.utils import force_update, update_counter @@ -42,10 +43,12 @@ @solara.component def SolaraViz( model: Model | solara.Reactive[Model], + *, components: list[reacton.core.Component] | list[Callable[[Model], reacton.core.Component]] | Literal["default"] = "default", play_interval: int = 100, + simulator: Simulator | None = None, model_params=None, name: str | None = None, ): @@ -65,6 +68,7 @@ def SolaraViz( Defaults to "default", which uses the default Altair space visualization. play_interval (int, optional): Interval for playing the model steps in milliseconds. This controls the speed of the model's automatic stepping. Defaults to 100 ms. + simulator: A simulator that controls the model (optional) model_params (dict, optional): Parameters for (re-)instantiating a model. Can include user-adjustable parameters and fixed parameters. Defaults to None. name (str | None, optional): Name of the visualization. Defaults to the models class name. @@ -92,21 +96,6 @@ def SolaraViz( if not isinstance(model, solara.Reactive): model = solara.use_reactive(model) # noqa: SH102, RUF100 - def connect_to_model(): - # Patch the step function to force updates - original_step = model.value.step - - def step(): - original_step() - force_update() - - model.value.step = step - # Add a trigger to model itself - model.value.force_update = force_update - force_update() - - solara.use_effect(connect_to_model, [model.value]) - # set up reactive model_parameters shared by ModelCreator and ModelController reactive_model_parameters = solara.use_reactive({}) @@ -115,11 +104,19 @@ def step(): with solara.Sidebar(), solara.Column(): with solara.Card("Controls"): - ModelController( - model, - model_parameters=reactive_model_parameters, - play_interval=play_interval, - ) + if not isinstance(simulator, Simulator): + ModelController( + model, + model_parameters=reactive_model_parameters, + play_interval=play_interval, + ) + else: + SimulatorController( + model, + simulator, + model_parameters=reactive_model_parameters, + play_interval=play_interval, + ) with solara.Card("Model Parameters"): ModelCreator( model, model_params, model_parameters=reactive_model_parameters @@ -207,6 +204,7 @@ def do_step(): """Advance the model by one step.""" model.value.step() running.value = model.value.running + force_update() def do_reset(): """Reset the model to its initial state.""" @@ -234,6 +232,73 @@ def do_play_pause(): ) +@solara.component +def SimulatorController( + model: solara.Reactive[Model], + simulator, + *, + model_parameters: dict | solara.Reactive[dict] = None, + play_interval: int = 100, +): + """Create controls for model execution (step, play, pause, reset). + + Args: + model: Reactive model instance + simulator: Simulator instance + model_parameters: Reactive parameters for (re-)instantiating a model. + play_interval: Interval for playing the model steps in milliseconds. + + """ + playing = solara.use_reactive(False) + running = solara.use_reactive(True) + if model_parameters is None: + model_parameters = {} + model_parameters = solara.use_reactive(model_parameters) + + async def step(): + while playing.value and running.value: + await asyncio.sleep(play_interval / 1000) + do_step() + + solara.lab.use_task( + step, dependencies=[playing.value, running.value], prefer_threaded=False + ) + + def do_step(): + """Advance the model by one step.""" + simulator.run_for(1) + running.value = model.value.running + force_update() + + def do_reset(): + """Reset the model to its initial state.""" + playing.value = False + running.value = True + simulator.reset() + model.value = model.value = model.value.__class__( + simulator, **model_parameters.value + ) + + def do_play_pause(): + """Toggle play/pause.""" + playing.value = not playing.value + + with solara.Row(justify="space-between"): + solara.Button(label="Reset", color="primary", on_click=do_reset) + solara.Button( + label="▶" if not playing.value else "❚❚", + color="primary", + on_click=do_play_pause, + disabled=not running.value, + ) + solara.Button( + label="Step", + color="primary", + on_click=do_step, + disabled=playing.value or not running.value, + ) + + def split_model_params(model_params): """Split model parameters into user-adjustable and fixed parameters. @@ -324,9 +389,7 @@ def ModelCreator( } def on_change(name, value): - new_model_parameters = {**model_parameters.value, name: value} - model.value = model.value.__class__(**new_model_parameters) - model_parameters.value = new_model_parameters + model_parameters.value = {**model_parameters.value, name: value} UserInputs(user_params, on_change=on_change)