From afc0f8a80d63d1aaa258f92a560bfdf9649111a3 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Wed, 21 Aug 2024 19:35:25 +0200 Subject: [PATCH 1/5] seperate apply and do in AgentSet --- mesa/agent.py | 54 ++++++++++++++++++++++++++++++++++++++++----- tests/test_agent.py | 36 +++++++++++++++++++++++++----- 2 files changed, 79 insertions(+), 11 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index e65b25f69c2..4406879c527 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -11,6 +11,7 @@ import contextlib import copy import operator +import warnings import weakref from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence from random import Random @@ -217,24 +218,66 @@ def _update(self, agents: Iterable[Agent]): return self def do( - self, method: str | Callable, *args, return_results: bool = False, **kwargs - ) -> AgentSet | list[Any]: + self, method: str | Callable, *args, **kwargs + ) -> AgentSet: """ Invoke a method or function on each agent in the AgentSet. Args: - method (str, callable): the callable to do on each agents + method (str, callable): the callable to do on each agent * in case of str, the name of the method to call on each agent. * in case of callable, the function to be called with each agent as first argument - return_results (bool, optional): If True, returns the results of the method calls; otherwise, returns the AgentSet itself. Defaults to False, so you can chain method calls. *args: Variable length argument list passed to the callable being called. **kwargs: Arbitrary keyword arguments passed to the callable being called. Returns: AgentSet | list[Any]: The results of the callable calls if return_results is True, otherwise the AgentSet itself. """ + try: + return_results = kwargs.pop("return_results") + except KeyError: + return_results = False + else: + warnings.warn("Using return_results is deprecated. Use AgenSet.do in case of return_results=False, and " + "AgentSet.apply in case of return_results=True", stacklevel=2) + + if return_results: + return self.apply(method, *args, **kwargs) + + + # we iterate over the actual weakref keys and check if weakref is alive before calling the method + if isinstance(method, str): + for agentref in self._agents.keyrefs(): + if (agent := agentref()) is not None: + getattr(agent, method)(*args, **kwargs) + else: + for agentref in self._agents.keyrefs(): + if (agent := agentref()) is not None: + method(agent, *args, **kwargs) + + return self + + + def apply( + self, method: str | Callable, *args, **kwargs + ) -> list[Any]: + """ + Invoke a method or function on each agent in the AgentSet. + + Args: + method (str, callable): the callable to apply on each agent + + * in case of str, the name of the method to call on each agent. + * in case of callable, the function to be called with each agent as first argument + + *args: Variable length argument list passed to the callable being called. + **kwargs: Arbitrary keyword arguments passed to the callable being called. + + Returns: + list[Any]: The results of the callable calls + """ # we iterate over the actual weakref keys and check if weakref is alive before calling the method if isinstance(method, str): res = [ @@ -249,7 +292,8 @@ def do( if (agent := agentref()) is not None ] - return res if return_results else self + return res + def get(self, attr_names: str | list[str]) -> list[Any]: """ diff --git a/tests/test_agent.py b/tests/test_agent.py index f4e64ce5338..99b92fbc930 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -83,12 +83,6 @@ def test_function(agent): assert all( a1 == a2.unique_id for a1, a2 in zip(agentset.get("unique_id"), agentset) ) - assert all( - a1 == a2.unique_id - for a1, a2 in zip( - agentset.do("get_unique_identifier", return_results=True), agentset - ) - ) assert agentset == agentset.do("get_unique_identifier") agentset.discard(agents[0]) @@ -276,6 +270,36 @@ def remove_function(agent): assert len(agentset) == 0 + +def test_agentset_apply_str(): + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(10)] + agentset = AgentSet(agents, model) + + with pytest.raises(AttributeError): + agentset.do("non_existing_method") + + results = agentset.apply("get_unique_identifier") + assert all(i == entry for i, entry in zip(results, range(1, 11))) + +def test_agentset_apply_callable(): + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(10)] + agentset = AgentSet(agents, model) + + # Test callable with non-existent function + with pytest.raises(AttributeError): + agentset.apply(lambda agent: agent.non_existing_method()) + + # tests for addition and removal in do using callables + # do iterates, so no error should be raised to change size while iterating + # related to issue #1595 + + results = agentset.apply(lambda agent: agent.unique_id) + assert all(i == entry for i, entry in zip(results, range(1, 11))) + + + def test_agentset_get_attribute(): model = Model() agents = [TestAgent(model.next_id(), model) for _ in range(10)] From f4154c183cd4823befbbb9ea093331167f08bcbd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Aug 2024 17:43:56 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa/agent.py | 18 +++++++----------- tests/test_agent.py | 3 +-- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 4406879c527..191b4cd07d7 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -217,9 +217,7 @@ def _update(self, agents: Iterable[Agent]): self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents}) return self - def do( - self, method: str | Callable, *args, **kwargs - ) -> AgentSet: + def do(self, method: str | Callable, *args, **kwargs) -> AgentSet: """ Invoke a method or function on each agent in the AgentSet. @@ -240,13 +238,15 @@ def do( except KeyError: return_results = False else: - warnings.warn("Using return_results is deprecated. Use AgenSet.do in case of return_results=False, and " - "AgentSet.apply in case of return_results=True", stacklevel=2) + warnings.warn( + "Using return_results is deprecated. Use AgenSet.do in case of return_results=False, and " + "AgentSet.apply in case of return_results=True", + stacklevel=2, + ) if return_results: return self.apply(method, *args, **kwargs) - # we iterate over the actual weakref keys and check if weakref is alive before calling the method if isinstance(method, str): for agentref in self._agents.keyrefs(): @@ -259,10 +259,7 @@ def do( return self - - def apply( - self, method: str | Callable, *args, **kwargs - ) -> list[Any]: + def apply(self, method: str | Callable, *args, **kwargs) -> list[Any]: """ Invoke a method or function on each agent in the AgentSet. @@ -294,7 +291,6 @@ def apply( return res - def get(self, attr_names: str | list[str]) -> list[Any]: """ Retrieve the specified attribute(s) from each agent in the AgentSet. diff --git a/tests/test_agent.py b/tests/test_agent.py index 99b92fbc930..d15f42cbd2a 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -270,7 +270,6 @@ def remove_function(agent): assert len(agentset) == 0 - def test_agentset_apply_str(): model = Model() agents = [TestAgent(model.next_id(), model) for _ in range(10)] @@ -282,6 +281,7 @@ def test_agentset_apply_str(): results = agentset.apply("get_unique_identifier") assert all(i == entry for i, entry in zip(results, range(1, 11))) + def test_agentset_apply_callable(): model = Model() agents = [TestAgent(model.next_id(), model) for _ in range(10)] @@ -299,7 +299,6 @@ def test_agentset_apply_callable(): assert all(i == entry for i, entry in zip(results, range(1, 11))) - def test_agentset_get_attribute(): model = Model() agents = [TestAgent(model.next_id(), model) for _ in range(10)] From fe4d5c1ea6debcf052d5be2a1fec4502cfa274e6 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Thu, 22 Aug 2024 10:41:05 +0200 Subject: [PATCH 3/5] change apply to map --- mesa/agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 191b4cd07d7..8c8257049f7 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -259,9 +259,9 @@ def do(self, method: str | Callable, *args, **kwargs) -> AgentSet: return self - def apply(self, method: str | Callable, *args, **kwargs) -> list[Any]: + def map(self, method: str | Callable, *args, **kwargs) -> list[Any]: """ - Invoke a method or function on each agent in the AgentSet. + Invoke a method or function on each agent in the AgentSet and return the results. Args: method (str, callable): the callable to apply on each agent From 6eb973c256fe389a9e18cd402ba68a5af5ca659a Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Thu, 22 Aug 2024 10:45:03 +0200 Subject: [PATCH 4/5] change map in tests --- tests/test_agent.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_agent.py b/tests/test_agent.py index d15f42cbd2a..1541a46b6c2 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -270,7 +270,7 @@ def remove_function(agent): assert len(agentset) == 0 -def test_agentset_apply_str(): +def test_agentset_map_str(): model = Model() agents = [TestAgent(model.next_id(), model) for _ in range(10)] agentset = AgentSet(agents, model) @@ -278,24 +278,24 @@ def test_agentset_apply_str(): with pytest.raises(AttributeError): agentset.do("non_existing_method") - results = agentset.apply("get_unique_identifier") + results = agentset.map("get_unique_identifier") assert all(i == entry for i, entry in zip(results, range(1, 11))) -def test_agentset_apply_callable(): +def test_agentset_map_callable(): model = Model() agents = [TestAgent(model.next_id(), model) for _ in range(10)] agentset = AgentSet(agents, model) # Test callable with non-existent function with pytest.raises(AttributeError): - agentset.apply(lambda agent: agent.non_existing_method()) + agentset.map(lambda agent: agent.non_existing_method()) # tests for addition and removal in do using callables # do iterates, so no error should be raised to change size while iterating # related to issue #1595 - results = agentset.apply(lambda agent: agent.unique_id) + results = agentset.map(lambda agent: agent.unique_id) assert all(i == entry for i, entry in zip(results, range(1, 11))) From de2863023cf773978bd52d3079348cf2ddddee4e Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Thu, 22 Aug 2024 10:55:56 +0200 Subject: [PATCH 5/5] change 2 last uses of apply to map --- mesa/agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 8c8257049f7..8805673193d 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -240,12 +240,12 @@ def do(self, method: str | Callable, *args, **kwargs) -> AgentSet: else: warnings.warn( "Using return_results is deprecated. Use AgenSet.do in case of return_results=False, and " - "AgentSet.apply in case of return_results=True", + "AgentSet.map in case of return_results=True", stacklevel=2, ) if return_results: - return self.apply(method, *args, **kwargs) + return self.map(method, *args, **kwargs) # we iterate over the actual weakref keys and check if weakref is alive before calling the method if isinstance(method, str):