Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion app/filters/circuit.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
import uuid
from typing import Annotated

from fastapi_filter import with_prefix

from app.db.model import Circuit
from app.dependencies.filter import FilterDepends
from app.filters.common import NameFilterMixin
from app.filters.base import CustomFilter
from app.filters.common import IdFilterMixin, NameFilterMixin
from app.filters.scientific_artifact import ScientificArtifactFilter


class NestedCircuitFilter(IdFilterMixin, NameFilterMixin, CustomFilter):
"""Circuit filter with limited fields for nesting."""

scale: str | None = None
scale__in: list[str] | None = None

build_category: str | None = None
build_category__in: list[str] | None = None

class Constants(CustomFilter.Constants):
model = Circuit


class CircuitFilter(ScientificArtifactFilter, NameFilterMixin):
atlas_id: uuid.UUID | None = None
root_circuit_id: uuid.UUID | None = None
Expand Down Expand Up @@ -39,3 +55,4 @@ class Constants(ScientificArtifactFilter.Constants):


CircuitFilterDep = Annotated[CircuitFilter, FilterDepends(CircuitFilter)]
NestedCircuitFilterDep = FilterDepends(with_prefix("circuit", NestedCircuitFilter))
2 changes: 2 additions & 0 deletions app/filters/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from app.db.model import Simulation
from app.dependencies.filter import FilterDepends
from app.filters.base import CustomFilter
from app.filters.circuit import NestedCircuitFilter, NestedCircuitFilterDep
from app.filters.common import (
ContributionFilterMixin,
CreationFilterMixin,
Expand All @@ -18,6 +19,7 @@
class SimulationFilterBase(NameFilterMixin, IdFilterMixin, CustomFilter):
entity_id: uuid.UUID | None = None
entity_id__in: list[uuid.UUID] | None = None
circuit: Annotated[NestedCircuitFilter | None, NestedCircuitFilterDep] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I am ok with this workaround, I am afraid that someone will decide to set a different entity in entity_id breaking this endpoint. On the other hand we control what is added in Simulation's entity_id, so we just need to be careful. However, people forget..

@GianlucaFicarelli wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can modify the filter when we'll allow different types of entities, but do we already know which ones to expect? Not sure how easy is to use a sort of polymorphism here, or if we need different filters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDK, probably other simulations (single neuron, synaptome) will be using the same logic in the future



class NestedSimulationFilter(SimulationFilterBase):
Expand Down
6 changes: 6 additions & 0 deletions app/queries/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from app.db.model import (
Agent,
BrainRegion,
Circuit,
Contribution,
EModel,
Entity,
Expand Down Expand Up @@ -72,6 +73,7 @@ def _get_alias[T: type[Identifiable]](db_cls: T, name: str | None = None) -> T:
simulation_alias = _get_alias(Simulation)
used_alias = _get_alias(Entity, "used")
generated_alias = _get_alias(Entity, "generated")
circuit_alias = _get_alias(Circuit)

name_to_facet_query_params: dict[str, FacetQueryParams] = {
"agent": {
Expand Down Expand Up @@ -111,6 +113,7 @@ def _get_alias[T: type[Identifiable]](db_cls: T, name: str | None = None) -> T:
"pre_region": {"id": pre_region_alias.id, "label": pre_region_alias.name},
"post_region": {"id": post_region_alias.id, "label": post_region_alias.name},
"simulation": {"id": simulation_alias.id, "label": simulation_alias.name},
"simulation.circuit": {"id": circuit_alias.id, "label": circuit_alias.name},
}
filter_joins = {
"species": lambda q: q.join(Species, db_model_class.species_id == Species.id),
Expand Down Expand Up @@ -180,6 +183,9 @@ def _get_alias[T: type[Identifiable]](db_cls: T, name: str | None = None) -> T:
"simulation": lambda q: q.outerjoin(
simulation_alias, db_model_class.id == simulation_alias.simulation_campaign_id
),
"simulation.circuit": lambda q: q.join(
circuit_alias, simulation_alias.entity_id == circuit_alias.id
),
"used": lambda q: q.outerjoin(
Usage, db_model_class.id == Usage.usage_activity_id
).outerjoin(used_alias, Usage.usage_entity_id == used_alias.id),
Expand Down
5 changes: 4 additions & 1 deletion app/service/simulation_campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from app.db.model import (
Agent,
Circuit,
Simulation,
SimulationCampaign,
)
Expand Down Expand Up @@ -84,20 +85,22 @@ def read_many(
created_by_alias = aliased(Agent, flat=True)
updated_by_alias = aliased(Agent, flat=True)
simulation_alias = aliased(Simulation, flat=True)

circuit_alias = aliased(Circuit, flat=True)
aliases: Aliases = {
Agent: {
"contribution": agent_alias,
"created_by": created_by_alias,
"updated_by": updated_by_alias,
},
Simulation: simulation_alias,
Circuit: circuit_alias,
}
facet_keys = filter_keys = [
"created_by",
"updated_by",
"contribution",
"simulation",
"simulation.circuit",
]
name_to_facet_query_params, filter_joins = query_params_factory(
db_model_class=SimulationCampaign,
Expand Down
234 changes: 233 additions & 1 deletion tests/test_simulation_campaign.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from app.db.model import Simulation, SimulationCampaign
from app.db.model import Circuit, CircuitBuildCategory, CircuitScale, Simulation, SimulationCampaign
from app.db.types import EntityType

from .utils import (
Expand Down Expand Up @@ -125,3 +125,235 @@ def test_filtering(client, models, simulation_json_data, person_id):
client.get, url=ROUTE, params={"simulation__name": simulation_json_data["name"]}
).json()["data"]
assert len(data) == 3


@pytest.fixture
def multiple_circuits(db, brain_atlas_id, subject_id, brain_region_id, license_id, person_id):
circuits_data = [
{
"name": "micro-circuit-1",
"description": "Micro Circuit 1",
"has_morphologies": True,
"has_point_neurons": False,
"has_electrical_cell_models": True,
"has_spines": False,
"number_neurons": 100,
"number_synapses": 1000,
"number_connections": 50,
"scale": CircuitScale.microcircuit,
"build_category": CircuitBuildCategory.computational_model,
"atlas_id": brain_atlas_id,
"subject_id": subject_id,
"brain_region_id": brain_region_id,
"license_id": license_id,
"authorized_public": False,
"created_by_id": person_id,
"updated_by_id": person_id,
"authorized_project_id": PROJECT_ID,
},
{
"name": "micro-circuit-2",
"description": "Micro Circuit 2",
"has_morphologies": False,
"has_point_neurons": True,
"has_electrical_cell_models": False,
"has_spines": True,
"number_neurons": 1000,
"number_synapses": 10000,
"number_connections": 200,
"scale": CircuitScale.microcircuit,
"build_category": CircuitBuildCategory.em_reconstruction,
"atlas_id": brain_atlas_id,
"subject_id": subject_id,
"brain_region_id": brain_region_id,
"license_id": license_id,
"authorized_public": False,
"created_by_id": person_id,
"updated_by_id": person_id,
"authorized_project_id": PROJECT_ID,
},
{
"name": "pair-circuit-1",
"description": "Pair Circuit 1",
"has_morphologies": True,
"has_point_neurons": True,
"has_electrical_cell_models": True,
"has_spines": True,
"number_neurons": 10000,
"number_synapses": 100000,
"number_connections": 500,
"scale": CircuitScale.pair,
"build_category": CircuitBuildCategory.computational_model,
"atlas_id": brain_atlas_id,
"subject_id": subject_id,
"brain_region_id": brain_region_id,
"license_id": license_id,
"authorized_public": False,
"created_by_id": person_id,
"updated_by_id": person_id,
"authorized_project_id": PROJECT_ID,
},
]

circuits = [add_db(db, Circuit(**circuit_data)) for circuit_data in circuits_data]
return circuits


@pytest.fixture
def campaigns_with_different_circuits(
db, json_data, person_id, simulation_json_data, multiple_circuits
):
campaigns = []

for i, circuit in enumerate(multiple_circuits):
campaign = add_db(
db,
MODEL(
**(
json_data
| {
"name": f"campaign-circuit-{i}",
"description": f"Campaign for circuit {i}",
"created_by_id": person_id,
"updated_by_id": person_id,
"authorized_project_id": PROJECT_ID,
}
)
),
)
campaigns.append(campaign)

add_db(
db,
Simulation(
**simulation_json_data
| {
"name": f"simulation-circuit-{i}",
"simulation_campaign_id": campaign.id,
"entity_id": circuit.id,
"created_by_id": person_id,
"updated_by_id": person_id,
"authorized_project_id": PROJECT_ID,
}
),
)

return campaigns


def test_filter_by_circuit_id(client, campaigns_with_different_circuits, multiple_circuits): # noqa: ARG001
first_circuit_id = str(multiple_circuits[0].id)
data = assert_request(client.get, url=ROUTE, params={"circuit__id": first_circuit_id}).json()[
"data"
]

assert len(data) == 1
assert data[0]["name"] == "campaign-circuit-0"

second_circuit_id = str(multiple_circuits[1].id)
data = assert_request(client.get, url=ROUTE, params={"circuit__id": second_circuit_id}).json()[
"data"
]

assert len(data) == 1
assert data[0]["name"] == "campaign-circuit-1"


def test_filter_by_circuit_name(client, campaigns_with_different_circuits, multiple_circuits): # noqa: ARG001
data = assert_request(
client.get, url=ROUTE, params={"circuit__name": "micro-circuit-1"}
).json()["data"]

assert len(data) == 1
assert data[0]["name"] == "campaign-circuit-0"

data = assert_request(
client.get, url=ROUTE, params={"circuit__name__in": "micro-circuit-2"}
).json()["data"]

assert len(data) == 1
assert data[0]["name"] == "campaign-circuit-1"


def test_filter_by_circuit_scale(client, campaigns_with_different_circuits, multiple_circuits): # noqa: ARG001
data = assert_request(
client.get, url=ROUTE, params={"circuit__scale": CircuitScale.microcircuit}
).json()["data"]

assert len(data) == 2

data = assert_request(
client.get, url=ROUTE, params={"circuit__scale": CircuitScale.pair}
).json()["data"]

assert len(data) == 1


def test_filter_by_circuit_scale_empty(
client,
campaigns_with_different_circuits, # noqa: ARG001
multiple_circuits, # noqa: ARG001
):
data = assert_request(
client.get, url=ROUTE, params={"circuit__scale": CircuitScale.small}
).json()["data"]

assert len(data) == 0


def test_filter_by_circuit_scale_in(client, campaigns_with_different_circuits, multiple_circuits): # noqa: ARG001
data = assert_request(
client.get,
url=ROUTE,
params={"circuit__scale__in": [CircuitScale.microcircuit, CircuitScale.pair]},
).json()["data"]

assert len(data) == 3
campaign_names = {campaign["name"] for campaign in data}
assert campaign_names == {"campaign-circuit-0", "campaign-circuit-1", "campaign-circuit-2"}


def test_filter_by_circuit_build_category(
client,
campaigns_with_different_circuits, # noqa: ARG001
multiple_circuits, # noqa: ARG001
):
data = assert_request(
client.get,
url=ROUTE,
params={"circuit__build_category": CircuitBuildCategory.computational_model},
).json()["data"]

assert len(data) == 2
campaign_names = {campaign["name"] for campaign in data}
assert campaign_names == {"campaign-circuit-0", "campaign-circuit-2"}

data = assert_request(
client.get,
url=ROUTE,
params={"circuit__build_category": CircuitBuildCategory.em_reconstruction},
).json()["data"]

assert len(data) == 1
assert data[0]["name"] == "campaign-circuit-1"


def test_filter_by_circuit_build_category_in(
client,
campaigns_with_different_circuits, # noqa: ARG001
multiple_circuits, # noqa: ARG001
):
data = assert_request(
client.get,
url=ROUTE,
params={
"circuit__build_category__in": [
CircuitBuildCategory.computational_model,
CircuitBuildCategory.em_reconstruction,
],
},
).json()["data"]

assert len(data) == 3
campaign_names = {campaign["name"] for campaign in data}
assert campaign_names == {"campaign-circuit-0", "campaign-circuit-1", "campaign-circuit-2"}