Skip to content

Commit 5e083ac

Browse files
authored
Add filters for circuit in simulation campaign (#311)
1 parent 45b7223 commit 5e083ac

File tree

5 files changed

+263
-3
lines changed

5 files changed

+263
-3
lines changed

app/filters/circuit.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
11
import uuid
22
from typing import Annotated
33

4+
from fastapi_filter import with_prefix
5+
46
from app.db.model import Circuit
57
from app.dependencies.filter import FilterDepends
6-
from app.filters.common import NameFilterMixin
8+
from app.filters.base import CustomFilter
9+
from app.filters.common import IdFilterMixin, NameFilterMixin
710
from app.filters.scientific_artifact import ScientificArtifactFilter
811

912

13+
class NestedCircuitFilter(IdFilterMixin, NameFilterMixin, CustomFilter):
14+
"""Circuit filter with limited fields for nesting."""
15+
16+
scale: str | None = None
17+
scale__in: list[str] | None = None
18+
19+
build_category: str | None = None
20+
build_category__in: list[str] | None = None
21+
22+
class Constants(CustomFilter.Constants):
23+
model = Circuit
24+
25+
1026
class CircuitFilter(ScientificArtifactFilter, NameFilterMixin):
1127
atlas_id: uuid.UUID | None = None
1228
root_circuit_id: uuid.UUID | None = None
@@ -39,3 +55,4 @@ class Constants(ScientificArtifactFilter.Constants):
3955

4056

4157
CircuitFilterDep = Annotated[CircuitFilter, FilterDepends(CircuitFilter)]
58+
NestedCircuitFilterDep = FilterDepends(with_prefix("circuit", NestedCircuitFilter))

app/filters/simulation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from app.db.model import Simulation
77
from app.dependencies.filter import FilterDepends
88
from app.filters.base import CustomFilter
9+
from app.filters.circuit import NestedCircuitFilter, NestedCircuitFilterDep
910
from app.filters.common import (
1011
ContributionFilterMixin,
1112
CreationFilterMixin,
@@ -18,6 +19,7 @@
1819
class SimulationFilterBase(NameFilterMixin, IdFilterMixin, CustomFilter):
1920
entity_id: uuid.UUID | None = None
2021
entity_id__in: list[uuid.UUID] | None = None
22+
circuit: Annotated[NestedCircuitFilter | None, NestedCircuitFilterDep] = None
2123

2224

2325
class NestedSimulationFilter(SimulationFilterBase):

app/queries/factory.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from app.db.model import (
44
Agent,
55
BrainRegion,
6+
Circuit,
67
Contribution,
78
EModel,
89
Entity,
@@ -72,6 +73,7 @@ def _get_alias[T: type[Identifiable]](db_cls: T, name: str | None = None) -> T:
7273
simulation_alias = _get_alias(Simulation)
7374
used_alias = _get_alias(Entity, "used")
7475
generated_alias = _get_alias(Entity, "generated")
76+
circuit_alias = _get_alias(Circuit)
7577

7678
name_to_facet_query_params: dict[str, FacetQueryParams] = {
7779
"agent": {
@@ -111,6 +113,7 @@ def _get_alias[T: type[Identifiable]](db_cls: T, name: str | None = None) -> T:
111113
"pre_region": {"id": pre_region_alias.id, "label": pre_region_alias.name},
112114
"post_region": {"id": post_region_alias.id, "label": post_region_alias.name},
113115
"simulation": {"id": simulation_alias.id, "label": simulation_alias.name},
116+
"simulation.circuit": {"id": circuit_alias.id, "label": circuit_alias.name},
114117
}
115118
filter_joins = {
116119
"species": lambda q: q.join(Species, db_model_class.species_id == Species.id),
@@ -180,6 +183,9 @@ def _get_alias[T: type[Identifiable]](db_cls: T, name: str | None = None) -> T:
180183
"simulation": lambda q: q.outerjoin(
181184
simulation_alias, db_model_class.id == simulation_alias.simulation_campaign_id
182185
),
186+
"simulation.circuit": lambda q: q.join(
187+
circuit_alias, simulation_alias.entity_id == circuit_alias.id
188+
),
183189
"used": lambda q: q.outerjoin(
184190
Usage, db_model_class.id == Usage.usage_activity_id
185191
).outerjoin(used_alias, Usage.usage_entity_id == used_alias.id),

app/service/simulation_campaign.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from app.db.model import (
88
Agent,
9+
Circuit,
910
Simulation,
1011
SimulationCampaign,
1112
)
@@ -84,20 +85,22 @@ def read_many(
8485
created_by_alias = aliased(Agent, flat=True)
8586
updated_by_alias = aliased(Agent, flat=True)
8687
simulation_alias = aliased(Simulation, flat=True)
87-
88+
circuit_alias = aliased(Circuit, flat=True)
8889
aliases: Aliases = {
8990
Agent: {
9091
"contribution": agent_alias,
9192
"created_by": created_by_alias,
9293
"updated_by": updated_by_alias,
9394
},
9495
Simulation: simulation_alias,
96+
Circuit: circuit_alias,
9597
}
9698
facet_keys = filter_keys = [
9799
"created_by",
98100
"updated_by",
99101
"contribution",
100102
"simulation",
103+
"simulation.circuit",
101104
]
102105
name_to_facet_query_params, filter_joins = query_params_factory(
103106
db_model_class=SimulationCampaign,

tests/test_simulation_campaign.py

Lines changed: 233 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from app.db.model import Simulation, SimulationCampaign
3+
from app.db.model import Circuit, CircuitBuildCategory, CircuitScale, Simulation, SimulationCampaign
44
from app.db.types import EntityType
55

66
from .utils import (
@@ -125,3 +125,235 @@ def test_filtering(client, models, simulation_json_data, person_id):
125125
client.get, url=ROUTE, params={"simulation__name": simulation_json_data["name"]}
126126
).json()["data"]
127127
assert len(data) == 3
128+
129+
130+
@pytest.fixture
131+
def multiple_circuits(db, brain_atlas_id, subject_id, brain_region_id, license_id, person_id):
132+
circuits_data = [
133+
{
134+
"name": "micro-circuit-1",
135+
"description": "Micro Circuit 1",
136+
"has_morphologies": True,
137+
"has_point_neurons": False,
138+
"has_electrical_cell_models": True,
139+
"has_spines": False,
140+
"number_neurons": 100,
141+
"number_synapses": 1000,
142+
"number_connections": 50,
143+
"scale": CircuitScale.microcircuit,
144+
"build_category": CircuitBuildCategory.computational_model,
145+
"atlas_id": brain_atlas_id,
146+
"subject_id": subject_id,
147+
"brain_region_id": brain_region_id,
148+
"license_id": license_id,
149+
"authorized_public": False,
150+
"created_by_id": person_id,
151+
"updated_by_id": person_id,
152+
"authorized_project_id": PROJECT_ID,
153+
},
154+
{
155+
"name": "micro-circuit-2",
156+
"description": "Micro Circuit 2",
157+
"has_morphologies": False,
158+
"has_point_neurons": True,
159+
"has_electrical_cell_models": False,
160+
"has_spines": True,
161+
"number_neurons": 1000,
162+
"number_synapses": 10000,
163+
"number_connections": 200,
164+
"scale": CircuitScale.microcircuit,
165+
"build_category": CircuitBuildCategory.em_reconstruction,
166+
"atlas_id": brain_atlas_id,
167+
"subject_id": subject_id,
168+
"brain_region_id": brain_region_id,
169+
"license_id": license_id,
170+
"authorized_public": False,
171+
"created_by_id": person_id,
172+
"updated_by_id": person_id,
173+
"authorized_project_id": PROJECT_ID,
174+
},
175+
{
176+
"name": "pair-circuit-1",
177+
"description": "Pair Circuit 1",
178+
"has_morphologies": True,
179+
"has_point_neurons": True,
180+
"has_electrical_cell_models": True,
181+
"has_spines": True,
182+
"number_neurons": 10000,
183+
"number_synapses": 100000,
184+
"number_connections": 500,
185+
"scale": CircuitScale.pair,
186+
"build_category": CircuitBuildCategory.computational_model,
187+
"atlas_id": brain_atlas_id,
188+
"subject_id": subject_id,
189+
"brain_region_id": brain_region_id,
190+
"license_id": license_id,
191+
"authorized_public": False,
192+
"created_by_id": person_id,
193+
"updated_by_id": person_id,
194+
"authorized_project_id": PROJECT_ID,
195+
},
196+
]
197+
198+
circuits = [add_db(db, Circuit(**circuit_data)) for circuit_data in circuits_data]
199+
return circuits
200+
201+
202+
@pytest.fixture
203+
def campaigns_with_different_circuits(
204+
db, json_data, person_id, simulation_json_data, multiple_circuits
205+
):
206+
campaigns = []
207+
208+
for i, circuit in enumerate(multiple_circuits):
209+
campaign = add_db(
210+
db,
211+
MODEL(
212+
**(
213+
json_data
214+
| {
215+
"name": f"campaign-circuit-{i}",
216+
"description": f"Campaign for circuit {i}",
217+
"created_by_id": person_id,
218+
"updated_by_id": person_id,
219+
"authorized_project_id": PROJECT_ID,
220+
}
221+
)
222+
),
223+
)
224+
campaigns.append(campaign)
225+
226+
add_db(
227+
db,
228+
Simulation(
229+
**simulation_json_data
230+
| {
231+
"name": f"simulation-circuit-{i}",
232+
"simulation_campaign_id": campaign.id,
233+
"entity_id": circuit.id,
234+
"created_by_id": person_id,
235+
"updated_by_id": person_id,
236+
"authorized_project_id": PROJECT_ID,
237+
}
238+
),
239+
)
240+
241+
return campaigns
242+
243+
244+
def test_filter_by_circuit_id(client, campaigns_with_different_circuits, multiple_circuits): # noqa: ARG001
245+
first_circuit_id = str(multiple_circuits[0].id)
246+
data = assert_request(client.get, url=ROUTE, params={"circuit__id": first_circuit_id}).json()[
247+
"data"
248+
]
249+
250+
assert len(data) == 1
251+
assert data[0]["name"] == "campaign-circuit-0"
252+
253+
second_circuit_id = str(multiple_circuits[1].id)
254+
data = assert_request(client.get, url=ROUTE, params={"circuit__id": second_circuit_id}).json()[
255+
"data"
256+
]
257+
258+
assert len(data) == 1
259+
assert data[0]["name"] == "campaign-circuit-1"
260+
261+
262+
def test_filter_by_circuit_name(client, campaigns_with_different_circuits, multiple_circuits): # noqa: ARG001
263+
data = assert_request(
264+
client.get, url=ROUTE, params={"circuit__name": "micro-circuit-1"}
265+
).json()["data"]
266+
267+
assert len(data) == 1
268+
assert data[0]["name"] == "campaign-circuit-0"
269+
270+
data = assert_request(
271+
client.get, url=ROUTE, params={"circuit__name__in": "micro-circuit-2"}
272+
).json()["data"]
273+
274+
assert len(data) == 1
275+
assert data[0]["name"] == "campaign-circuit-1"
276+
277+
278+
def test_filter_by_circuit_scale(client, campaigns_with_different_circuits, multiple_circuits): # noqa: ARG001
279+
data = assert_request(
280+
client.get, url=ROUTE, params={"circuit__scale": CircuitScale.microcircuit}
281+
).json()["data"]
282+
283+
assert len(data) == 2
284+
285+
data = assert_request(
286+
client.get, url=ROUTE, params={"circuit__scale": CircuitScale.pair}
287+
).json()["data"]
288+
289+
assert len(data) == 1
290+
291+
292+
def test_filter_by_circuit_scale_empty(
293+
client,
294+
campaigns_with_different_circuits, # noqa: ARG001
295+
multiple_circuits, # noqa: ARG001
296+
):
297+
data = assert_request(
298+
client.get, url=ROUTE, params={"circuit__scale": CircuitScale.small}
299+
).json()["data"]
300+
301+
assert len(data) == 0
302+
303+
304+
def test_filter_by_circuit_scale_in(client, campaigns_with_different_circuits, multiple_circuits): # noqa: ARG001
305+
data = assert_request(
306+
client.get,
307+
url=ROUTE,
308+
params={"circuit__scale__in": [CircuitScale.microcircuit, CircuitScale.pair]},
309+
).json()["data"]
310+
311+
assert len(data) == 3
312+
campaign_names = {campaign["name"] for campaign in data}
313+
assert campaign_names == {"campaign-circuit-0", "campaign-circuit-1", "campaign-circuit-2"}
314+
315+
316+
def test_filter_by_circuit_build_category(
317+
client,
318+
campaigns_with_different_circuits, # noqa: ARG001
319+
multiple_circuits, # noqa: ARG001
320+
):
321+
data = assert_request(
322+
client.get,
323+
url=ROUTE,
324+
params={"circuit__build_category": CircuitBuildCategory.computational_model},
325+
).json()["data"]
326+
327+
assert len(data) == 2
328+
campaign_names = {campaign["name"] for campaign in data}
329+
assert campaign_names == {"campaign-circuit-0", "campaign-circuit-2"}
330+
331+
data = assert_request(
332+
client.get,
333+
url=ROUTE,
334+
params={"circuit__build_category": CircuitBuildCategory.em_reconstruction},
335+
).json()["data"]
336+
337+
assert len(data) == 1
338+
assert data[0]["name"] == "campaign-circuit-1"
339+
340+
341+
def test_filter_by_circuit_build_category_in(
342+
client,
343+
campaigns_with_different_circuits, # noqa: ARG001
344+
multiple_circuits, # noqa: ARG001
345+
):
346+
data = assert_request(
347+
client.get,
348+
url=ROUTE,
349+
params={
350+
"circuit__build_category__in": [
351+
CircuitBuildCategory.computational_model,
352+
CircuitBuildCategory.em_reconstruction,
353+
],
354+
},
355+
).json()["data"]
356+
357+
assert len(data) == 3
358+
campaign_names = {campaign["name"] for campaign in data}
359+
assert campaign_names == {"campaign-circuit-0", "campaign-circuit-1", "campaign-circuit-2"}

0 commit comments

Comments
 (0)