Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,9 @@ repos:
- id: isort
name: "Sort Imports"
args: ["--profile", "black"]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.5
hooks:
# Run the linter.
- id: ruff
9 changes: 9 additions & 0 deletions engine/base_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,12 @@

class IncompatibilityError(Exception):
pass


__all__ = [
"BaseClient",
"BaseConfigurator",
"BaseSearcher",
"BaseUploader",
"IncompatibilityError",
]
1 change: 0 additions & 1 deletion engine/base_client/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import os
from datetime import datetime
from pathlib import Path
from typing import List

from benchmark import ROOT_DIR
Expand Down
6 changes: 2 additions & 4 deletions engine/base_client/upload.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from multiprocessing import get_context
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List

import tqdm

Expand Down Expand Up @@ -90,9 +90,7 @@ def post_upload(cls, distance):
return {}

@classmethod
def upload_batch(
cls, ids: List[int], vectors: List[list], metadata: List[Optional[dict]]
):
def upload_batch(cls, batch: List[Record]):
raise NotImplementedError()

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions engine/base_client/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Iterable
from typing import Iterable, List

from dataset_reader.base_reader import Record


def iter_batches(records: Iterable[Record], n: int) -> Iterable[Record]:
def iter_batches(records: Iterable[Record], n: int) -> Iterable[List[Record]]:
batch = []

for record in records:
Expand Down
6 changes: 6 additions & 0 deletions engine/clients/elasticsearch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from engine.clients.elasticsearch.configure import ElasticConfigurator
from engine.clients.elasticsearch.search import ElasticSearcher
from engine.clients.elasticsearch.upload import ElasticUploader

__all__ = [
"ElasticConfigurator",
"ElasticSearcher",
"ElasticUploader",
]
7 changes: 4 additions & 3 deletions engine/clients/elasticsearch/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from elasticsearch import Elasticsearch

from dataset_reader.base_reader import Query
from engine.base_client.search import BaseSearcher
from engine.clients.elasticsearch.config import (
ELASTIC_INDEX,
Expand Down Expand Up @@ -46,15 +47,15 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic
cls.search_params = search_params

@classmethod
def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]:
knn = {
"field": "vector",
"query_vector": vector,
"query_vector": query.vector,
"k": top,
**{"num_candidates": 100, **cls.search_params},
}

meta_conditions = cls.parser.parse(meta_conditions)
meta_conditions = cls.parser.parse(query.meta_conditions)
if meta_conditions:
knn["filter"] = meta_conditions

Expand Down
18 changes: 6 additions & 12 deletions engine/clients/elasticsearch/upload.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import multiprocessing as mp
import uuid
from typing import List, Optional
from typing import List

from elasticsearch import Elasticsearch

from dataset_reader.base_reader import Record
from engine.base_client.upload import BaseUploader
from engine.clients.elasticsearch.config import (
ELASTIC_INDEX,
Expand Down Expand Up @@ -44,19 +45,12 @@ def init_client(cls, host, distance, connection_params, upload_params):
cls.upload_params = upload_params

@classmethod
def upload_batch(
cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]]
):
if metadata is None:
metadata = [{}] * len(vectors)
def upload_batch(cls, batch: List[Record]):
operations = []
for idx, vector, payload in zip(ids, vectors, metadata):
vector_id = uuid.UUID(int=idx).hex
for record in batch:
vector_id = uuid.UUID(int=record.id).hex
operations.append({"index": {"_id": vector_id}})
if payload:
operations.append({"vector": vector, **payload})
else:
operations.append({"vector": vector})
operations.append({"vector": record.vector, **(record.metadata or {})})

cls.client.bulk(
index=ELASTIC_INDEX,
Expand Down
6 changes: 6 additions & 0 deletions engine/clients/milvus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from engine.clients.milvus.configure import MilvusConfigurator
from engine.clients.milvus.search import MilvusSearcher
from engine.clients.milvus.upload import MilvusUploader

__all__ = [
"MilvusConfigurator",
"MilvusSearcher",
"MilvusUploader",
]
7 changes: 4 additions & 3 deletions engine/clients/milvus/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pymilvus import Collection, connections

from dataset_reader.base_reader import Query
from engine.base_client.search import BaseSearcher
from engine.clients.milvus.config import (
DISTANCE_MAPPING,
Expand Down Expand Up @@ -37,15 +38,15 @@ def get_mp_start_method(cls):
return "forkserver" if "forkserver" in mp.get_all_start_methods() else "spawn"

@classmethod
def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]:
param = {"metric_type": cls.distance, "params": cls.search_params["params"]}
try:
res = cls.collection.search(
data=[vector],
data=[query.vector],
anns_field="vector",
param=param,
limit=top,
expr=cls.parser.parse(meta_conditions),
expr=cls.parser.parse(query.meta_conditions),
)
except Exception as e:
import ipdb
Expand Down
21 changes: 14 additions & 7 deletions engine/clients/milvus/upload.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import multiprocessing as mp
from typing import List, Optional
from typing import List

from pymilvus import (
Collection,
Expand All @@ -8,6 +8,7 @@
wait_for_index_building_complete,
)

from dataset_reader.base_reader import Record
from engine.base_client.upload import BaseUploader
from engine.clients.milvus.config import (
DISTANCE_MAPPING,
Expand Down Expand Up @@ -41,20 +42,26 @@ def init_client(cls, host, distance, connection_params, upload_params):
cls.distance = DISTANCE_MAPPING[distance]

@classmethod
def upload_batch(
cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]]
):
if metadata is not None:
def upload_batch(cls, batch: List[Record]):
has_metadata = any(record.metadata for record in batch)
if has_metadata:
field_values = [
[
payload.get(field_schema.name) or DTYPE_DEFAULT[field_schema.dtype]
for payload in metadata
record.metadata.get(field_schema.name)
or DTYPE_DEFAULT[field_schema.dtype]
for record in batch
]
for field_schema in cls.collection.schema.fields
if field_schema.name not in ["id", "vector"]
]
else:
field_values = []

ids, vectors = [], []
for record in batch:
ids.append(record.id)
vectors.append(record.vector)

cls.collection.insert([ids, vectors] + field_values)

@classmethod
Expand Down
6 changes: 6 additions & 0 deletions engine/clients/opensearch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from engine.clients.opensearch.configure import OpenSearchConfigurator
from engine.clients.opensearch.search import OpenSearchSearcher
from engine.clients.opensearch.upload import OpenSearchUploader

__all__ = [
"OpenSearchConfigurator",
"OpenSearchSearcher",
"OpenSearchUploader",
]
15 changes: 8 additions & 7 deletions engine/clients/opensearch/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from opensearchpy import OpenSearch

from dataset_reader.base_reader import Query
from engine.base_client.search import BaseSearcher
from engine.clients.opensearch.config import (
OPENSEARCH_INDEX,
Expand Down Expand Up @@ -46,29 +47,29 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic
cls.search_params = search_params

@classmethod
def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
query = {
def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]:
opensearch_query = {
"knn": {
"vector": {
"vector": vector,
"vector": query.vector,
"k": top,
}
}
}

meta_conditions = cls.parser.parse(meta_conditions)
meta_conditions = cls.parser.parse(query.meta_conditions)
if meta_conditions:
query = {
opensearch_query = {
"bool": {
"must": [query],
"must": [opensearch_query],
"filter": meta_conditions,
}
}

res = cls.client.search(
index=OPENSEARCH_INDEX,
body={
"query": query,
"query": opensearch_query,
"size": top,
},
params={
Expand Down
18 changes: 6 additions & 12 deletions engine/clients/opensearch/upload.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import multiprocessing as mp
import uuid
from typing import List, Optional
from typing import List

from opensearchpy import OpenSearch

from dataset_reader.base_reader import Record
from engine.base_client.upload import BaseUploader
from engine.clients.opensearch.config import (
OPENSEARCH_INDEX,
Expand Down Expand Up @@ -44,19 +45,12 @@ def init_client(cls, host, distance, connection_params, upload_params):
cls.upload_params = upload_params

@classmethod
def upload_batch(
cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]]
):
if metadata is None:
metadata = [{}] * len(vectors)
def upload_batch(cls, batch: List[Record]):
operations = []
for idx, vector, payload in zip(ids, vectors, metadata):
vector_id = uuid.UUID(int=idx).hex
for record in batch:
vector_id = uuid.UUID(int=record.id).hex
operations.append({"index": {"_id": vector_id}})
if payload:
operations.append({"vector": vector, **payload})
else:
operations.append({"vector": vector})
operations.append({"vector": record.vector, **(record.metadata or {})})

cls.client.bulk(
index=OPENSEARCH_INDEX,
Expand Down
4 changes: 3 additions & 1 deletion engine/clients/pgvector/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def recreate(self, dataset: Dataset, collection_params):
)

self.conn.execute(
f"CREATE INDEX on items USING hnsw(embedding {hnsw_distance_type}) WITH (m = {collection_params['hnsw_config']['m']}, ef_construction = {collection_params['hnsw_config']['ef_construct']})"
f"CREATE INDEX on items USING hnsw(embedding {hnsw_distance_type}) WITH "
f"(m = {collection_params['hnsw_config']['m']}, "
f"ef_construction = {collection_params['hnsw_config']['ef_construct']})"
)

self.conn.close()
Expand Down
12 changes: 6 additions & 6 deletions engine/clients/pgvector/search.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import multiprocessing as mp
from typing import List, Tuple

import numpy as np
import psycopg
from pgvector.psycopg import register_vector

from dataset_reader.base_reader import Query
from engine.base_client.distances import Distance
from engine.base_client.search import BaseSearcher
from engine.clients.pgvector.config import get_db_config
Expand All @@ -27,19 +27,19 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic
cls.search_params = search_params["search_params"]

@classmethod
def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]:
cls.cur.execute(f"SET hnsw.ef_search = {cls.search_params['hnsw_ef']}")

if cls.distance == Distance.COSINE:
query = f"SELECT id, embedding <=> %s AS _score FROM items ORDER BY _score LIMIT {top};"
sql_query = f"SELECT id, embedding <=> %s AS _score FROM items ORDER BY _score LIMIT {top};"
elif cls.distance == Distance.L2:
query = f"SELECT id, embedding <-> %s AS _score FROM items ORDER BY _score LIMIT {top};"
sql_query = f"SELECT id, embedding <-> %s AS _score FROM items ORDER BY _score LIMIT {top};"
else:
raise NotImplementedError(f"Unsupported distance metric {cls.distance}")

cls.cur.execute(
query,
(np.array(vector),),
sql_query,
(np.array(query.vector),),
)
return cls.cur.fetchall()

Expand Down
13 changes: 5 additions & 8 deletions engine/clients/pgvector/upload.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List, Optional
from typing import List

import numpy as np
import psycopg
from pgvector.psycopg import register_vector

from dataset_reader.base_reader import Record
from engine.base_client.upload import BaseUploader
from engine.clients.pgvector.config import get_db_config

Expand All @@ -21,15 +22,11 @@ def init_client(cls, host, distance, connection_params, upload_params):
cls.upload_params = upload_params

@classmethod
def upload_batch(
cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]]
):
vectors = np.array(vectors)

def upload_batch(cls, batch: List[Record]):
# Copy is faster than insert
with cls.cur.copy("COPY items (id, embedding) FROM STDIN") as copy:
for i, embedding in zip(ids, vectors):
copy.write_row((i, embedding))
for record in batch:
copy.write_row((record.id, np.array(record.vector)))

@classmethod
def delete_client(cls):
Expand Down
6 changes: 6 additions & 0 deletions engine/clients/qdrant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from engine.clients.qdrant.configure import QdrantConfigurator
from engine.clients.qdrant.search import QdrantSearcher
from engine.clients.qdrant.upload import QdrantUploader

__all__ = [
"QdrantConfigurator",
"QdrantSearcher",
"QdrantUploader",
]
Loading