Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions alembic/versions/20250825_123447_7aa80d34dbdd_add_pgvector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""add pgvector

Revision ID: 7aa80d34dbdd
Revises: 02b804d687ee
Create Date: 2025-08-25 12:34:47.832367

"""

from typing import Sequence, Union
import os
import random

import openai
from alembic import op
from pgvector.sqlalchemy import Vector
import sqlalchemy as sa
from sqlalchemy import text

# revision identifiers, used by Alembic.
revision: str = "7aa80d34dbdd"
down_revision: Union[str, None] = "02b804d687ee"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def generate_embeddings_for_existing_data():
"""Generate embeddings for existing brain_region, species, and strain data."""
# Get connection
connection = op.get_bind()

# Collect all entity data
all_entities = []

# Collect brain region data
brain_regions = connection.execute(text("SELECT id, name FROM brain_region")).fetchall()
for brain_region in brain_regions:
all_entities.append(("brain_region", brain_region.id, brain_region.name))

# Collect species data
species = connection.execute(text("SELECT id, name FROM species")).fetchall()
for sp in species:
all_entities.append(("species", sp.id, sp.name))

# Collect strain data
strains = connection.execute(text("SELECT id, name FROM strain")).fetchall()
for strain in strains:
all_entities.append(("strain", strain.id, strain.name))

# Generate embeddings based on available API key
api_key = os.getenv("OPENAI_API_KEY")

if api_key:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The migration should fail if executed in staging or production and the api key isn't defined by mistake.
There is an env variable ENVIRONMENT that represents if the image is built for prod or dev, so it's not exactly the same as the running environment, but it should be enough to check that.

# Use OpenAI API for real embeddings
client = openai.OpenAI(api_key=api_key)

# Generate all embeddings in a single API call
names = [entity[2] for entity in all_entities]
response = client.embeddings.create(model="text-embedding-3-small", input=names)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For confirmation, is this request below the limits indicated in https://platform.openai.com/docs/api-reference/embeddings/create?


# Extract embeddings from response
embeddings = [embedding.embedding for embedding in response.data]
else:
# Use random vectors when OpenAI key is not provided
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want random vectors? Since it's nullable, can't null be used?
How does the search work w/ null values?

Copy link
Author

@jankrepl jankrepl Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. We avoided making the column nullable because IMO the documentation does not really document the behavior that well. There is some mention in the README.md https://github.com/pgvector/pgvector but it is in the section on an indexing algorithm HNSW (that we don't use).

Essentially, we wanted to avoid having to append queries with WHERE embedding IS NOT NULL. Anyway, if you find some good resource that documents the behavior then I would be happy to change it.

embeddings = []
for _ in all_entities:
random_embedding = [random.random() for _ in range(1536)]
embeddings.append(random_embedding)

# Update database with generated embeddings (shared logic)
for (table_name, entity_id, _), embedding in zip(all_entities, embeddings):
# Convert embedding to string format for pgvector
embedding_str = str(embedding)
connection.execute(
text(f"UPDATE {table_name} SET embedding = :embedding WHERE id = :id"),
{"embedding": embedding_str, "id": entity_id},
)


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Enable the pgvector extension
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The manual query to create the extension doesn't seem to work well with alembic automatic migration: running make migration on the up to date branch would create a migration file that reverts the change, trying to remove the extension.

To fix it, it's possible to add in triggers.py the following code:

entities += [
    PGExtension(schema="public", signature="vector"),
]

With this change, alembic would also be able to automatically generate the commands

    public_vector = PGExtension(schema="public", signature="vector")
    op.create_entity(public_vector)

and

    public_vector = PGExtension(schema="public", signature="vector")
    op.drop_entity(public_vector)

We could also rename triggers.py to something more generic, such as entities.py or pg_entities.py or sql_entities.py, although entity is also a class and a table so it may be a bit misleading, but I don't have a better name.


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a reminder: in staging and production there is postgresql 17.4 and the (pg)vector extension seems available.
However, before releasing and deploying we should ensure that no other actions are needed to activate the extension.

op.add_column(
"brain_region",
sa.Column("embedding", Vector(dim=1536), nullable=True),
)

op.add_column(
"species",
sa.Column("embedding", Vector(dim=1536), nullable=True),
)

op.add_column(
"strain",
sa.Column("embedding", Vector(dim=1536), nullable=True),
)

# Generate embeddings for existing data
generate_embeddings_for_existing_data()

# Now make columns non-nullable
op.alter_column("brain_region", "embedding", nullable=False)
op.alter_column("species", "embedding", nullable=False)
op.alter_column("strain", "embedding", nullable=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("strain", "embedding")
op.drop_column("species", "embedding")
op.drop_column("brain_region", "embedding")

# Disable the pgvector extension
op.execute("DROP EXTENSION IF EXISTS vector;")
# ### end Alembic commands ###
4 changes: 3 additions & 1 deletion app/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Literal

from pydantic import PostgresDsn, field_validator
from pydantic import PostgresDsn, SecretStr, field_validator
from pydantic_core.core_schema import ValidationInfo
from pydantic_settings import BaseSettings, SettingsConfigDict

Expand Down Expand Up @@ -62,6 +62,8 @@ class Settings(BaseSettings):
DB_POOL_PRE_PING: bool = False
DB_MAX_OVERFLOW: int = 0

OPENAI_API_KEY: SecretStr | None = None

@field_validator("DB_URI", mode="before")
@classmethod
def build_db_uri(cls, v: str, info: ValidationInfo) -> str:
Expand Down
4 changes: 4 additions & 0 deletions app/db/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from uuid import UUID

import sqlalchemy as sa
from pgvector.sqlalchemy import Vector
from sqlalchemy import (
BigInteger,
DateTime,
Expand Down Expand Up @@ -175,12 +176,14 @@ class BrainRegion(Identifiable):
hierarchy_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("brain_region_hierarchy.id"), index=True
)
embedding: Mapped[Vector] = mapped_column(Vector(1536), nullable=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid repeating the vector definition, we could define a mixin that can be reused also if we want to add semantic search to more tables.
It would be helpful also to add a comment to explain why the dim of the vector is 1536.



class Species(Identifiable):
__tablename__ = "species"
name: Mapped[str] = mapped_column(unique=True, index=True)
taxonomy_id: Mapped[str] = mapped_column(unique=True, index=True)
embedding: Mapped[Vector] = mapped_column(Vector(1536), nullable=False)


class Strain(Identifiable):
Expand All @@ -189,6 +192,7 @@ class Strain(Identifiable):
taxonomy_id: Mapped[str] = mapped_column(unique=True, index=True)
species_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("species.id"), index=True)
species = relationship("Species", uselist=False)
embedding: Mapped[Vector] = mapped_column(Vector(1536), nullable=False)

__table_args__ = (
# needed for the composite foreign key in SpeciesMixin
Expand Down
11 changes: 11 additions & 0 deletions app/queries/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def router_read_many[T: BaseModel, I: Identifiable]( # noqa: PLR0913
name_to_facet_query_params: dict[str, FacetQueryParams] | None,
filter_model: CustomFilter[I],
filter_joins: dict[str, ApplyOperations] | None = None,
embedding: list[float] | None = None,
) -> ListResponse[T]:
"""Read multiple models from the database.

Expand All @@ -258,6 +259,7 @@ def router_read_many[T: BaseModel, I: Identifiable]( # noqa: PLR0913
filter_joins: mapping of filter names to join functions. The keys should match both:
- the nested filters attributes, to choose which joins should be applied for filtering.
- the keys in `name_to_facet_query_params`, for retrieving the facets.
embedding: optional list of floats representing an embedding vector for semantic search.

Returns:
the list of model data, pagination, and facets as a Pydantic model.
Expand Down Expand Up @@ -291,6 +293,15 @@ def router_read_many[T: BaseModel, I: Identifiable]( # noqa: PLR0913
.limit(pagination_request.page_size)
)

# Add semantic similarity ordering if embedding is provided and model has embedding field
if embedding is not None and hasattr(db_model_class, "embedding"):
# Remove existing ordering clauses and replace with semantic similarity ordering
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As quickly discussed before, this is an important thing to agree on, because it depends on the use case for the semantic search. It means that:

  • the order of records is different (species, strain and brain_region are ordered by name by default. we don't allow ordering by other attributes, but it's technically possible)
  • in general, the user cannot order the results by other attributes. In any case with this approach it could be better to explicitly raise an error if the order_by is explicitly passed by the user together with the semantic search, instead of silently ignore it.
  • all the records are returned, even the ones with very low similarity (that will be at the bottom of the list or in the last pages)

Copy link
Author

@jankrepl jankrepl Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general, the user cannot order the results by other attributes. In any case with this approach it could be better to explicitly raise an error if the order_by is explicitly passed by the user together with the semantic search, instead of silently ignore it.

We initially wanted to raise an exception but then we realized that somehow by default the swagger UI attaches the order_by=name query param by default. That is why we decided to just ignore it in case semantic_search is provided too. Alternatively, we could just somehow make sure the order_by=name is not sent by default. Feel free to decide.

the order of records is different (species, strain and brain_region are ordered by name by default. we don't allow ordering by other attributes, but it's technically possible)

Point taken that one can provide multiple entries in the order_by. Now, there are two options

  1. semantic_search comes first. Then I would argue there is no need for other tiebreaker entries (e.g. name, created_date,...) since no two entries in the DB will have the same distance. The current PR basically implements this case since it disregards all the other order_by clauses.
  2. semantic_search does not come first. If for some reason someone wants to sort by creation day (not date) first and only then by semantic search then I do agree that our current PR does not support it. However, I would argue that it is not that useful. However, it can be implemented but maybe I would leave that for a future PR since the iteraction between the FastAPI filter library and this custom semantic_search is not obvious.

all the records are returned, even the ones with very low similarity (that will be at the bottom of the list or in the last pages)

I can make the same argument about order_by=name or any other column you order by. As discussed in person, we will use this endpoint by adding the page_size=5 and not caring about the other pages.

if getattr(data_query, "_order_by_clauses", None):
# Clear existing ordering by setting _order_by_clauses to empty tuple
data_query._order_by_clauses = () # noqa: SLF001

data_query = data_query.order_by(db_model_class.embedding.l2_distance(embedding)) # type: ignore[attr-defined]

if apply_data_query_operations:
data_query = apply_data_query_operations(data_query)

Expand Down
13 changes: 8 additions & 5 deletions app/schemas/species.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from uuid import UUID

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from pydantic.json_schema import SkipJsonSchema

from app.schemas.agent import CreatedByUpdatedByMixin
from app.schemas.base import CreationMixin, IdentifiableMixin
Expand All @@ -10,26 +11,28 @@ class SpeciesCreate(BaseModel):
model_config = ConfigDict(from_attributes=True)
name: str
taxonomy_id: str
embedding: SkipJsonSchema[list[float] | None] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not wrong the embedding attribute in the create schema is needed only because it's set in the create_one method.
Since embedding is completely internal, it seems cleaner to not even define it in the pydantic schema, and pass an additional parameter to router_create_one, similarly to what has been already done in router_read_many.
In either case, the embedding that is defined in the read schemas doesn't seem to be used, so it can be removed.



class SpeciesRead(SpeciesCreate, CreationMixin, CreatedByUpdatedByMixin, IdentifiableMixin):
pass
embedding: SkipJsonSchema[list[float] | None] = Field(default=None, exclude=True)


class NestedSpeciesRead(SpeciesCreate, IdentifiableMixin):
pass
embedding: SkipJsonSchema[list[float] | None] = Field(default=None, exclude=True)


class StrainCreate(BaseModel):
model_config = ConfigDict(from_attributes=True)
name: str
taxonomy_id: str
species_id: UUID
embedding: SkipJsonSchema[list[float] | None] = None


class StrainRead(StrainCreate, CreationMixin, CreatedByUpdatedByMixin, IdentifiableMixin):
pass
embedding: SkipJsonSchema[list[float] | None] = Field(default=None, exclude=True)


class NestedStrainRead(StrainCreate, IdentifiableMixin):
pass
embedding: SkipJsonSchema[list[float] | None] = Field(default=None, exclude=True)
8 changes: 8 additions & 0 deletions app/service/brain_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,21 @@
from app.filters.brain_region import BrainRegionFilterDep
from app.schemas.base import BrainRegionRead
from app.schemas.types import ListResponse
from app.utils.embedding import generate_embedding


def read_many(
*,
db: SessionDep,
pagination_request: PaginationQuery,
brain_region_filter: BrainRegionFilterDep,
semantic_search: str | None = None,
) -> ListResponse[BrainRegionRead]:
embedding = None

if semantic_search is not None:
embedding = generate_embedding(semantic_search)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the call to generate_embedding can be done in the service layer because it doesn't really belong to the query/db layer, but on the other side it's an internal detail, so we could pass directly semantic_search to router_read_many, and call generate_embedding there? The same can be done also for router_create_one.
This would also avoid repeating this call in each read_many endpoints where it's needed, especially, if we think to extend it to other endpoints later.


return app.queries.common.router_read_many(
db=db,
db_model_class=BrainRegion,
Expand All @@ -32,6 +39,7 @@ def read_many(
response_schema_class=BrainRegionRead,
name_to_facet_query_params=None,
filter_model=brain_region_filter,
embedding=embedding,
)


Expand Down
11 changes: 11 additions & 0 deletions app/service/species.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from app.queries.factory import query_params_factory
from app.schemas.species import SpeciesCreate, SpeciesRead
from app.schemas.types import ListResponse
from app.utils.embedding import generate_embedding


def _load(query: sa.Select):
Expand Down Expand Up @@ -43,6 +44,9 @@ def create_one(
species: SpeciesCreate,
user_context: AdminContextDep,
) -> SpeciesRead:
# Generate embedding using OpenAI API
species.embedding = generate_embedding(species.name)

return app.queries.common.router_create_one(
db=db,
db_model_class=Species,
Expand All @@ -58,7 +62,13 @@ def read_many(
db: SessionDep,
pagination_request: PaginationQuery,
species_filter: SpeciesFilterDep,
semantic_search: str | None = None,
) -> ListResponse[SpeciesRead]:
embedding = None

if semantic_search is not None:
embedding = generate_embedding(semantic_search)

facet_keys = filter_keys = [
"created_by",
"updated_by",
Expand All @@ -84,4 +94,5 @@ def read_many(
name_to_facet_query_params=name_to_facet_query_params,
filter_model=species_filter,
filter_joins=filter_joins,
embedding=embedding,
)
12 changes: 12 additions & 0 deletions app/service/strain.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from app.queries.factory import query_params_factory
from app.schemas.species import StrainCreate, StrainRead
from app.schemas.types import ListResponse
from app.utils.embedding import generate_embedding


def _load(query: sa.Select):
Expand All @@ -23,10 +24,17 @@ def _load(query: sa.Select):


def read_many(
*,
db: SessionDep,
pagination_request: PaginationQuery,
strain_filter: StrainFilterDep,
semantic_search: str | None = None,
) -> ListResponse[StrainRead]:
embedding = None

if semantic_search is not None:
embedding = generate_embedding(semantic_search)

facet_keys = filter_keys = [
"created_by",
"updated_by",
Expand All @@ -52,6 +60,7 @@ def read_many(
name_to_facet_query_params=name_to_facet_query_params,
filter_model=strain_filter,
filter_joins=filter_joins,
embedding=embedding,
)


Expand All @@ -69,6 +78,9 @@ def read_one(id_: uuid.UUID, db: SessionDep) -> StrainRead:
def create_one(
json_model: StrainCreate, db: SessionDep, user_context: AdminContextDep
) -> StrainRead:
# Generate embedding using OpenAI API
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor, this comment has been removed in read_many but not in the create_ones

json_model.embedding = generate_embedding(json_model.name)

return app.queries.common.router_create_one(
db=db,
db_model_class=Strain,
Expand Down
32 changes: 32 additions & 0 deletions app/utils/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Utility functions for generating embeddings using OpenAI API."""

import openai

from app.config import settings


def generate_embedding(text: str, model: str = "text-embedding-3-small") -> list[float]:
"""Generate an embedding for the given text using OpenAI API.

Args:
text: The text to generate an embedding for
model: The OpenAI embedding model to use (default: text-embedding-3-small)

Returns:
A list of floats representing the embedding vector

Raises:
ValueError: If OpenAI API key is not configured
"""
if settings.OPENAI_API_KEY is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to check the key every time; having it in the config will be enough.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was mostly for the type checker so that we can call get_secret_value().

Note that we want want entitycore to be runnable (locally) without that env variable (of course the semantic search features would not work in that case)

message = "OpenAI API key is not configured."
raise ValueError(message)

openai_api_key = settings.OPENAI_API_KEY.get_secret_value()

# Generate embedding using OpenAI API
client = openai.OpenAI(api_key=openai_api_key)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there alternatives instead of calling the OpenAI API? Is there a rate limiting?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. There are many alternatives (other API providers and even self hosting custom embedding models). However, since neuroagent uses OpenAI in production we thought about making things simple for now.

However, the clear downside is that entitycore will make calls to a 3rd party API from now on that requires an api key .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a rate limiting?

I think this is an important point - do we have any idea if they have an SLA?
I see https://status.openai.com/ but I didn't find latency numbers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a rate limiting?

I think this is an important point - do we have any idea if they have an SLA?
I see https://status.openai.com/ but I didn't find latency numbers.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. There is rate limiting. Check https://platform.openai.com/docs/guides/rate-limits/usage-tiers?context=tier-one

The reason why we did not implement any retry logic or some exception handling:

  • GET - semantic_search is optional so the default behavior is not to use semantic search
  • in POST - we assume that the three entities this PR is concerned with (species, strain and brain region) won't change that much (or at all)

However, if you want we can write some extra logic.

response = client.embeddings.create(model=model, input=text)

# Return the generated embedding
return response.data[0].embedding
Copy link
Collaborator

@GianlucaFicarelli GianlucaFicarelli Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function would propagate any error (such as missing api key, errors with the external api) and cause a generic error 500, that would require someone to check the logs on the server side.
Instead of propagating the errors, we can improve the error handling, but I'm ok also if we do that in a separate PR.

2 changes: 1 addition & 1 deletion docker-compose.run.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ services:

db:
profiles: [run]
image: postgres:17-alpine
image: pgvector/pgvector:0.8.0-pg17-trixie
environment:
- POSTGRES_USER=entitycore
- POSTGRES_PASSWORD=entitycore
Expand Down
Loading