diff --git a/benchmarks/BoltzmannWealth/boltzmann_wealth.py b/benchmarks/BoltzmannWealth/boltzmann_wealth.py index 9faffa2f0c3..f6b1bf3d6f9 100644 --- a/benchmarks/BoltzmannWealth/boltzmann_wealth.py +++ b/benchmarks/BoltzmannWealth/boltzmann_wealth.py @@ -38,7 +38,6 @@ def __init__(self, seed=None, n=100, width=10, height=10): self.datacollector.collect(self) def step(self): - self._advance_time() self.agents.shuffle().do("step") # collect data self.datacollector.collect(self) diff --git a/mesa/batchrunner.py b/mesa/batchrunner.py index 80ba23cafdf..342d3196707 100644 --- a/mesa/batchrunner.py +++ b/mesa/batchrunner.py @@ -132,14 +132,14 @@ def _model_run_func( """ run_id, iteration, kwargs = run model = model_cls(**kwargs) - while model.running and model._steps <= max_steps: + while model.running and model.steps <= max_steps: model.step() data = [] - steps = list(range(0, model._steps, data_collection_period)) - if not steps or steps[-1] != model._steps - 1: - steps.append(model._steps - 1) + steps = list(range(0, model.steps, data_collection_period)) + if not steps or steps[-1] != model.steps - 1: + steps.append(model.steps - 1) for step in steps: model_data, all_agents_data = _collect_data(model, step) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index 712482af8e2..5a82e71c20d 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -180,7 +180,7 @@ def _record_agents(self, model): rep_funcs = self.agent_reporters.values() def get_reports(agent): - _prefix = (agent.model._steps, agent.unique_id) + _prefix = (agent.model.steps, agent.unique_id) reports = tuple(rep(agent) for rep in rep_funcs) return _prefix + reports @@ -216,7 +216,7 @@ def collect(self, model): if self.agent_reporters: agent_records = self._record_agents(model) - self._agent_records[model._steps] = list(agent_records) + self._agent_records[model.steps] = list(agent_records) def add_table_row(self, table_name, row, ignore_missing=False): """Add a row dictionary to a specific table. diff --git a/mesa/model.py b/mesa/model.py index 6261c79b2d7..9387e6d1031 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -17,8 +17,6 @@ from mesa.agent import Agent, AgentSet from mesa.datacollection import DataCollector -TimeT = float | int - class Model: """Base class for models in the Mesa ABM library. @@ -35,6 +33,8 @@ class Model: Properties: agents: An AgentSet containing all agents in the model agent_types: A list of different agent types present in the model. + steps: An integer representing the number of steps the model has taken. + It increases automatically at the start of each step() call. Methods: get_agents_of_type: Returns an AgentSet of agents of the specified type. @@ -62,10 +62,6 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any: # advance. obj._seed = random.random() obj.random = random.Random(obj._seed) - - # TODO: Remove these 2 lines just before Mesa 3.0 - obj._steps = 0 - obj._time = 0 return obj def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -77,11 +73,20 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.running = True self.schedule = None self.current_id = 0 + self.steps: int = 0 self._setup_agent_registration() - self._steps: int = 0 - self._time: TimeT = 0 # the model's clock + # Wrap the user-defined step method + self._user_step = self.step + self.step = self._wrapped_step + + def _wrapped_step(self, *args: Any, **kwargs: Any) -> None: + """Automatically increments time and steps after calling the user's step method.""" + # Automatically increment time and step counters + self.steps += 1 + # Call the original user-defined step method + self._user_step(*args, **kwargs) @property def agents(self) -> AgentSet: @@ -180,11 +185,6 @@ def run_model(self) -> None: def step(self) -> None: """A single step. Fill in here.""" - def _advance_time(self, deltat: TimeT = 1): - """Increment the model's steps counter and clock.""" - self._steps += 1 - self._time += deltat - def next_id(self) -> int: """Return the next unique ID for agents, increment current_id""" self.current_id += 1 diff --git a/mesa/time.py b/mesa/time.py index e47563b4880..f50aa89b478 100644 --- a/mesa/time.py +++ b/mesa/time.py @@ -69,8 +69,6 @@ def __init__(self, model: Model, agents: Iterable[Agent] | None = None) -> None: self.model = model self.steps = 0 self.time: TimeT = 0 - self._original_step = self.step - self.step = self._wrapped_step if agents is None: agents = [] @@ -113,11 +111,6 @@ def step(self) -> None: self.steps += 1 self.time += 1 - def _wrapped_step(self): - """Wrapper for the step method to include time and step updating.""" - self._original_step() - self.model._advance_time() - def get_agent_count(self) -> int: """Returns the current number of agents in the queue.""" return len(self._agents) diff --git a/mesa/visualization/solara_viz.py b/mesa/visualization/solara_viz.py index 8ae53cfd314..eb05ff8ec0a 100644 --- a/mesa/visualization/solara_viz.py +++ b/mesa/visualization/solara_viz.py @@ -244,7 +244,7 @@ def do_step(): """Advance the model by one step.""" model.step() previous_step.value = current_step.value - current_step.value = model._steps + current_step.value = model.steps def do_play(): """Run the model continuously.""" diff --git a/tests/test_batch_run.py b/tests/test_batch_run.py index f5ddb0ae9f3..2de7c44735f 100644 --- a/tests/test_batch_run.py +++ b/tests/test_batch_run.py @@ -83,8 +83,8 @@ def get_local_model_param(self): return 42 def step(self): - self.datacollector.collect(self) self.schedule.step() + self.datacollector.collect(self) def test_batch_run(): diff --git a/tests/test_examples.py b/tests/test_examples.py index e5c0381f065..25b8e0e07b9 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -66,3 +66,4 @@ def test_examples(self): model = model_class() for _ in range(10): model.step() + self.assertEqual(model.steps, 10) diff --git a/tests/test_model.py b/tests/test_model.py index 874d45f935f..837d6b0049a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,7 +9,9 @@ def test_model_set_up(): assert model.current_id == 0 assert model.current_id + 1 == model.next_id() assert model.current_id == 1 + assert model.steps == 0 model.step() + assert model.steps == 1 def test_running(): @@ -18,12 +20,12 @@ class TestModel(Model): def step(self): """Increase steps until 10.""" - self.steps += 1 if self.steps == 10: self.running = False model = TestModel() model.run_model() + assert model.steps == 10 def test_seed(seed=23):