Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 6 additions & 57 deletions pinecone/grpc/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
import logging
from abc import ABC, abstractmethod
from functools import wraps
from typing import Dict, Optional
from typing import Optional

import grpc
from grpc._channel import _InactiveRpcError, Channel
from grpc._channel import Channel

from .retry import RetryConfig
from .channel_factory import GrpcChannelFactory

from pinecone import Config
from .utils import _generate_request_id
from .config import GRPCClientConfig
from pinecone.utils.constants import REQUEST_ID, CLIENT_VERSION
from pinecone.exceptions.exceptions import PineconeException

_logger = logging.getLogger(__name__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't being used, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently unused. The _logger reference is now used in the GrpcChannelFactory which is logic I recently pulled out of this base class. Cleaning up this reference should have happened in that PR, but got overlooked.

from .grpc_runner import GrpcRunner


class GRPCIndexBase(ABC):
Expand All @@ -35,18 +28,12 @@ def __init__(
):
self.config = config
self.grpc_client_config = grpc_config or GRPCClientConfig()
self.retry_config = self.grpc_client_config.retry_config or RetryConfig()

self.fixed_metadata = {
"api-key": config.api_key,
"service-name": index_name,
"client-version": CLIENT_VERSION,
}
if self.grpc_client_config.additional_metadata:
self.fixed_metadata.update(self.grpc_client_config.additional_metadata)

self._endpoint_override = _endpoint_override

self.runner = GrpcRunner(
index_name=index_name, config=config, grpc_config=self.grpc_client_config
)
self.channel_factory = GrpcChannelFactory(
config=self.config, grpc_client_config=self.grpc_client_config, use_asyncio=False
)
Expand Down Expand Up @@ -91,44 +78,6 @@ def close(self):
except TypeError:
pass

def _wrap_grpc_call(
self,
func,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None,
):
@wraps(func)
def wrapped():
user_provided_metadata = metadata or {}
_metadata = tuple(
(k, v)
for k, v in {
**self.fixed_metadata,
**self._request_metadata(),
**user_provided_metadata,
}.items()
)
try:
return func(
request,
timeout=timeout,
metadata=_metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression,
)
except _InactiveRpcError as e:
raise PineconeException(e._state.debug_error_string) from e

return wrapped()

def _request_metadata(self) -> Dict[str, str]:
return {REQUEST_ID: _generate_request_id()}

def __enter__(self):
return self

Expand Down
97 changes: 97 additions & 0 deletions pinecone/grpc/grpc_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from functools import wraps
from typing import Dict, Tuple, Optional

from grpc._channel import _InactiveRpcError

from pinecone import Config
from .utils import _generate_request_id
from .config import GRPCClientConfig
from pinecone.utils.constants import REQUEST_ID, CLIENT_VERSION
from pinecone.exceptions.exceptions import PineconeException
from grpc import CallCredentials, Compression
from google.protobuf.message import Message


class GrpcRunner:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea pulling all of this out into somewhere contained. 👍

def __init__(self, index_name: str, config: Config, grpc_config: GRPCClientConfig):
self.config = config
self.grpc_client_config = grpc_config

self.fixed_metadata = {
"api-key": config.api_key,
"service-name": index_name,
"client-version": CLIENT_VERSION,
}
if self.grpc_client_config.additional_metadata:
self.fixed_metadata.update(self.grpc_client_config.additional_metadata)

def run(
self,
func,
request: Message,
timeout: Optional[int] = None,
metadata: Optional[Dict[str, str]] = None,
credentials: Optional[CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[Compression] = None,
):
@wraps(func)
def wrapped():
user_provided_metadata = metadata or {}
_metadata = self._prepare_metadata(user_provided_metadata)
try:
return func(
request,
timeout=timeout,
metadata=_metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression,
)
except _InactiveRpcError as e:
raise PineconeException(e._state.debug_error_string) from e

return wrapped()

async def run_asyncio(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not using or testing run_asyncio anywhere yet, but I will. It's the same as run but with the addition of the async/await bits.

self,
func,
request: Message,
timeout: Optional[int] = None,
metadata: Optional[Dict[str, str]] = None,
credentials: Optional[CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[Compression] = None,
):
@wraps(func)
async def wrapped():
user_provided_metadata = metadata or {}
_metadata = self._prepare_metadata(user_provided_metadata)
try:
return await func(
request,
timeout=timeout,
metadata=_metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression,
)
except _InactiveRpcError as e:
raise PineconeException(e._state.debug_error_string) from e

return await wrapped()

def _prepare_metadata(
self, user_provided_metadata: Dict[str, str]
) -> Tuple[Tuple[str, str], ...]:
return tuple(
(k, v)
for k, v in {
**self.fixed_metadata,
**self._request_metadata(),
**user_provided_metadata,
}.items()
)

def _request_metadata(self) -> Dict[str, str]:
return {REQUEST_ID: _generate_request_id()}
26 changes: 11 additions & 15 deletions pinecone/grpc/index_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def upsert(
if async_req:
args_dict = self._parse_non_empty_args([("namespace", namespace)])
request = UpsertRequest(vectors=vectors, **args_dict, **kwargs)
future = self._wrap_grpc_call(self.stub.Upsert.future, request, timeout=timeout)
future = self.runner.run(self.stub.Upsert.future, request, timeout=timeout)
return PineconeGrpcFuture(future)

if batch_size is None:
Expand All @@ -155,15 +155,11 @@ def upsert(
return UpsertResponse(upserted_count=total_upserted)

def _upsert_batch(
self,
vectors: List[GRPCVector],
namespace: Optional[str],
timeout: Optional[float],
**kwargs,
self, vectors: List[GRPCVector], namespace: Optional[str], timeout: Optional[int], **kwargs
) -> UpsertResponse:
args_dict = self._parse_non_empty_args([("namespace", namespace)])
request = UpsertRequest(vectors=vectors, **args_dict)
return self._wrap_grpc_call(self.stub.Upsert, request, timeout=timeout, **kwargs)
return self.runner.run(self.stub.Upsert, request, timeout=timeout, **kwargs)

def upsert_from_dataframe(
self,
Expand Down Expand Up @@ -280,10 +276,10 @@ def delete(

request = DeleteRequest(**args_dict, **kwargs)
if async_req:
future = self._wrap_grpc_call(self.stub.Delete.future, request, timeout=timeout)
future = self.runner.run(self.stub.Delete.future, request, timeout=timeout)
return PineconeGrpcFuture(future)
else:
return self._wrap_grpc_call(self.stub.Delete, request, timeout=timeout)
return self.runner.run(self.stub.Delete, request, timeout=timeout)

def fetch(
self, ids: Optional[List[str]], namespace: Optional[str] = None, **kwargs
Expand All @@ -308,7 +304,7 @@ def fetch(
args_dict = self._parse_non_empty_args([("namespace", namespace)])

request = FetchRequest(ids=ids, **args_dict, **kwargs)
response = self._wrap_grpc_call(self.stub.Fetch, request, timeout=timeout)
response = self.runner.run(self.stub.Fetch, request, timeout=timeout)
json_response = json_format.MessageToDict(response)
return parse_fetch_response(json_response)

Expand Down Expand Up @@ -388,7 +384,7 @@ def query(
request = QueryRequest(**args_dict)

timeout = kwargs.pop("timeout", None)
response = self._wrap_grpc_call(self.stub.Query, request, timeout=timeout)
response = self.runner.run(self.stub.Query, request, timeout=timeout)
json_response = json_format.MessageToDict(response)
return parse_query_response(json_response, _check_type=False)

Expand Down Expand Up @@ -451,10 +447,10 @@ def update(

request = UpdateRequest(id=id, **args_dict)
if async_req:
future = self._wrap_grpc_call(self.stub.Update.future, request, timeout=timeout)
future = self.runner.run(self.stub.Update.future, request, timeout=timeout)
return PineconeGrpcFuture(future)
else:
return self._wrap_grpc_call(self.stub.Update, request, timeout=timeout)
return self.runner.run(self.stub.Update, request, timeout=timeout)

def list_paginated(
self,
Expand Down Expand Up @@ -499,7 +495,7 @@ def list_paginated(
)
request = ListRequest(**args_dict, **kwargs)
timeout = kwargs.pop("timeout", None)
response = self._wrap_grpc_call(self.stub.List, request, timeout=timeout)
response = self.runner.run(self.stub.List, request, timeout=timeout)

if response.pagination and response.pagination.next != "":
pagination = Pagination(next=response.pagination.next)
Expand Down Expand Up @@ -572,7 +568,7 @@ def describe_index_stats(
timeout = kwargs.pop("timeout", None)

request = DescribeIndexStatsRequest(**args_dict)
response = self._wrap_grpc_call(self.stub.DescribeIndexStats, request, timeout=timeout)
response = self.runner.run(self.stub.DescribeIndexStats, request, timeout=timeout)
json_response = json_format.MessageToDict(response)
return parse_stats_response(json_response)

Expand Down
8 changes: 4 additions & 4 deletions tests/unit_grpc/test_grpc_index_describe_index_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ def setup_method(self):
)

def test_describeIndexStats_callWithoutFilter_CalledWithoutFilter(self, mocker):
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
mocker.patch.object(self.index.runner, "run", autospec=True)
self.index.describe_index_stats()
self.index._wrap_grpc_call.assert_called_once_with(
self.index.runner.run.assert_called_once_with(
self.index.stub.DescribeIndexStats, DescribeIndexStatsRequest(), timeout=None
)

def test_describeIndexStats_callWithFilter_CalledWithFilter(self, mocker, filter1):
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
mocker.patch.object(self.index.runner, "run", autospec=True)
self.index.describe_index_stats(filter=filter1)
self.index._wrap_grpc_call.assert_called_once_with(
self.index.runner.run.assert_called_once_with(
self.index.stub.DescribeIndexStats,
DescribeIndexStatsRequest(filter=dict_to_proto_struct(filter1)),
timeout=None,
Expand Down
8 changes: 4 additions & 4 deletions tests/unit_grpc/test_grpc_index_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ def setup_method(self):
)

def test_fetch_byIds_fetchByIds(self, mocker):
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
mocker.patch.object(self.index.runner, "run", autospec=True)
self.index.fetch(["vec1", "vec2"])
self.index._wrap_grpc_call.assert_called_once_with(
self.index.runner.run.assert_called_once_with(
self.index.stub.Fetch, FetchRequest(ids=["vec1", "vec2"]), timeout=None
)

def test_fetch_byIdsAndNS_fetchByIdsAndNS(self, mocker):
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
mocker.patch.object(self.index.runner, "run", autospec=True)
self.index.fetch(["vec1", "vec2"], namespace="ns", timeout=30)
self.index._wrap_grpc_call.assert_called_once_with(
self.index.runner.run.assert_called_once_with(
self.index.stub.Fetch, FetchRequest(ids=["vec1", "vec2"], namespace="ns"), timeout=30
)
19 changes: 0 additions & 19 deletions tests/unit_grpc/test_grpc_index_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,6 @@ def test_init_with_default_config(self):
assert index.grpc_client_config.grpc_channel_options is None
assert index.grpc_client_config.additional_metadata is None

# Default metadata, grpc equivalent to http request headers
assert len(index.fixed_metadata) == 3
assert index.fixed_metadata["api-key"] == "YOUR_API_KEY"
assert index.fixed_metadata["service-name"] == "my-index"
assert index.fixed_metadata["client-version"] is not None

def test_init_with_additional_metadata(self):
pc = PineconeGRPC(api_key="YOUR_API_KEY")
config = GRPCClientConfig(
additional_metadata={"debug-header": "value123", "debug-header2": "value456"}
)
index = pc.Index(name="my-index", host="host", grpc_config=config)
assert len(index.fixed_metadata) == 5
assert index.fixed_metadata["api-key"] == "YOUR_API_KEY"
assert index.fixed_metadata["service-name"] == "my-index"
assert index.fixed_metadata["client-version"] is not None
assert index.fixed_metadata["debug-header"] == "value123"
assert index.fixed_metadata["debug-header2"] == "value456"

def test_init_with_grpc_config_from_dict(self):
pc = PineconeGRPC(api_key="YOUR_API_KEY")
config = GRPCClientConfig._from_dict({"timeout": 10})
Expand Down
12 changes: 6 additions & 6 deletions tests/unit_grpc/test_grpc_index_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ def setup_method(self):
)

def test_query_byVectorNoFilter_queryVectorNoFilter(self, mocker, vals1):
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
mocker.patch.object(self.index.runner, "run", autospec=True)
self.index.query(top_k=10, vector=vals1)
self.index._wrap_grpc_call.assert_called_once_with(
self.index.runner.run.assert_called_once_with(
self.index.stub.Query, QueryRequest(top_k=10, vector=vals1), timeout=None
)

def test_query_byVectorWithFilter_queryVectorWithFilter(self, mocker, vals1, filter1):
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
mocker.patch.object(self.index.runner, "run", autospec=True)
self.index.query(top_k=10, vector=vals1, filter=filter1, namespace="ns", timeout=10)
self.index._wrap_grpc_call.assert_called_once_with(
self.index.runner.run.assert_called_once_with(
self.index.stub.Query,
QueryRequest(
top_k=10, vector=vals1, filter=dict_to_proto_struct(filter1), namespace="ns"
Expand All @@ -32,9 +32,9 @@ def test_query_byVectorWithFilter_queryVectorWithFilter(self, mocker, vals1, fil
)

def test_query_byVecId_queryByVecId(self, mocker):
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
mocker.patch.object(self.index.runner, "run", autospec=True)
self.index.query(top_k=10, id="vec1", include_metadata=True, include_values=False)
self.index._wrap_grpc_call.assert_called_once_with(
self.index.runner.run.assert_called_once_with(
self.index.stub.Query,
QueryRequest(top_k=10, id="vec1", include_metadata=True, include_values=False),
timeout=None,
Expand Down
Loading
Loading