From 22303a07da4748ef746ac17b2f5539ef787c9372 Mon Sep 17 00:00:00 2001 From: SiddharthBansal007 Date: Wed, 12 Nov 2025 22:23:13 +0530 Subject: [PATCH 1/4] fix #2716: missing agent type hints --- mesa/agent.py | 56 +++++++++++++++++++++++++++++++++------------------ mesa/model.py | 23 +++++++++++---------- 2 files changed, 48 insertions(+), 31 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 93cd6287528..412b40ab209 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -19,7 +19,7 @@ from random import Random # mypy -from typing import TYPE_CHECKING, Any, Literal, overload +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload import numpy as np @@ -29,8 +29,21 @@ from mesa.model import Model from mesa.space import Position - -class Agent: + # Type variables for better static typing + M = TypeVar("M", bound=Model) + # Agent typevar bound to Agent (forward reference) for use in AgentSet etc. + A = TypeVar("A", bound="Agent") + T = TypeVar("T", bound="Agent") +else: + # At runtime, TypeVars do not need to be bound to concrete types. Provide + # plain TypeVars so code runs normally while type checkers still see the + # correct bounds. + M = TypeVar("M") + A = TypeVar("A") + T = TypeVar("T") + + +class Agent(Generic[M]): # noqa: UP046 """Base class for a model agent in Mesa. Attributes: @@ -48,7 +61,7 @@ class Agent: # so, unique_id is unique relative to a model, and counting starts from 1 _ids = defaultdict(functools.partial(itertools.count, 1)) - def __init__(self, model: Model, *args, **kwargs) -> None: + def __init__(self, model: M, *args, **kwargs) -> None: """Create a new agent. Args: @@ -62,7 +75,10 @@ def __init__(self, model: Model, *args, **kwargs) -> None: """ super().__init__(*args, **kwargs) - self.model: Model = model + # Preserve the more specific model type for static type checkers by + # typing Agent as Generic[M]. At runtime this remains the Model + # instance passed in. + self.model: M = model self.unique_id: int = next(self._ids[model]) self.pos: Position | None = None self.model.register_agent(self) @@ -85,7 +101,7 @@ def advance(self) -> None: # noqa: D102 pass @classmethod - def create_agents(cls, model: Model, n: int, *args, **kwargs) -> AgentSet[Agent]: + def create_agents(cls: type[T], model: M, n: int, *args, **kwargs) -> AgentSet[T]: """Create N agents. Args: @@ -146,7 +162,7 @@ def rng(self) -> np.random.Generator: return self.model.rng -class AgentSet(MutableSet, Sequence): +class AgentSet(Generic[A], MutableSet[A], Sequence[A]): # noqa: UP046 """A collection class that represents an ordered set of agents within an agent-based model (ABM). This class extends both MutableSet and Sequence, providing set-like functionality with order preservation and @@ -171,7 +187,7 @@ class AgentSet(MutableSet, Sequence): def __init__( self, - agents: Iterable[Agent], + agents: Iterable[A], random: Random | None = None, ): """Initializes the AgentSet with a collection of agents and a reference to the model. @@ -200,11 +216,11 @@ def __len__(self) -> int: """Return the number of agents in the AgentSet.""" return len(self._agents) - def __iter__(self) -> Iterator[Agent]: + def __iter__(self) -> Iterator[A]: """Provide an iterator over the agents in the AgentSet.""" return self._agents.keys() - def __contains__(self, agent: Agent) -> bool: + def __contains__(self, agent: A) -> bool: """Check if an agent is in the AgentSet. Can be used like `agent in agentset`.""" return agent in self._agents @@ -213,7 +229,7 @@ def select( filter_func: Callable[[Agent], bool] | None = None, at_most: int | float = float("inf"), inplace: bool = False, - agent_type: type[Agent] | None = None, + agent_type: type[A] | None = None, ) -> AgentSet: """Select a subset of agents from the AgentSet based on a filter function and/or quantity limit. @@ -256,7 +272,7 @@ def agent_generator(filter_func, agent_type, at_most): return AgentSet(agents, self.random) if not inplace else self._update(agents) - def shuffle(self, inplace: bool = False) -> AgentSet: + def shuffle(self, inplace: bool = False) -> AgentSet[A]: """Randomly shuffle the order of agents in the AgentSet. Args: @@ -307,7 +323,7 @@ def sort( else self._update(sorted_agents) ) - def _update(self, agents: Iterable[Agent]): + def _update(self, agents: Iterable[A]): """Update the AgentSet with a new set of agents. This is a private method primarily used internally by other methods like select, shuffle, and sort. @@ -315,7 +331,7 @@ def _update(self, agents: Iterable[Agent]): self._agents = weakref.WeakKeyDictionary(dict.fromkeys(agents)) return self - def do(self, method: str | Callable, *args, **kwargs) -> AgentSet: + def do(self, method: str | Callable, *args, **kwargs) -> AgentSet[A]: """Invoke a method or function on each agent in the AgentSet. Args: @@ -342,7 +358,7 @@ def do(self, method: str | Callable, *args, **kwargs) -> AgentSet: return self - def shuffle_do(self, method: str | Callable, *args, **kwargs) -> AgentSet: + def shuffle_do(self, method: str | Callable, *args, **kwargs) -> AgentSet[A]: """Shuffle the agents in the AgentSet and then invoke a method or function on each agent. It's a fast, optimized version of calling shuffle() followed by do(). @@ -488,7 +504,7 @@ def get( "should be one of 'error' or 'default'" ) - def set(self, attr_name: str, value: Any) -> AgentSet: + def set(self, attr_name: str, value: Any) -> AgentSet[A]: """Set a specified attribute to a given value for all agents in the AgentSet. Args: @@ -502,7 +518,7 @@ def set(self, attr_name: str, value: Any) -> AgentSet: setattr(agent, attr_name, value) return self - def __getitem__(self, item: int | slice) -> Agent: + def __getitem__(self, item: int | slice) -> A: """Retrieve an agent or a slice of agents from the AgentSet. Args: @@ -513,7 +529,7 @@ def __getitem__(self, item: int | slice) -> Agent: """ return list(self._agents.keys())[item] - def add(self, agent: Agent): + def add(self, agent: A): """Add an agent to the AgentSet. Args: @@ -524,7 +540,7 @@ def add(self, agent: Agent): """ self._agents[agent] = None - def discard(self, agent: Agent): + def discard(self, agent: A): """Remove an agent from the AgentSet if it exists. This method does not raise an error if the agent is not present. @@ -538,7 +554,7 @@ def discard(self, agent: Agent): with contextlib.suppress(KeyError): del self._agents[agent] - def remove(self, agent: Agent): + def remove(self, agent: A): """Remove an agent from the AgentSet. This method raises an error if the agent is not present. diff --git a/mesa/model.py b/mesa/model.py index f53d92b1633..fb615bf5363 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -12,13 +12,16 @@ from collections.abc import Sequence # mypy -from typing import Any +from typing import Any, Generic, TypeVar import numpy as np from mesa.agent import Agent, AgentSet from mesa.mesa_logging import create_module_logger, method_logger +# Type variable for agent types so Model can be parameterized by its Agent class +A = TypeVar("A", bound=Agent) + SeedLike = int | np.integer | Sequence[int] | np.random.SeedSequence RNGLike = np.random.Generator | np.random.BitGenerator @@ -26,7 +29,7 @@ _mesa_logger = create_module_logger() -class Model: +class Model(Generic[A]): # noqa: UP046 """Base class for models in the Mesa ABM library. This class serves as a foundational structure for creating agent-based models. @@ -107,11 +110,9 @@ def __init__( # setup agent registration data structures self._agents = {} # the hard references to all agents in the model self._agents_by_type: dict[ - type[Agent], AgentSet - ] = {} # a dict with an agentset for each class of agents - self._all_agents = AgentSet( - [], random=self.random - ) # an agenset with all agents + type[A], AgentSet[A] + ] = {} # agentset per agent class + self._all_agents: AgentSet[A] = AgentSet([], random=self.random) def _wrapped_step(self, *args: Any, **kwargs: Any) -> None: """Automatically increments time and steps after calling the user's step method.""" @@ -122,7 +123,7 @@ def _wrapped_step(self, *args: Any, **kwargs: Any) -> None: self._user_step(*args, **kwargs) @property - def agents(self) -> AgentSet: + def agents(self) -> AgentSet[A]: """Provides an AgentSet of all agents in the model, combining agents from all types.""" return self._all_agents @@ -140,11 +141,11 @@ def agent_types(self) -> list[type]: return list(self._agents_by_type.keys()) @property - def agents_by_type(self) -> dict[type[Agent], AgentSet]: + def agents_by_type(self) -> dict[type[A], AgentSet[A]]: """A dictionary where the keys are agent types and the values are the corresponding AgentSets.""" return self._agents_by_type - def register_agent(self, agent): + def register_agent(self, agent: A): """Register the agent with the model. Args: @@ -174,7 +175,7 @@ def register_agent(self, agent): f"registered {agent.__class__.__name__} with agent_id {agent.unique_id}" ) - def deregister_agent(self, agent): + def deregister_agent(self, agent: A): """Deregister the agent with the model. Args: From b4ed50aadd67a2e9bf5732931affcc92f675517c Mon Sep 17 00:00:00 2001 From: SiddharthBansal007 Date: Thu, 13 Nov 2025 01:13:51 +0530 Subject: [PATCH 2/4] Addressed the comment and updated the codebase --- mesa/agent.py | 32 +++++++++----------------------- mesa/model.py | 8 ++++---- mesa/space.py | 4 +--- 3 files changed, 14 insertions(+), 30 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 412b40ab209..66ef659d6ff 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -19,31 +19,16 @@ from random import Random # mypy -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload +from typing import TYPE_CHECKING, Any, Literal, overload import numpy as np if TYPE_CHECKING: - # We ensure that these are not imported during runtime to prevent cyclic - # dependency. from mesa.model import Model from mesa.space import Position - # Type variables for better static typing - M = TypeVar("M", bound=Model) - # Agent typevar bound to Agent (forward reference) for use in AgentSet etc. - A = TypeVar("A", bound="Agent") - T = TypeVar("T", bound="Agent") -else: - # At runtime, TypeVars do not need to be bound to concrete types. Provide - # plain TypeVars so code runs normally while type checkers still see the - # correct bounds. - M = TypeVar("M") - A = TypeVar("A") - T = TypeVar("T") - - -class Agent(Generic[M]): # noqa: UP046 + +class Agent[M: Model]: """Base class for a model agent in Mesa. Attributes: @@ -75,9 +60,8 @@ def __init__(self, model: M, *args, **kwargs) -> None: """ super().__init__(*args, **kwargs) - # Preserve the more specific model type for static type checkers by - # typing Agent as Generic[M]. At runtime this remains the Model - # instance passed in. + # Preserve the more specific model type for static type checkers. + # At runtime this remains the Model instance passed in. self.model: M = model self.unique_id: int = next(self._ids[model]) self.pos: Position | None = None @@ -101,7 +85,9 @@ def advance(self) -> None: # noqa: D102 pass @classmethod - def create_agents(cls: type[T], model: M, n: int, *args, **kwargs) -> AgentSet[T]: + def create_agents[T: Agent]( + cls: type[T], model: Model, n: int, *args, **kwargs + ) -> AgentSet[T]: """Create N agents. Args: @@ -162,7 +148,7 @@ def rng(self) -> np.random.Generator: return self.model.rng -class AgentSet(Generic[A], MutableSet[A], Sequence[A]): # noqa: UP046 +class AgentSet[A: Agent](MutableSet[A], Sequence[A]): """A collection class that represents an ordered set of agents within an agent-based model (ABM). This class extends both MutableSet and Sequence, providing set-like functionality with order preservation and diff --git a/mesa/model.py b/mesa/model.py index fb615bf5363..872d48df24c 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -12,15 +12,15 @@ from collections.abc import Sequence # mypy -from typing import Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any import numpy as np from mesa.agent import Agent, AgentSet from mesa.mesa_logging import create_module_logger, method_logger -# Type variable for agent types so Model can be parameterized by its Agent class -A = TypeVar("A", bound=Agent) +if TYPE_CHECKING: + pass SeedLike = int | np.integer | Sequence[int] | np.random.SeedSequence RNGLike = np.random.Generator | np.random.BitGenerator @@ -29,7 +29,7 @@ _mesa_logger = create_module_logger() -class Model(Generic[A]): # noqa: UP046 +class Model[A: Agent]: """Base class for models in the Mesa ABM library. This class serves as a foundational structure for creating agent-based models. diff --git a/mesa/space.py b/mesa/space.py index c9cd04f4223..5d7a899e826 100644 --- a/mesa/space.py +++ b/mesa/space.py @@ -33,7 +33,7 @@ import warnings from collections.abc import Callable, Iterable, Iterator, Sequence from numbers import Real -from typing import Any, TypeVar, cast, overload +from typing import Any, cast, overload from warnings import warn with contextlib.suppress(ImportError): @@ -58,8 +58,6 @@ GridContent = Agent | None MultiGridContent = list[Agent] -F = TypeVar("F", bound=Callable[..., Any]) - def accept_tuple_argument[F: Callable[..., Any]](wrapped_function: F) -> F: """Decorator to allow grid methods that take a list of (x, y) coord tuples to also handle a single position. From 61f7e1682d979f8402c547d5962e38802c949bcb Mon Sep 17 00:00:00 2001 From: SiddharthBansal007 Date: Thu, 13 Nov 2025 15:54:39 +0530 Subject: [PATCH 3/4] Addressed the review and updated the comments --- mesa/agent.py | 3 +-- mesa/model.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 66ef659d6ff..80fb2b438bd 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -60,8 +60,7 @@ def __init__(self, model: M, *args, **kwargs) -> None: """ super().__init__(*args, **kwargs) - # Preserve the more specific model type for static type checkers. - # At runtime this remains the Model instance passed in. + self.model: M = model self.unique_id: int = next(self._ids[model]) self.pos: Position | None = None diff --git a/mesa/model.py b/mesa/model.py index 872d48df24c..cca6a10a7e3 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -111,8 +111,8 @@ def __init__( self._agents = {} # the hard references to all agents in the model self._agents_by_type: dict[ type[A], AgentSet[A] - ] = {} # agentset per agent class - self._all_agents: AgentSet[A] = AgentSet([], random=self.random) + ] = {} # a dict with an agentset for each class of agents + self._all_agents: AgentSet[A] = AgentSet([], random=self.random) # an agenset with all agents def _wrapped_step(self, *args: Any, **kwargs: Any) -> None: """Automatically increments time and steps after calling the user's step method.""" From 129f3c25c061a27bdd4b1f338548313bf2a9d571 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Nov 2025 10:25:01 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa/agent.py | 1 - mesa/model.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 80fb2b438bd..e8bf2bc8c71 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -60,7 +60,6 @@ def __init__(self, model: M, *args, **kwargs) -> None: """ super().__init__(*args, **kwargs) - self.model: M = model self.unique_id: int = next(self._ids[model]) self.pos: Position | None = None diff --git a/mesa/model.py b/mesa/model.py index cca6a10a7e3..c49403e9f41 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -112,7 +112,9 @@ def __init__( self._agents_by_type: dict[ type[A], AgentSet[A] ] = {} # a dict with an agentset for each class of agents - self._all_agents: AgentSet[A] = AgentSet([], random=self.random) # an agenset with all agents + self._all_agents: AgentSet[A] = AgentSet( + [], random=self.random + ) # an agenset with all agents def _wrapped_step(self, *args: Any, **kwargs: Any) -> None: """Automatically increments time and steps after calling the user's step method."""