From 72caffc6d1afdb46d8ee4d9b9fb02919d36bc3ee Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 9 Apr 2024 13:50:55 +0530 Subject: [PATCH 1/8] refactor: Update all engines to use Query and Record dataclasses --- engine/clients/elasticsearch/search.py | 5 +++-- engine/clients/elasticsearch/upload.py | 16 ++++++---------- engine/clients/milvus/search.py | 7 ++++--- engine/clients/milvus/upload.py | 15 +++++++++------ engine/clients/opensearch/search.py | 5 +++-- engine/clients/opensearch/upload.py | 18 +++++++----------- engine/clients/pgvector/search.py | 11 ++++++----- engine/clients/pgvector/upload.py | 9 ++++----- engine/clients/qdrant/search.py | 2 +- engine/clients/redis/search.py | 11 ++++++----- engine/clients/redis/upload.py | 13 ++++++------- engine/clients/weaviate/search.py | 7 ++++--- engine/clients/weaviate/upload.py | 15 ++++++++------- 13 files changed, 67 insertions(+), 67 deletions(-) diff --git a/engine/clients/elasticsearch/search.py b/engine/clients/elasticsearch/search.py index 29d20ec5..bcbb6d94 100644 --- a/engine/clients/elasticsearch/search.py +++ b/engine/clients/elasticsearch/search.py @@ -4,6 +4,7 @@ from elasticsearch import Elasticsearch +from dataset_reader.base_reader import Query from engine.base_client.search import BaseSearcher from engine.clients.elasticsearch.config import ( ELASTIC_INDEX, @@ -46,10 +47,10 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic cls.search_params = search_params @classmethod - def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]: + def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]: knn = { "field": "vector", - "query_vector": vector, + "query_vector": query.vector, "k": top, **{"num_candidates": 100, **cls.search_params}, } diff --git a/engine/clients/elasticsearch/upload.py b/engine/clients/elasticsearch/upload.py index 0d5c6f2b..e170fa73 100644 --- a/engine/clients/elasticsearch/upload.py +++ b/engine/clients/elasticsearch/upload.py @@ -4,6 +4,7 @@ from elasticsearch import Elasticsearch +from dataset_reader.base_reader import Record from engine.base_client.upload import BaseUploader from engine.clients.elasticsearch.config import ( ELASTIC_INDEX, @@ -44,19 +45,14 @@ def init_client(cls, host, distance, connection_params, upload_params): cls.upload_params = upload_params @classmethod - def upload_batch( - cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]] - ): + def upload_batch(cls, batch: List[Record]): if metadata is None: - metadata = [{}] * len(vectors) + metadata = [{}] * len(batch) operations = [] - for idx, vector, payload in zip(ids, vectors, metadata): - vector_id = uuid.UUID(int=idx).hex + for record in batch: + vector_id = uuid.UUID(int=record.idx).hex operations.append({"index": {"_id": vector_id}}) - if payload: - operations.append({"vector": vector, **payload}) - else: - operations.append({"vector": vector}) + operations.append({"vector": record.vector, **(record.metadata or {})}) cls.client.bulk( index=ELASTIC_INDEX, diff --git a/engine/clients/milvus/search.py b/engine/clients/milvus/search.py index 9b155f7b..80493206 100644 --- a/engine/clients/milvus/search.py +++ b/engine/clients/milvus/search.py @@ -3,6 +3,7 @@ from pymilvus import Collection, connections +from dataset_reader.base_reader import Query from engine.base_client.search import BaseSearcher from engine.clients.milvus.config import ( DISTANCE_MAPPING, @@ -37,15 +38,15 @@ def get_mp_start_method(cls): return "forkserver" if "forkserver" in mp.get_all_start_methods() else "spawn" @classmethod - def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]: + def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]: param = {"metric_type": cls.distance, "params": cls.search_params["params"]} try: res = cls.collection.search( - data=[vector], + data=[query.vector], anns_field="vector", param=param, limit=top, - expr=cls.parser.parse(meta_conditions), + expr=cls.parser.parse(query.meta_conditions), ) except Exception as e: import ipdb diff --git a/engine/clients/milvus/upload.py b/engine/clients/milvus/upload.py index 8f897a45..ff6375d4 100644 --- a/engine/clients/milvus/upload.py +++ b/engine/clients/milvus/upload.py @@ -8,6 +8,7 @@ wait_for_index_building_complete, ) +from dataset_reader.base_reader import Record from engine.base_client.upload import BaseUploader from engine.clients.milvus.config import ( DISTANCE_MAPPING, @@ -41,20 +42,22 @@ def init_client(cls, host, distance, connection_params, upload_params): cls.distance = DISTANCE_MAPPING[distance] @classmethod - def upload_batch( - cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]] - ): - if metadata is not None: + def upload_batch(cls, batch: List[Record]): + has_metadata = any(record.metadata for record in batch) + if has_metadata: field_values = [ [ - payload.get(field_schema.name) or DTYPE_DEFAULT[field_schema.dtype] - for payload in metadata + record.metadata.get(field_schema.name) + or DTYPE_DEFAULT[field_schema.dtype] + for record in batch ] for field_schema in cls.collection.schema.fields if field_schema.name not in ["id", "vector"] ] else: field_values = [] + ids = [record.idx for record in batch] + vectors = [record.vector for record in batch] cls.collection.insert([ids, vectors] + field_values) @classmethod diff --git a/engine/clients/opensearch/search.py b/engine/clients/opensearch/search.py index 8f388380..4f04f0e4 100644 --- a/engine/clients/opensearch/search.py +++ b/engine/clients/opensearch/search.py @@ -4,6 +4,7 @@ from opensearchpy import OpenSearch +from dataset_reader.base_reader import Query from engine.base_client.search import BaseSearcher from engine.clients.opensearch.config import ( OPENSEARCH_INDEX, @@ -46,11 +47,11 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic cls.search_params = search_params @classmethod - def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]: + def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]: query = { "knn": { "vector": { - "vector": vector, + "vector": query.vector, "k": top, } } diff --git a/engine/clients/opensearch/upload.py b/engine/clients/opensearch/upload.py index 46a7151d..3a42bf5f 100644 --- a/engine/clients/opensearch/upload.py +++ b/engine/clients/opensearch/upload.py @@ -1,9 +1,10 @@ import multiprocessing as mp import uuid -from typing import List, Optional +from typing import List from opensearchpy import OpenSearch +from dataset_reader.base_reader import Record from engine.base_client.upload import BaseUploader from engine.clients.opensearch.config import ( OPENSEARCH_INDEX, @@ -44,19 +45,14 @@ def init_client(cls, host, distance, connection_params, upload_params): cls.upload_params = upload_params @classmethod - def upload_batch( - cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]] - ): + def upload_batch(cls, batch: List[Record]): if metadata is None: - metadata = [{}] * len(vectors) + metadata = [{}] * len(batch) operations = [] - for idx, vector, payload in zip(ids, vectors, metadata): - vector_id = uuid.UUID(int=idx).hex + for record in batch: + vector_id = uuid.UUID(int=record.id).hex operations.append({"index": {"_id": vector_id}}) - if payload: - operations.append({"vector": vector, **payload}) - else: - operations.append({"vector": vector}) + operations.append({"vector": record.vector, **(record.metadata or {})}) cls.client.bulk( index=OPENSEARCH_INDEX, diff --git a/engine/clients/pgvector/search.py b/engine/clients/pgvector/search.py index fa8bde5a..c9e470b9 100644 --- a/engine/clients/pgvector/search.py +++ b/engine/clients/pgvector/search.py @@ -5,6 +5,7 @@ import psycopg from pgvector.psycopg import register_vector +from dataset_reader.base_reader import Query from engine.base_client.distances import Distance from engine.base_client.search import BaseSearcher from engine.clients.pgvector.config import get_db_config @@ -27,19 +28,19 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic cls.search_params = search_params["search_params"] @classmethod - def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]: + def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]: cls.cur.execute(f"SET hnsw.ef_search = {cls.search_params['hnsw_ef']}") if cls.distance == Distance.COSINE: - query = f"SELECT id, embedding <=> %s AS _score FROM items ORDER BY _score LIMIT {top};" + sql_query = f"SELECT id, embedding <=> %s AS _score FROM items ORDER BY _score LIMIT {top};" elif cls.distance == Distance.L2: - query = f"SELECT id, embedding <-> %s AS _score FROM items ORDER BY _score LIMIT {top};" + sql_query = f"SELECT id, embedding <-> %s AS _score FROM items ORDER BY _score LIMIT {top};" else: raise NotImplementedError(f"Unsupported distance metric {cls.distance}") cls.cur.execute( - query, - (np.array(vector),), + sql_query, + (np.array(query.vector),), ) return cls.cur.fetchall() diff --git a/engine/clients/pgvector/upload.py b/engine/clients/pgvector/upload.py index 8d59ee7f..3aee0a74 100644 --- a/engine/clients/pgvector/upload.py +++ b/engine/clients/pgvector/upload.py @@ -4,6 +4,7 @@ import psycopg from pgvector.psycopg import register_vector +from dataset_reader.base_reader import Record from engine.base_client.upload import BaseUploader from engine.clients.pgvector.config import get_db_config @@ -21,15 +22,13 @@ def init_client(cls, host, distance, connection_params, upload_params): cls.upload_params = upload_params @classmethod - def upload_batch( - cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]] - ): + def upload_batch(cls, batch: List[Record]): vectors = np.array(vectors) # Copy is faster than insert with cls.cur.copy("COPY items (id, embedding) FROM STDIN") as copy: - for i, embedding in zip(ids, vectors): - copy.write_row((i, embedding)) + for record in batch: + copy.write_row((record.id, record.vector)) @classmethod def delete_client(cls): diff --git a/engine/clients/qdrant/search.py b/engine/clients/qdrant/search.py index fd3c04eb..759d2d99 100644 --- a/engine/clients/qdrant/search.py +++ b/engine/clients/qdrant/search.py @@ -35,7 +35,7 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic # return "forkserver" if "forkserver" in mp.get_all_start_methods() else "spawn" @classmethod - def search_one(cls, query: Query, top) -> List[Tuple[int, float]]: + def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]: # Can query only one till we introduce re-ranking in the benchmarks if query.sparse_vector is None: query_vector = query.vector diff --git a/engine/clients/redis/search.py b/engine/clients/redis/search.py index dca31919..feccfece 100644 --- a/engine/clients/redis/search.py +++ b/engine/clients/redis/search.py @@ -3,8 +3,9 @@ import numpy as np from redis import Redis, RedisCluster -from redis.commands.search.query import Query +from redis.commands.search.query import Query as RedisQuery +from dataset_reader.base_reader import Query as DatasetQuery from engine.base_client.search import BaseSearcher from engine.clients.redis.config import ( REDIS_AUTH, @@ -41,8 +42,8 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic cls._ft = cls.conns[random.randint(0, len(cls.conns)) - 1].ft() @classmethod - def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]: - conditions = cls.parser.parse(meta_conditions) + def search_one(cls, query: DatasetQuery, top: int) -> List[Tuple[int, float]]: + conditions = cls.parser.parse(query.meta_conditions) if conditions is None: prefilter_condition = "*" params = {} @@ -50,7 +51,7 @@ def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]: prefilter_condition, params = conditions q = ( - Query( + RedisQuery( f"{prefilter_condition}=>[KNN $K @vector $vec_param {cls.knn_conditions} AS vector_score]" ) .sort_by("vector_score", asc=True) @@ -62,7 +63,7 @@ def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]: .timeout(REDIS_QUERY_TIMEOUT) ) params_dict = { - "vec_param": np.array(vector).astype(np.float32).tobytes(), + "vec_param": np.array(query.vector).astype(np.float32).tobytes(), "K": top, "EF": cls.search_params["search_params"]["ef"], **params, diff --git a/engine/clients/redis/upload.py b/engine/clients/redis/upload.py index 89bc0a3b..2211f662 100644 --- a/engine/clients/redis/upload.py +++ b/engine/clients/redis/upload.py @@ -3,6 +3,7 @@ import numpy as np from redis import Redis, RedisCluster +from dataset_reader.base_reader import Record from engine.base_client.upload import BaseUploader from engine.clients.redis.config import ( REDIS_AUTH, @@ -26,14 +27,12 @@ def init_client(cls, host, distance, connection_params, upload_params): cls.upload_params = upload_params @classmethod - def upload_batch( - cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]] - ): + def upload_batch(cls, batch: List[Record]): p = cls.client.pipeline(transaction=False) - for i in range(len(ids)): - idx = ids[i] - vec = vectors[i] - meta = metadata[i] if metadata else {} + for record in batch: + idx = record.id + vec = record.vector + meta = record.metadata or {} geopoints = {} payload = {} if meta is not None: diff --git a/engine/clients/weaviate/search.py b/engine/clients/weaviate/search.py index 4218be92..cd9abef2 100644 --- a/engine/clients/weaviate/search.py +++ b/engine/clients/weaviate/search.py @@ -7,6 +7,7 @@ from weaviate.collections import Collection from weaviate.connect import ConnectionParams +from dataset_reader.base_reader import Query from engine.base_client.search import BaseSearcher from engine.clients.weaviate.config import WEAVIATE_CLASS_NAME, WEAVIATE_DEFAULT_PORT from engine.clients.weaviate.parser import WeaviateConditionParser @@ -32,10 +33,10 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic cls.client = client @classmethod - def search_one(self, vector, meta_conditions, top) -> List[Tuple[int, float]]: + def search_one(self, query: Query, top: int) -> List[Tuple[int, float]]: res = self.collection.query.near_vector( - near_vector=vector, - filters=self.parser.parse(meta_conditions), + near_vector=query.vector, + filters=self.parser.parse(query.meta_conditions), limit=top, return_metadata=MetadataQuery(distance=True), return_properties=[], diff --git a/engine/clients/weaviate/upload.py b/engine/clients/weaviate/upload.py index ad52f64f..03c00498 100644 --- a/engine/clients/weaviate/upload.py +++ b/engine/clients/weaviate/upload.py @@ -5,6 +5,7 @@ from weaviate.classes.data import DataObject from weaviate.connect import ConnectionParams +from dataset_reader.base_reader import Record from engine.base_client.upload import BaseUploader from engine.clients.weaviate.config import WEAVIATE_CLASS_NAME, WEAVIATE_DEFAULT_PORT @@ -28,14 +29,14 @@ def init_client(cls, host, distance, connection_params, upload_params): ) @classmethod - def upload_batch( - cls, ids: List[int], vectors: List[list], metadata: List[Optional[dict]] - ): + def upload_batch(cls, batch: List[Record]): objects = [] - for i in range(len(ids)): - id = uuid.UUID(int=ids[i]) - property = metadata[i] or {} - objects.append(DataObject(properties=property, vector=vectors[i], uuid=id)) + for record in batch: + id = uuid.UUID(int=record.id) + property = record.metadata or {} + objects.append( + DataObject(properties=property, vector=record.vector, uuid=id) + ) if len(objects) > 0: cls.collection.data.insert_many(objects) From a7e6698f7b6e9bb04ed2b97681c498780eef8809 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 9 Apr 2024 14:23:31 +0530 Subject: [PATCH 2/8] feat: Add ruff in pre-commit hooks --- .pre-commit-config.yaml | 6 ++++++ engine/base_client/__init__.py | 9 +++++++++ engine/base_client/client.py | 1 - engine/base_client/upload.py | 2 +- engine/base_client/utils.py | 2 +- engine/clients/elasticsearch/__init__.py | 6 ++++++ engine/clients/elasticsearch/search.py | 2 +- engine/clients/elasticsearch/upload.py | 4 +--- engine/clients/milvus/__init__.py | 6 ++++++ engine/clients/milvus/upload.py | 2 +- engine/clients/opensearch/__init__.py | 6 ++++++ engine/clients/opensearch/search.py | 2 +- engine/clients/opensearch/upload.py | 2 -- engine/clients/pgvector/search.py | 1 - engine/clients/pgvector/upload.py | 6 ++---- engine/clients/qdrant/__init__.py | 6 ++++++ engine/clients/qdrant/upload.py | 2 +- engine/clients/redis/__init__.py | 6 ++++++ engine/clients/redis/upload.py | 2 +- engine/clients/weaviate/__init__.py | 6 ++++++ engine/clients/weaviate/search.py | 1 - engine/clients/weaviate/upload.py | 2 +- run.py | 2 +- 23 files changed, 63 insertions(+), 21 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 754906cd..b172eb4f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,3 +28,9 @@ repos: - id: isort name: "Sort Imports" args: ["--profile", "black"] + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.5 + hooks: + # Run the linter. + - id: ruff diff --git a/engine/base_client/__init__.py b/engine/base_client/__init__.py index a5495554..5528bb47 100644 --- a/engine/base_client/__init__.py +++ b/engine/base_client/__init__.py @@ -6,3 +6,12 @@ class IncompatibilityError(Exception): pass + + +__all__ = [ + "BaseClient", + "BaseConfigurator", + "BaseSearcher", + "BaseUploader", + "IncompatibilityError", +] diff --git a/engine/base_client/client.py b/engine/base_client/client.py index 840a9353..def2f53b 100644 --- a/engine/base_client/client.py +++ b/engine/base_client/client.py @@ -1,7 +1,6 @@ import json import os from datetime import datetime -from pathlib import Path from typing import List from benchmark import ROOT_DIR diff --git a/engine/base_client/upload.py b/engine/base_client/upload.py index 260f7c54..597e8750 100644 --- a/engine/base_client/upload.py +++ b/engine/base_client/upload.py @@ -1,6 +1,6 @@ import time from multiprocessing import get_context -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional import tqdm diff --git a/engine/base_client/utils.py b/engine/base_client/utils.py index 39d4fbe9..899ae9ee 100644 --- a/engine/base_client/utils.py +++ b/engine/base_client/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable +from typing import Iterable from dataset_reader.base_reader import Record diff --git a/engine/clients/elasticsearch/__init__.py b/engine/clients/elasticsearch/__init__.py index 24288e97..c1802087 100644 --- a/engine/clients/elasticsearch/__init__.py +++ b/engine/clients/elasticsearch/__init__.py @@ -1,3 +1,9 @@ from engine.clients.elasticsearch.configure import ElasticConfigurator from engine.clients.elasticsearch.search import ElasticSearcher from engine.clients.elasticsearch.upload import ElasticUploader + +__all__ = [ + "ElasticConfigurator", + "ElasticSearcher", + "ElasticUploader", +] diff --git a/engine/clients/elasticsearch/search.py b/engine/clients/elasticsearch/search.py index bcbb6d94..f51adc15 100644 --- a/engine/clients/elasticsearch/search.py +++ b/engine/clients/elasticsearch/search.py @@ -55,7 +55,7 @@ def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]: **{"num_candidates": 100, **cls.search_params}, } - meta_conditions = cls.parser.parse(meta_conditions) + meta_conditions = cls.parser.parse(query.meta_conditions) if meta_conditions: knn["filter"] = meta_conditions diff --git a/engine/clients/elasticsearch/upload.py b/engine/clients/elasticsearch/upload.py index e170fa73..71c1267c 100644 --- a/engine/clients/elasticsearch/upload.py +++ b/engine/clients/elasticsearch/upload.py @@ -1,6 +1,6 @@ import multiprocessing as mp import uuid -from typing import List, Optional +from typing import List from elasticsearch import Elasticsearch @@ -46,8 +46,6 @@ def init_client(cls, host, distance, connection_params, upload_params): @classmethod def upload_batch(cls, batch: List[Record]): - if metadata is None: - metadata = [{}] * len(batch) operations = [] for record in batch: vector_id = uuid.UUID(int=record.idx).hex diff --git a/engine/clients/milvus/__init__.py b/engine/clients/milvus/__init__.py index ca400c86..31abe17b 100644 --- a/engine/clients/milvus/__init__.py +++ b/engine/clients/milvus/__init__.py @@ -1,3 +1,9 @@ from engine.clients.milvus.configure import MilvusConfigurator from engine.clients.milvus.search import MilvusSearcher from engine.clients.milvus.upload import MilvusUploader + +__all__ = [ + "MilvusConfigurator", + "MilvusSearcher", + "MilvusUploader", +] diff --git a/engine/clients/milvus/upload.py b/engine/clients/milvus/upload.py index ff6375d4..4a5def5a 100644 --- a/engine/clients/milvus/upload.py +++ b/engine/clients/milvus/upload.py @@ -1,5 +1,5 @@ import multiprocessing as mp -from typing import List, Optional +from typing import List from pymilvus import ( Collection, diff --git a/engine/clients/opensearch/__init__.py b/engine/clients/opensearch/__init__.py index 686bfcde..e4c6c59a 100644 --- a/engine/clients/opensearch/__init__.py +++ b/engine/clients/opensearch/__init__.py @@ -1,3 +1,9 @@ from engine.clients.opensearch.configure import OpenSearchConfigurator from engine.clients.opensearch.search import OpenSearchSearcher from engine.clients.opensearch.upload import OpenSearchUploader + +__all__ = [ + "OpenSearchConfigurator", + "OpenSearchSearcher", + "OpenSearchUploader", +] diff --git a/engine/clients/opensearch/search.py b/engine/clients/opensearch/search.py index 4f04f0e4..30b882df 100644 --- a/engine/clients/opensearch/search.py +++ b/engine/clients/opensearch/search.py @@ -57,7 +57,7 @@ def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]: } } - meta_conditions = cls.parser.parse(meta_conditions) + meta_conditions = cls.parser.parse(query.meta_conditions) if meta_conditions: query = { "bool": { diff --git a/engine/clients/opensearch/upload.py b/engine/clients/opensearch/upload.py index 3a42bf5f..0bc2427e 100644 --- a/engine/clients/opensearch/upload.py +++ b/engine/clients/opensearch/upload.py @@ -46,8 +46,6 @@ def init_client(cls, host, distance, connection_params, upload_params): @classmethod def upload_batch(cls, batch: List[Record]): - if metadata is None: - metadata = [{}] * len(batch) operations = [] for record in batch: vector_id = uuid.UUID(int=record.id).hex diff --git a/engine/clients/pgvector/search.py b/engine/clients/pgvector/search.py index c9e470b9..c3799654 100644 --- a/engine/clients/pgvector/search.py +++ b/engine/clients/pgvector/search.py @@ -1,4 +1,3 @@ -import multiprocessing as mp from typing import List, Tuple import numpy as np diff --git a/engine/clients/pgvector/upload.py b/engine/clients/pgvector/upload.py index 3aee0a74..74bb4486 100644 --- a/engine/clients/pgvector/upload.py +++ b/engine/clients/pgvector/upload.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List import numpy as np import psycopg @@ -23,12 +23,10 @@ def init_client(cls, host, distance, connection_params, upload_params): @classmethod def upload_batch(cls, batch: List[Record]): - vectors = np.array(vectors) - # Copy is faster than insert with cls.cur.copy("COPY items (id, embedding) FROM STDIN") as copy: for record in batch: - copy.write_row((record.id, record.vector)) + copy.write_row((record.id, np.array(record.vector))) @classmethod def delete_client(cls): diff --git a/engine/clients/qdrant/__init__.py b/engine/clients/qdrant/__init__.py index 03642803..2c95ffc8 100644 --- a/engine/clients/qdrant/__init__.py +++ b/engine/clients/qdrant/__init__.py @@ -1,3 +1,9 @@ from engine.clients.qdrant.configure import QdrantConfigurator from engine.clients.qdrant.search import QdrantSearcher from engine.clients.qdrant.upload import QdrantUploader + +__all__ = [ + "QdrantConfigurator", + "QdrantSearcher", + "QdrantUploader", +] diff --git a/engine/clients/qdrant/upload.py b/engine/clients/qdrant/upload.py index bb1f83db..a5c2dbbe 100644 --- a/engine/clients/qdrant/upload.py +++ b/engine/clients/qdrant/upload.py @@ -46,7 +46,7 @@ def upload_batch(cls, batch: List[Record]): vectors.append(vector) payloads.append(point.metadata or {}) - res = cls.client.upsert( + _ = cls.client.upsert( collection_name=QDRANT_COLLECTION_NAME, points=Batch.model_construct( ids=ids, diff --git a/engine/clients/redis/__init__.py b/engine/clients/redis/__init__.py index a1437747..75f3b150 100644 --- a/engine/clients/redis/__init__.py +++ b/engine/clients/redis/__init__.py @@ -1,3 +1,9 @@ from engine.clients.redis.configure import RedisConfigurator from engine.clients.redis.search import RedisSearcher from engine.clients.redis.upload import RedisUploader + +__all__ = [ + "RedisConfigurator", + "RedisSearcher", + "RedisUploader", +] diff --git a/engine/clients/redis/upload.py b/engine/clients/redis/upload.py index 2211f662..cd4b888b 100644 --- a/engine/clients/redis/upload.py +++ b/engine/clients/redis/upload.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List import numpy as np from redis import Redis, RedisCluster diff --git a/engine/clients/weaviate/__init__.py b/engine/clients/weaviate/__init__.py index d8d90121..2e8abba5 100644 --- a/engine/clients/weaviate/__init__.py +++ b/engine/clients/weaviate/__init__.py @@ -1,3 +1,9 @@ from engine.clients.weaviate.configure import WeaviateConfigurator from engine.clients.weaviate.search import WeaviateSearcher from engine.clients.weaviate.upload import WeaviateUploader + +__all__ = [ + "WeaviateConfigurator", + "WeaviateSearcher", + "WeaviateUploader", +] diff --git a/engine/clients/weaviate/search.py b/engine/clients/weaviate/search.py index cd9abef2..def87b38 100644 --- a/engine/clients/weaviate/search.py +++ b/engine/clients/weaviate/search.py @@ -1,4 +1,3 @@ -import uuid from typing import List, Tuple from weaviate import WeaviateClient diff --git a/engine/clients/weaviate/upload.py b/engine/clients/weaviate/upload.py index 03c00498..05790490 100644 --- a/engine/clients/weaviate/upload.py +++ b/engine/clients/weaviate/upload.py @@ -1,5 +1,5 @@ import uuid -from typing import List, Optional +from typing import List from weaviate import WeaviateClient from weaviate.classes.data import DataObject diff --git a/run.py b/run.py index 32cf6509..1c20e88a 100644 --- a/run.py +++ b/run.py @@ -76,7 +76,7 @@ def run( f"Skipping {engine_name} - {dataset_name}, incompatible params:", e ) continue - except KeyboardInterrupt as e: + except KeyboardInterrupt: traceback.print_exc() exit(1) except Exception as e: From 63d3a028d545e1e224eeb81ee98558188f261528 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 16 Apr 2024 20:34:00 +0530 Subject: [PATCH 3/8] fix: Type mismatches --- engine/base_client/upload.py | 6 ++---- engine/clients/milvus/upload.py | 9 +++++++-- engine/clients/weaviate/upload.py | 6 +++--- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/engine/base_client/upload.py b/engine/base_client/upload.py index 597e8750..55ee4055 100644 --- a/engine/base_client/upload.py +++ b/engine/base_client/upload.py @@ -1,6 +1,6 @@ import time from multiprocessing import get_context -from typing import Iterable, List, Optional +from typing import Iterable, List import tqdm @@ -90,9 +90,7 @@ def post_upload(cls, distance): return {} @classmethod - def upload_batch( - cls, ids: List[int], vectors: List[list], metadata: List[Optional[dict]] - ): + def upload_batch(cls, batch: List[Record]): raise NotImplementedError() @classmethod diff --git a/engine/clients/milvus/upload.py b/engine/clients/milvus/upload.py index 4a5def5a..eece1d3c 100644 --- a/engine/clients/milvus/upload.py +++ b/engine/clients/milvus/upload.py @@ -44,6 +44,7 @@ def init_client(cls, host, distance, connection_params, upload_params): @classmethod def upload_batch(cls, batch: List[Record]): has_metadata = any(record.metadata for record in batch) + field_values = [] if has_metadata: field_values = [ [ @@ -56,8 +57,12 @@ def upload_batch(cls, batch: List[Record]): ] else: field_values = [] - ids = [record.idx for record in batch] - vectors = [record.vector for record in batch] + + ids, vectors = [], [] + for record in batch: + ids.append(record.idx) + vectors.append(record.vector) + cls.collection.insert([ids, vectors] + field_values) @classmethod diff --git a/engine/clients/weaviate/upload.py b/engine/clients/weaviate/upload.py index 05790490..9715ad1d 100644 --- a/engine/clients/weaviate/upload.py +++ b/engine/clients/weaviate/upload.py @@ -32,10 +32,10 @@ def init_client(cls, host, distance, connection_params, upload_params): def upload_batch(cls, batch: List[Record]): objects = [] for record in batch: - id = uuid.UUID(int=record.id) - property = record.metadata or {} + _id = uuid.UUID(int=record.id) + _property = record.metadata or {} objects.append( - DataObject(properties=property, vector=record.vector, uuid=id) + DataObject(properties=_property, vector=record.vector, uuid=_id) ) if len(objects) > 0: cls.collection.data.insert_many(objects) From 2620cb3c338ba03cd0ad4ed3def840cccb5f2c3c Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 16 Apr 2024 21:04:46 +0530 Subject: [PATCH 4/8] fix: Redis search client types and var names --- engine/clients/redis/configure.py | 14 +++++++------- engine/clients/redis/search.py | 25 +++++++++++++++++-------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/engine/clients/redis/configure.py b/engine/clients/redis/configure.py index ccf3776c..a5e6fe82 100644 --- a/engine/clients/redis/configure.py +++ b/engine/clients/redis/configure.py @@ -36,24 +36,24 @@ class RedisConfigurator(BaseConfigurator): def __init__(self, host, collection_params: dict, connection_params: dict): super().__init__(host, collection_params, connection_params) redis_constructor = RedisCluster if REDIS_CLUSTER else Redis - self._is_cluster = True if REDIS_CLUSTER else False + self.is_cluster = REDIS_CLUSTER self.client = redis_constructor( host=host, port=REDIS_PORT, password=REDIS_AUTH, username=REDIS_USER ) def clean(self): conns = [self.client] - if self._is_cluster: + if self.is_cluster: conns = [ self.client.get_redis_connection(node) for node in self.client.get_primaries() ] for conn in conns: - index = conn.ft() + search_namespace = conn.ft() try: - index.dropindex(delete_documents=True) + search_namespace.dropindex(delete_documents=True) except redis.ResponseError as e: - if "Unknown Index name" not in e.__str__(): + if "Unknown Index name" not in str(e): print(e) def recreate(self, dataset: Dataset, collection_params): @@ -90,7 +90,7 @@ def recreate(self, dataset: Dataset, collection_params): ] + payload_fields conns = [self.client] - if self._is_cluster: + if self.is_cluster: conns = [ self.client.get_redis_connection(node) for node in self.client.get_primaries() @@ -100,7 +100,7 @@ def recreate(self, dataset: Dataset, collection_params): try: search_namespace.create_index(fields=index_fields) except redis.ResponseError as e: - if "Index already exists" not in e.__str__(): + if "Index already exists" not in str(e): raise e diff --git a/engine/clients/redis/search.py b/engine/clients/redis/search.py index feccfece..604042ae 100644 --- a/engine/clients/redis/search.py +++ b/engine/clients/redis/search.py @@ -1,9 +1,10 @@ import random -from typing import List, Tuple +from typing import List, Tuple, Optional, Union import numpy as np from redis import Redis, RedisCluster from redis.commands.search.query import Query as RedisQuery +from redis.commands.search import Search as RedisSearchIndex from dataset_reader.base_reader import Query as DatasetQuery from engine.base_client.search import BaseSearcher @@ -19,9 +20,14 @@ class RedisSearcher(BaseSearcher): search_params = {} - client = None + client: Union[RedisCluster, Redis] = None parser = RedisConditionParser() + knn_conditions: str + is_cluster: bool + conns: List[Union[RedisCluster, Redis]] + search_namespace: RedisSearchIndex + @classmethod def init_client(cls, host, distance, connection_params: dict, search_params: dict): redis_constructor = RedisCluster if REDIS_CLUSTER else Redis @@ -29,17 +35,20 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic host=host, port=REDIS_PORT, password=REDIS_AUTH, username=REDIS_USER ) cls.search_params = search_params - cls.knn_conditions = "EF_RUNTIME $EF" - cls._is_cluster = True if REDIS_CLUSTER else False + # In the case of CLUSTER API enabled we randomly select the starting primary shard # when doing the client initialization to evenly distribute the load among the cluster - cls.conns = [cls.client] - if cls._is_cluster: + if REDIS_CLUSTER: cls.conns = [ cls.client.get_redis_connection(node) for node in cls.client.get_primaries() ] - cls._ft = cls.conns[random.randint(0, len(cls.conns)) - 1].ft() + else: + cls.conns = [cls.client] + + cls.is_cluster = REDIS_CLUSTER + cls.search_namespace = random.choice(cls.conns).ft() + cls.knn_conditions = "EF_RUNTIME $EF" @classmethod def search_one(cls, query: DatasetQuery, top: int) -> List[Tuple[int, float]]: @@ -68,6 +77,6 @@ def search_one(cls, query: DatasetQuery, top: int) -> List[Tuple[int, float]]: "EF": cls.search_params["search_params"]["ef"], **params, } - results = cls._ft.search(q, query_params=params_dict) + results = cls.search_namespace.search(q, query_params=params_dict) return [(int(result.id), float(result.vector_score)) for result in results.docs] From 47d944a9ac6df1f5f89c93fc41a40d3812ccaec1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Apr 2024 15:34:57 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- engine/clients/redis/search.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/engine/clients/redis/search.py b/engine/clients/redis/search.py index 604042ae..42ad1202 100644 --- a/engine/clients/redis/search.py +++ b/engine/clients/redis/search.py @@ -1,10 +1,10 @@ import random -from typing import List, Tuple, Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np from redis import Redis, RedisCluster -from redis.commands.search.query import Query as RedisQuery from redis.commands.search import Search as RedisSearchIndex +from redis.commands.search.query import Query as RedisQuery from dataset_reader.base_reader import Query as DatasetQuery from engine.base_client.search import BaseSearcher From 3e1b33f27e97a1758b3fab76d3cb968c77bffb4e Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 16 Apr 2024 21:43:21 +0530 Subject: [PATCH 6/8] fix: Type issues detected by linter --- engine/clients/elasticsearch/upload.py | 2 +- engine/clients/milvus/upload.py | 3 +-- engine/clients/opensearch/search.py | 8 ++++---- engine/clients/pgvector/configure.py | 4 +++- engine/clients/redis/search.py | 2 +- engine/clients/weaviate/search.py | 6 +++--- 6 files changed, 13 insertions(+), 12 deletions(-) diff --git a/engine/clients/elasticsearch/upload.py b/engine/clients/elasticsearch/upload.py index 71c1267c..c577263a 100644 --- a/engine/clients/elasticsearch/upload.py +++ b/engine/clients/elasticsearch/upload.py @@ -48,7 +48,7 @@ def init_client(cls, host, distance, connection_params, upload_params): def upload_batch(cls, batch: List[Record]): operations = [] for record in batch: - vector_id = uuid.UUID(int=record.idx).hex + vector_id = uuid.UUID(int=record.id).hex operations.append({"index": {"_id": vector_id}}) operations.append({"vector": record.vector, **(record.metadata or {})}) diff --git a/engine/clients/milvus/upload.py b/engine/clients/milvus/upload.py index eece1d3c..8c3768e1 100644 --- a/engine/clients/milvus/upload.py +++ b/engine/clients/milvus/upload.py @@ -44,7 +44,6 @@ def init_client(cls, host, distance, connection_params, upload_params): @classmethod def upload_batch(cls, batch: List[Record]): has_metadata = any(record.metadata for record in batch) - field_values = [] if has_metadata: field_values = [ [ @@ -60,7 +59,7 @@ def upload_batch(cls, batch: List[Record]): ids, vectors = [], [] for record in batch: - ids.append(record.idx) + ids.append(record.id) vectors.append(record.vector) cls.collection.insert([ids, vectors] + field_values) diff --git a/engine/clients/opensearch/search.py b/engine/clients/opensearch/search.py index 30b882df..6fb1a2ea 100644 --- a/engine/clients/opensearch/search.py +++ b/engine/clients/opensearch/search.py @@ -48,7 +48,7 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic @classmethod def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]: - query = { + opensearch_query = { "knn": { "vector": { "vector": query.vector, @@ -59,9 +59,9 @@ def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]: meta_conditions = cls.parser.parse(query.meta_conditions) if meta_conditions: - query = { + opensearch_query = { "bool": { - "must": [query], + "must": [opensearch_query], "filter": meta_conditions, } } @@ -69,7 +69,7 @@ def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]: res = cls.client.search( index=OPENSEARCH_INDEX, body={ - "query": query, + "query": opensearch_query, "size": top, }, params={ diff --git a/engine/clients/pgvector/configure.py b/engine/clients/pgvector/configure.py index d5587431..486c8273 100644 --- a/engine/clients/pgvector/configure.py +++ b/engine/clients/pgvector/configure.py @@ -46,7 +46,9 @@ def recreate(self, dataset: Dataset, collection_params): ) self.conn.execute( - f"CREATE INDEX on items USING hnsw(embedding {hnsw_distance_type}) WITH (m = {collection_params['hnsw_config']['m']}, ef_construction = {collection_params['hnsw_config']['ef_construct']})" + f"CREATE INDEX on items USING hnsw(embedding {hnsw_distance_type}) WITH " + f"(m = {collection_params['hnsw_config']['m']}, " + f"ef_construction = {collection_params['hnsw_config']['ef_construct']})" ) self.conn.close() diff --git a/engine/clients/redis/search.py b/engine/clients/redis/search.py index 42ad1202..f73a4abd 100644 --- a/engine/clients/redis/search.py +++ b/engine/clients/redis/search.py @@ -1,5 +1,5 @@ import random -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union import numpy as np from redis import Redis, RedisCluster diff --git a/engine/clients/weaviate/search.py b/engine/clients/weaviate/search.py index def87b38..aeea43db 100644 --- a/engine/clients/weaviate/search.py +++ b/engine/clients/weaviate/search.py @@ -32,10 +32,10 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic cls.client = client @classmethod - def search_one(self, query: Query, top: int) -> List[Tuple[int, float]]: - res = self.collection.query.near_vector( + def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]: + res = cls.collection.query.near_vector( near_vector=query.vector, - filters=self.parser.parse(query.meta_conditions), + filters=cls.parser.parse(query.meta_conditions), limit=top, return_metadata=MetadataQuery(distance=True), return_properties=[], From 8869a7af4e4b40276ace33ef15f8960951a537dd Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 16 Apr 2024 22:08:48 +0530 Subject: [PATCH 7/8] fix: iter_batches func type --- engine/base_client/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/engine/base_client/utils.py b/engine/base_client/utils.py index 899ae9ee..1b0da967 100644 --- a/engine/base_client/utils.py +++ b/engine/base_client/utils.py @@ -1,9 +1,9 @@ -from typing import Iterable +from typing import Iterable, List from dataset_reader.base_reader import Record -def iter_batches(records: Iterable[Record], n: int) -> Iterable[Record]: +def iter_batches(records: Iterable[Record], n: int) -> Iterable[List[Record]]: batch = [] for record in records: From 604d34714f0df2e924ff8e9210b676507e51d5fb Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 16 Apr 2024 22:23:41 +0530 Subject: [PATCH 8/8] refactor: knn_conditions should be class level constant --- engine/clients/redis/search.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/engine/clients/redis/search.py b/engine/clients/redis/search.py index f73a4abd..b7582bc1 100644 --- a/engine/clients/redis/search.py +++ b/engine/clients/redis/search.py @@ -22,8 +22,8 @@ class RedisSearcher(BaseSearcher): search_params = {} client: Union[RedisCluster, Redis] = None parser = RedisConditionParser() + knn_conditions = "EF_RUNTIME $EF" - knn_conditions: str is_cluster: bool conns: List[Union[RedisCluster, Redis]] search_namespace: RedisSearchIndex @@ -48,7 +48,6 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic cls.is_cluster = REDIS_CLUSTER cls.search_namespace = random.choice(cls.conns).ft() - cls.knn_conditions = "EF_RUNTIME $EF" @classmethod def search_one(cls, query: DatasetQuery, top: int) -> List[Tuple[int, float]]: