-
Notifications
You must be signed in to change notification settings - Fork 0
Semantic search with pgvector #340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d463985
2237610
436dd42
afbc0a6
59e9a73
3ce692b
116c3b8
d427695
e90a7d8
91984c2
6436a69
560a222
954f239
d99598c
46722ea
c1b3050
0da2118
a8c4ac3
1ec4124
45cad24
80b2903
a6fc299
da4645f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
"""add pgvector | ||
|
||
Revision ID: 7aa80d34dbdd | ||
Revises: 02b804d687ee | ||
Create Date: 2025-08-25 12:34:47.832367 | ||
|
||
""" | ||
|
||
from typing import Sequence, Union | ||
import os | ||
import random | ||
|
||
import openai | ||
from alembic import op | ||
from pgvector.sqlalchemy import Vector | ||
import sqlalchemy as sa | ||
from sqlalchemy import text | ||
|
||
# revision identifiers, used by Alembic. | ||
revision: str = "7aa80d34dbdd" | ||
down_revision: Union[str, None] = "02b804d687ee" | ||
branch_labels: Union[str, Sequence[str], None] = None | ||
depends_on: Union[str, Sequence[str], None] = None | ||
|
||
|
||
def generate_embeddings_for_existing_data(): | ||
"""Generate embeddings for existing brain_region, species, and strain data.""" | ||
# Get connection | ||
connection = op.get_bind() | ||
|
||
# Collect all entity data | ||
all_entities = [] | ||
|
||
# Collect brain region data | ||
brain_regions = connection.execute(text("SELECT id, name FROM brain_region")).fetchall() | ||
for brain_region in brain_regions: | ||
all_entities.append(("brain_region", brain_region.id, brain_region.name)) | ||
|
||
# Collect species data | ||
species = connection.execute(text("SELECT id, name FROM species")).fetchall() | ||
for sp in species: | ||
all_entities.append(("species", sp.id, sp.name)) | ||
|
||
# Collect strain data | ||
strains = connection.execute(text("SELECT id, name FROM strain")).fetchall() | ||
for strain in strains: | ||
all_entities.append(("strain", strain.id, strain.name)) | ||
|
||
# Generate embeddings based on available API key | ||
api_key = os.getenv("OPENAI_API_KEY") | ||
|
||
if api_key: | ||
# Use OpenAI API for real embeddings | ||
client = openai.OpenAI(api_key=api_key) | ||
|
||
# Generate all embeddings in a single API call | ||
names = [entity[2] for entity in all_entities] | ||
response = client.embeddings.create(model="text-embedding-3-small", input=names) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For confirmation, is this request below the limits indicated in https://platform.openai.com/docs/api-reference/embeddings/create? |
||
|
||
# Extract embeddings from response | ||
embeddings = [embedding.embedding for embedding in response.data] | ||
else: | ||
# Use random vectors when OpenAI key is not provided | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really want random vectors? Since it's nullable, can't null be used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair point. We avoided making the column nullable because IMO the documentation does not really document the behavior that well. There is some mention in the README.md https://github.com/pgvector/pgvector but it is in the section on an indexing algorithm HNSW (that we don't use). Essentially, we wanted to avoid having to append queries with |
||
embeddings = [] | ||
for _ in all_entities: | ||
random_embedding = [random.random() for _ in range(1536)] | ||
embeddings.append(random_embedding) | ||
|
||
# Update database with generated embeddings (shared logic) | ||
for (table_name, entity_id, _), embedding in zip(all_entities, embeddings): | ||
# Convert embedding to string format for pgvector | ||
embedding_str = str(embedding) | ||
connection.execute( | ||
text(f"UPDATE {table_name} SET embedding = :embedding WHERE id = :id"), | ||
{"embedding": embedding_str, "id": entity_id}, | ||
) | ||
|
||
|
||
def upgrade() -> None: | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
# Enable the pgvector extension | ||
op.execute("CREATE EXTENSION IF NOT EXISTS vector;") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The manual query to create the extension doesn't seem to work well with alembic automatic migration: running To fix it, it's possible to add in triggers.py the following code: entities += [
PGExtension(schema="public", signature="vector"),
] With this change, alembic would also be able to automatically generate the commands public_vector = PGExtension(schema="public", signature="vector")
op.create_entity(public_vector) and public_vector = PGExtension(schema="public", signature="vector")
op.drop_entity(public_vector) We could also rename triggers.py to something more generic, such as entities.py or pg_entities.py or sql_entities.py, although entity is also a class and a table so it may be a bit misleading, but I don't have a better name. |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As a reminder: in staging and production there is postgresql 17.4 and the (pg)vector extension seems available. |
||
op.add_column( | ||
"brain_region", | ||
sa.Column("embedding", Vector(dim=1536), nullable=True), | ||
) | ||
|
||
op.add_column( | ||
"species", | ||
sa.Column("embedding", Vector(dim=1536), nullable=True), | ||
) | ||
|
||
op.add_column( | ||
"strain", | ||
sa.Column("embedding", Vector(dim=1536), nullable=True), | ||
) | ||
|
||
# Generate embeddings for existing data | ||
generate_embeddings_for_existing_data() | ||
|
||
# Now make columns non-nullable | ||
op.alter_column("brain_region", "embedding", nullable=False) | ||
op.alter_column("species", "embedding", nullable=False) | ||
op.alter_column("strain", "embedding", nullable=False) | ||
# ### end Alembic commands ### | ||
|
||
|
||
def downgrade() -> None: | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.drop_column("strain", "embedding") | ||
op.drop_column("species", "embedding") | ||
op.drop_column("brain_region", "embedding") | ||
|
||
# Disable the pgvector extension | ||
op.execute("DROP EXTENSION IF EXISTS vector;") | ||
# ### end Alembic commands ### |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
from uuid import UUID | ||
|
||
import sqlalchemy as sa | ||
from pgvector.sqlalchemy import Vector | ||
from sqlalchemy import ( | ||
BigInteger, | ||
DateTime, | ||
|
@@ -175,12 +176,14 @@ class BrainRegion(Identifiable): | |
hierarchy_id: Mapped[uuid.UUID] = mapped_column( | ||
ForeignKey("brain_region_hierarchy.id"), index=True | ||
) | ||
embedding: Mapped[Vector] = mapped_column(Vector(1536), nullable=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To avoid repeating the vector definition, we could define a mixin that can be reused also if we want to add semantic search to more tables. |
||
|
||
|
||
class Species(Identifiable): | ||
__tablename__ = "species" | ||
name: Mapped[str] = mapped_column(unique=True, index=True) | ||
taxonomy_id: Mapped[str] = mapped_column(unique=True, index=True) | ||
embedding: Mapped[Vector] = mapped_column(Vector(1536), nullable=False) | ||
|
||
|
||
class Strain(Identifiable): | ||
|
@@ -189,6 +192,7 @@ class Strain(Identifiable): | |
taxonomy_id: Mapped[str] = mapped_column(unique=True, index=True) | ||
species_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("species.id"), index=True) | ||
species = relationship("Species", uselist=False) | ||
embedding: Mapped[Vector] = mapped_column(Vector(1536), nullable=False) | ||
|
||
__table_args__ = ( | ||
# needed for the composite foreign key in SpeciesMixin | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -238,6 +238,7 @@ def router_read_many[T: BaseModel, I: Identifiable]( # noqa: PLR0913 | |
name_to_facet_query_params: dict[str, FacetQueryParams] | None, | ||
filter_model: CustomFilter[I], | ||
filter_joins: dict[str, ApplyOperations] | None = None, | ||
embedding: list[float] | None = None, | ||
) -> ListResponse[T]: | ||
"""Read multiple models from the database. | ||
|
||
|
@@ -258,6 +259,7 @@ def router_read_many[T: BaseModel, I: Identifiable]( # noqa: PLR0913 | |
filter_joins: mapping of filter names to join functions. The keys should match both: | ||
- the nested filters attributes, to choose which joins should be applied for filtering. | ||
- the keys in `name_to_facet_query_params`, for retrieving the facets. | ||
embedding: optional list of floats representing an embedding vector for semantic search. | ||
|
||
Returns: | ||
the list of model data, pagination, and facets as a Pydantic model. | ||
|
@@ -291,6 +293,15 @@ def router_read_many[T: BaseModel, I: Identifiable]( # noqa: PLR0913 | |
.limit(pagination_request.page_size) | ||
) | ||
|
||
# Add semantic similarity ordering if embedding is provided and model has embedding field | ||
if embedding is not None and hasattr(db_model_class, "embedding"): | ||
# Remove existing ordering clauses and replace with semantic similarity ordering | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As quickly discussed before, this is an important thing to agree on, because it depends on the use case for the semantic search. It means that:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We initially wanted to raise an exception but then we realized that somehow by default the swagger UI attaches the
Point taken that one can provide multiple entries in the
I can make the same argument about |
||
if getattr(data_query, "_order_by_clauses", None): | ||
# Clear existing ordering by setting _order_by_clauses to empty tuple | ||
data_query._order_by_clauses = () # noqa: SLF001 | ||
|
||
data_query = data_query.order_by(db_model_class.embedding.l2_distance(embedding)) # type: ignore[attr-defined] | ||
|
||
if apply_data_query_operations: | ||
data_query = apply_data_query_operations(data_query) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from uuid import UUID | ||
|
||
from pydantic import BaseModel, ConfigDict | ||
from pydantic import BaseModel, ConfigDict, Field | ||
from pydantic.json_schema import SkipJsonSchema | ||
|
||
from app.schemas.agent import CreatedByUpdatedByMixin | ||
from app.schemas.base import CreationMixin, IdentifiableMixin | ||
|
@@ -10,26 +11,28 @@ class SpeciesCreate(BaseModel): | |
model_config = ConfigDict(from_attributes=True) | ||
name: str | ||
taxonomy_id: str | ||
embedding: SkipJsonSchema[list[float] | None] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I'm not wrong the |
||
|
||
|
||
class SpeciesRead(SpeciesCreate, CreationMixin, CreatedByUpdatedByMixin, IdentifiableMixin): | ||
pass | ||
embedding: SkipJsonSchema[list[float] | None] = Field(default=None, exclude=True) | ||
|
||
|
||
class NestedSpeciesRead(SpeciesCreate, IdentifiableMixin): | ||
pass | ||
embedding: SkipJsonSchema[list[float] | None] = Field(default=None, exclude=True) | ||
|
||
|
||
class StrainCreate(BaseModel): | ||
model_config = ConfigDict(from_attributes=True) | ||
name: str | ||
taxonomy_id: str | ||
species_id: UUID | ||
embedding: SkipJsonSchema[list[float] | None] = None | ||
|
||
|
||
class StrainRead(StrainCreate, CreationMixin, CreatedByUpdatedByMixin, IdentifiableMixin): | ||
pass | ||
embedding: SkipJsonSchema[list[float] | None] = Field(default=None, exclude=True) | ||
|
||
|
||
class NestedStrainRead(StrainCreate, IdentifiableMixin): | ||
pass | ||
embedding: SkipJsonSchema[list[float] | None] = Field(default=None, exclude=True) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,14 +10,21 @@ | |
from app.filters.brain_region import BrainRegionFilterDep | ||
from app.schemas.base import BrainRegionRead | ||
from app.schemas.types import ListResponse | ||
from app.utils.embedding import generate_embedding | ||
|
||
|
||
def read_many( | ||
*, | ||
db: SessionDep, | ||
pagination_request: PaginationQuery, | ||
brain_region_filter: BrainRegionFilterDep, | ||
semantic_search: str | None = None, | ||
) -> ListResponse[BrainRegionRead]: | ||
embedding = None | ||
|
||
if semantic_search is not None: | ||
embedding = generate_embedding(semantic_search) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that the call to |
||
|
||
return app.queries.common.router_read_many( | ||
db=db, | ||
db_model_class=BrainRegion, | ||
|
@@ -32,6 +39,7 @@ def read_many( | |
response_schema_class=BrainRegionRead, | ||
name_to_facet_query_params=None, | ||
filter_model=brain_region_filter, | ||
embedding=embedding, | ||
) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
from app.queries.factory import query_params_factory | ||
from app.schemas.species import StrainCreate, StrainRead | ||
from app.schemas.types import ListResponse | ||
from app.utils.embedding import generate_embedding | ||
|
||
|
||
def _load(query: sa.Select): | ||
|
@@ -23,10 +24,17 @@ def _load(query: sa.Select): | |
|
||
|
||
def read_many( | ||
*, | ||
db: SessionDep, | ||
pagination_request: PaginationQuery, | ||
strain_filter: StrainFilterDep, | ||
semantic_search: str | None = None, | ||
) -> ListResponse[StrainRead]: | ||
embedding = None | ||
|
||
if semantic_search is not None: | ||
embedding = generate_embedding(semantic_search) | ||
|
||
facet_keys = filter_keys = [ | ||
"created_by", | ||
"updated_by", | ||
|
@@ -52,6 +60,7 @@ def read_many( | |
name_to_facet_query_params=name_to_facet_query_params, | ||
filter_model=strain_filter, | ||
filter_joins=filter_joins, | ||
embedding=embedding, | ||
) | ||
|
||
|
||
|
@@ -69,6 +78,9 @@ def read_one(id_: uuid.UUID, db: SessionDep) -> StrainRead: | |
def create_one( | ||
json_model: StrainCreate, db: SessionDep, user_context: AdminContextDep | ||
) -> StrainRead: | ||
# Generate embedding using OpenAI API | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor, this comment has been removed in read_many but not in the create_ones |
||
json_model.embedding = generate_embedding(json_model.name) | ||
|
||
return app.queries.common.router_create_one( | ||
db=db, | ||
db_model_class=Strain, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
"""Utility functions for generating embeddings using OpenAI API.""" | ||
|
||
import openai | ||
|
||
from app.config import settings | ||
|
||
|
||
def generate_embedding(text: str, model: str = "text-embedding-3-small") -> list[float]: | ||
"""Generate an embedding for the given text using OpenAI API. | ||
|
||
Args: | ||
text: The text to generate an embedding for | ||
model: The OpenAI embedding model to use (default: text-embedding-3-small) | ||
|
||
Returns: | ||
A list of floats representing the embedding vector | ||
|
||
Raises: | ||
ValueError: If OpenAI API key is not configured | ||
""" | ||
if settings.OPENAI_API_KEY is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to check the key every time; having it in the config will be enough. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was mostly for the type checker so that we can call Note that we want want entitycore to be runnable (locally) without that env variable (of course the semantic search features would not work in that case) |
||
message = "OpenAI API key is not configured." | ||
raise ValueError(message) | ||
|
||
openai_api_key = settings.OPENAI_API_KEY.get_secret_value() | ||
|
||
# Generate embedding using OpenAI API | ||
client = openai.OpenAI(api_key=openai_api_key) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there alternatives instead of calling the OpenAI API? Is there a rate limiting? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. There are many alternatives (other API providers and even self hosting custom embedding models). However, since However, the clear downside is that entitycore will make calls to a 3rd party API from now on that requires an api key . There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think this is an important point - do we have any idea if they have an SLA? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think this is an important point - do we have any idea if they have an SLA? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. There is rate limiting. Check https://platform.openai.com/docs/guides/rate-limits/usage-tiers?context=tier-one The reason why we did not implement any retry logic or some exception handling:
However, if you want we can write some extra logic. |
||
response = client.embeddings.create(model=model, input=text) | ||
|
||
# Return the generated embedding | ||
return response.data[0].embedding | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function would propagate any error (such as missing api key, errors with the external api) and cause a generic error 500, that would require someone to check the logs on the server side. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The migration should fail if executed in staging or production and the api key isn't defined by mistake.
There is an env variable
ENVIRONMENT
that represents if the image is built forprod
ordev
, so it's not exactly the same as the running environment, but it should be enough to check that.