Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import contextlib
import copy
import operator
import warnings
import weakref
from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence
from random import Random
Expand Down Expand Up @@ -216,25 +217,64 @@ def _update(self, agents: Iterable[Agent]):
self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})
return self

def do(
self, method: str | Callable, *args, return_results: bool = False, **kwargs
) -> AgentSet | list[Any]:
def do(self, method: str | Callable, *args, **kwargs) -> AgentSet:
"""
Invoke a method or function on each agent in the AgentSet.

Args:
method (str, callable): the callable to do on each agents
method (str, callable): the callable to do on each agent

* in case of str, the name of the method to call on each agent.
* in case of callable, the function to be called with each agent as first argument

return_results (bool, optional): If True, returns the results of the method calls; otherwise, returns the AgentSet itself. Defaults to False, so you can chain method calls.
*args: Variable length argument list passed to the callable being called.
**kwargs: Arbitrary keyword arguments passed to the callable being called.

Returns:
AgentSet | list[Any]: The results of the callable calls if return_results is True, otherwise the AgentSet itself.
"""
try:
return_results = kwargs.pop("return_results")
except KeyError:
return_results = False
else:
warnings.warn(
"Using return_results is deprecated. Use AgenSet.do in case of return_results=False, and "
"AgentSet.map in case of return_results=True",
stacklevel=2,
)

if return_results:
return self.map(method, *args, **kwargs)

# we iterate over the actual weakref keys and check if weakref is alive before calling the method
if isinstance(method, str):
for agentref in self._agents.keyrefs():
if (agent := agentref()) is not None:
getattr(agent, method)(*args, **kwargs)
else:
for agentref in self._agents.keyrefs():
if (agent := agentref()) is not None:
method(agent, *args, **kwargs)

return self

def map(self, method: str | Callable, *args, **kwargs) -> list[Any]:
"""
Invoke a method or function on each agent in the AgentSet and return the results.

Args:
method (str, callable): the callable to apply on each agent

* in case of str, the name of the method to call on each agent.
* in case of callable, the function to be called with each agent as first argument

*args: Variable length argument list passed to the callable being called.
**kwargs: Arbitrary keyword arguments passed to the callable being called.

Returns:
list[Any]: The results of the callable calls
"""
# we iterate over the actual weakref keys and check if weakref is alive before calling the method
if isinstance(method, str):
res = [
Expand All @@ -249,7 +289,7 @@ def do(
if (agent := agentref()) is not None
]

return res if return_results else self
return res

def get(self, attr_names: str | list[str]) -> list[Any]:
"""
Expand Down
35 changes: 29 additions & 6 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,6 @@ def test_function(agent):
assert all(
a1 == a2.unique_id for a1, a2 in zip(agentset.get("unique_id"), agentset)
)
assert all(
a1 == a2.unique_id
for a1, a2 in zip(
agentset.do("get_unique_identifier", return_results=True), agentset
)
)
assert agentset == agentset.do("get_unique_identifier")

agentset.discard(agents[0])
Expand Down Expand Up @@ -276,6 +270,35 @@ def remove_function(agent):
assert len(agentset) == 0


def test_agentset_map_str():
model = Model()
agents = [TestAgent(model.next_id(), model) for _ in range(10)]
agentset = AgentSet(agents, model)

with pytest.raises(AttributeError):
agentset.do("non_existing_method")

results = agentset.map("get_unique_identifier")
assert all(i == entry for i, entry in zip(results, range(1, 11)))


def test_agentset_map_callable():
model = Model()
agents = [TestAgent(model.next_id(), model) for _ in range(10)]
agentset = AgentSet(agents, model)

# Test callable with non-existent function
with pytest.raises(AttributeError):
agentset.map(lambda agent: agent.non_existing_method())

# tests for addition and removal in do using callables
# do iterates, so no error should be raised to change size while iterating
# related to issue #1595

results = agentset.map(lambda agent: agent.unique_id)
assert all(i == entry for i, entry in zip(results, range(1, 11)))


def test_agentset_get_attribute():
model = Model()
agents = [TestAgent(model.next_id(), model) for _ in range(10)]
Expand Down