Skip to content

Commit d4d11e6

Browse files
committed
NetworkGrid: modify get_neighbors and create get_neighborhood
1 parent 238c2c0 commit d4d11e6

File tree

4 files changed

+19
-13
lines changed

4 files changed

+19
-13
lines changed

examples/boltzmann_wealth_model_network/boltzmann_wealth_model_network/model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(self, unique_id, model):
5757
def move(self):
5858
possible_steps = [
5959
node
60-
for node in self.model.grid.get_neighbors(self.pos, include_center=False)
60+
for node in self.model.grid.get_neighborhood(self.pos, include_center=False)
6161
if self.model.grid.is_cell_empty(node)
6262
]
6363
if len(possible_steps) > 0:
@@ -66,8 +66,7 @@ def move(self):
6666

6767
def give_money(self):
6868

69-
neighbors_nodes = self.model.grid.get_neighbors(self.pos, include_center=False)
70-
neighbors = self.model.grid.get_cell_list_contents(neighbors_nodes)
69+
neighbors = self.model.grid.get_neighbors(self.pos, include_center=False)
7170
if len(neighbors) > 0:
7271
other = self.random.choice(neighbors)
7372
other.wealth += 1

examples/virus_on_network/virus_on_network/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def __init__(
124124
self.gain_resistance_chance = gain_resistance_chance
125125

126126
def try_to_infect_neighbors(self):
127-
neighbors_nodes = self.model.grid.get_neighbors(self.pos, include_center=False)
127+
neighbors_nodes = self.model.grid.get_neighborhood(
128+
self.pos, include_center=False
129+
)
128130
susceptible_neighbors = [
129131
agent
130132
for agent in self.model.grid.get_cell_list_contents(neighbors_nodes)

mesa/space.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,12 +1015,17 @@ def place_agent(self, agent: Agent, node_id: int) -> None:
10151015
self.G.nodes[node_id]["agent"].append(agent)
10161016
agent.pos = node_id
10171017

1018-
def get_neighbors(self, node_id: int, include_center: bool = False) -> list[int]:
1018+
def get_neighborhood(self, node_id: int, include_center: bool = False) -> list[int]:
10191019
"""Get all adjacent nodes"""
1020-
neighbors = list(self.G.neighbors(node_id))
1020+
neighborhood = list(self.G.neighbors(node_id))
10211021
if include_center:
1022-
neighbors.append(node_id)
1023-
return neighbors
1022+
neighborhood.append(node_id)
1023+
return neighborhood
1024+
1025+
def get_neighbors(self, node_id: int, include_center: bool = False) -> list[Agent]:
1026+
"""Get all agents in adjacent nodes."""
1027+
neighborhood = self.get_neighborhood(node_id, include_center)
1028+
return self.get_cell_list_contents(neighborhood)
10241029

10251030
def move_agent(self, agent: Agent, node_id: int) -> None:
10261031
"""Move an agent from its current node to a new node."""
@@ -1046,7 +1051,7 @@ def get_cell_list_contents(self, cell_list: list[int]) -> list[Agent]:
10461051
for node_id in cell_list
10471052
if not self.is_cell_empty(node_id)
10481053
]
1049-
return [item for sublist in list_of_lists for item in sublist]
1054+
return [agent for sublist in list_of_lists for agent in sublist]
10501055

10511056
def get_all_cell_contents(self) -> list[Agent]:
10521057
"""Returns a list of all the agents in the network."""

tests/test_space.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,11 +358,11 @@ def test_agent_positions(self):
358358

359359
def test_get_neighbors(self):
360360
assert (
361-
len(self.space.get_neighbors(0, include_center=True))
361+
len(self.space.get_neighborhood(0, include_center=True))
362362
== TestSingleNetworkGrid.GRAPH_SIZE
363363
)
364364
assert (
365-
len(self.space.get_neighbors(0, include_center=False))
365+
len(self.space.get_neighborhood(0, include_center=False))
366366
== TestSingleNetworkGrid.GRAPH_SIZE - 1
367367
)
368368

@@ -433,11 +433,11 @@ def test_agent_positions(self):
433433

434434
def test_get_neighbors(self):
435435
assert (
436-
len(self.space.get_neighbors(0, include_center=True))
436+
len(self.space.get_neighborhood(0, include_center=True))
437437
== TestMultipleNetworkGrid.GRAPH_SIZE
438438
)
439439
assert (
440-
len(self.space.get_neighbors(0, include_center=False))
440+
len(self.space.get_neighborhood(0, include_center=False))
441441
== TestMultipleNetworkGrid.GRAPH_SIZE - 1
442442
)
443443

0 commit comments

Comments
 (0)