From 19c5b2bee5916fd8f889f91e22ef78e7eec2b3ec Mon Sep 17 00:00:00 2001 From: Andrew Smith Date: Sat, 24 Feb 2024 00:38:58 +0000 Subject: [PATCH 1/3] Add RPC request builder class for additional filters --- postgrest/_async/request_builder.py | 3 ++- postgrest/_sync/request_builder.py | 3 ++- postgrest/base_request_builder.py | 41 +++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/postgrest/_async/request_builder.py b/postgrest/_async/request_builder.py index e8c2020f..a3d1806b 100644 --- a/postgrest/_async/request_builder.py +++ b/postgrest/_async/request_builder.py @@ -9,6 +9,7 @@ from ..base_request_builder import ( APIResponse, BaseFilterRequestBuilder, + BaseRPCRequestBuilder, BaseSelectRequestBuilder, CountMethod, SingleAPIResponse, @@ -164,7 +165,7 @@ def __init__( # this exists for type-safety. see https://gist.github.com/anand2312/93d3abf401335fd3310d9e30112303bf class AsyncRPCFilterRequestBuilder( - BaseFilterRequestBuilder[_ReturnT], AsyncSingleRequestBuilder[_ReturnT] + BaseRPCRequestBuilder[_ReturnT], AsyncSingleRequestBuilder[_ReturnT] ): def __init__( self, diff --git a/postgrest/_sync/request_builder.py b/postgrest/_sync/request_builder.py index aaaeed2f..82848f8b 100644 --- a/postgrest/_sync/request_builder.py +++ b/postgrest/_sync/request_builder.py @@ -9,6 +9,7 @@ from ..base_request_builder import ( APIResponse, BaseFilterRequestBuilder, + BaseRPCRequestBuilder, BaseSelectRequestBuilder, CountMethod, SingleAPIResponse, @@ -164,7 +165,7 @@ def __init__( # this exists for type-safety. see https://gist.github.com/anand2312/93d3abf401335fd3310d9e30112303bf class SyncRPCFilterRequestBuilder( - BaseFilterRequestBuilder[_ReturnT], SyncSingleRequestBuilder[_ReturnT] + BaseRPCRequestBuilder[_ReturnT], SyncSingleRequestBuilder[_ReturnT] ): def __init__( self, diff --git a/postgrest/base_request_builder.py b/postgrest/base_request_builder.py index 12549d0a..a0d910ff 100644 --- a/postgrest/base_request_builder.py +++ b/postgrest/base_request_builder.py @@ -569,3 +569,44 @@ def range(self: Self, start: int, end: int) -> Self: self.headers["Range-Unit"] = "items" self.headers["Range"] = f"{start}-{end - 1}" return self + + +class BaseRPCRequestBuilder(BaseSelectRequestBuilder[_ReturnT]): + def __init__( + self, + session: Union[AsyncClient, SyncClient], + headers: Headers, + params: QueryParams, + ) -> None: + # Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__ + # tries to call _GenericAlias.__init__ - which is the wrong method + # The __origin__ attribute of the _GenericAlias is the actual class + get_origin_and_cast(BaseSelectRequestBuilder[_ReturnT]).__init__( + self, session, headers, params + ) + + def select( + self, + *columns: str, + ) -> Self: + """Run a SELECT query. + + Args: + *columns: The names of the columns to fetch. + Returns: + :class:`BaseSelectRequestBuilder` + """ + method, params, headers, json = pre_select(*columns, count=None) + self.params = self.params.add("select", params.get("select")) + self.headers["Prefer"] = "return=representation" + return self + + def maybe_single(self) -> Self: + """Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error.""" + self.headers["Accept"] = "application/vnd.pgrst.object+json" + return self + + def csv(self) -> Self: + """Specify that the query must retrieve data as a single CSV string.""" + self.headers["Accept"] = "text/csv" + return self From 61a856bc4e918b9c3012c818208d8b24c992f23d Mon Sep 17 00:00:00 2001 From: Andrew Smith Date: Tue, 27 Feb 2024 01:17:12 +0000 Subject: [PATCH 2/3] Add single to the request builder --- postgrest/base_request_builder.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/postgrest/base_request_builder.py b/postgrest/base_request_builder.py index a0d910ff..09b9f41b 100644 --- a/postgrest/base_request_builder.py +++ b/postgrest/base_request_builder.py @@ -601,6 +601,15 @@ def select( self.headers["Prefer"] = "return=representation" return self + def single(self) -> Self: + """Specify that the query will only return a single row in response. + + .. caution:: + The API will raise an error if the query returned more than one row. + """ + self.headers["Accept"] = "application/vnd.pgrst.object+json" + return self + def maybe_single(self) -> Self: """Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error.""" self.headers["Accept"] = "application/vnd.pgrst.object+json" From ac054221f131383e18262ccf3753d2681423851c Mon Sep 17 00:00:00 2001 From: Andrew Smith Date: Tue, 27 Feb 2024 01:18:06 +0000 Subject: [PATCH 3/3] Add integration tests for the rpc request builder --- infra/init.sql | 7 ++++ ...test_filter_request_builder_integration.py | 41 +++++++++++++++++++ ...test_filter_request_builder_integration.py | 41 +++++++++++++++++++ 3 files changed, 89 insertions(+) diff --git a/infra/init.sql b/infra/init.sql index 3aad2eb2..bc7d733f 100644 --- a/infra/init.sql +++ b/infra/init.sql @@ -69,3 +69,10 @@ insert into public.issues (id, title, tags) values (2, 'Use better names', array['is:open', 'severity:low', 'priority:medium']), (3, 'Add missing postgrest filters', array['is:open', 'severity:low', 'priority:high']), (4, 'Add alias to filters', array['is:closed', 'severity:low', 'priority:medium']); + +create or replace function public.list_stored_countries() + returns setof countries + language sql +as $function$ + select * from countries; +$function$ diff --git a/tests/_async/test_filter_request_builder_integration.py b/tests/_async/test_filter_request_builder_integration.py index a9790471..763e0e39 100644 --- a/tests/_async/test_filter_request_builder_integration.py +++ b/tests/_async/test_filter_request_builder_integration.py @@ -1,3 +1,5 @@ +import pytest + from .client import rest_client @@ -387,3 +389,42 @@ async def test_or_on_reference_table(): ], }, ] + + +async def test_rpc_with_single(): + res = ( + await rest_client() + .rpc("list_stored_countries", {}) + .select("nicename, country_name, iso") + .eq("nicename", "Albania") + .single() + .execute() + ) + + assert res.data == {"nicename": "Albania", "country_name": "ALBANIA", "iso": "AL"} + + +async def test_rpc_with_limit(): + res = ( + await rest_client() + .rpc("list_stored_countries", {}) + .select("nicename, country_name, iso") + .eq("nicename", "Albania") + .limit(1) + .execute() + ) + + assert res.data == [{"nicename": "Albania", "country_name": "ALBANIA", "iso": "AL"}] + + +@pytest.mark.skip(reason="Need to re-implement range to use query parameters") +async def test_rpc_with_range(): + res = ( + await rest_client() + .rpc("list_stored_countries", {}) + .select("nicename, iso") + .range(0, 1) + .execute() + ) + + assert res.data == [{"nicename": "Albania", "iso": "AL"}] diff --git a/tests/_sync/test_filter_request_builder_integration.py b/tests/_sync/test_filter_request_builder_integration.py index 8798744f..19fc10bb 100644 --- a/tests/_sync/test_filter_request_builder_integration.py +++ b/tests/_sync/test_filter_request_builder_integration.py @@ -1,3 +1,5 @@ +import pytest + from .client import rest_client @@ -380,3 +382,42 @@ def test_or_on_reference_table(): ], }, ] + + +def test_rpc_with_single(): + res = ( + rest_client() + .rpc("list_stored_countries", {}) + .select("nicename, country_name, iso") + .eq("nicename", "Albania") + .single() + .execute() + ) + + assert res.data == {"nicename": "Albania", "country_name": "ALBANIA", "iso": "AL"} + + +def test_rpc_with_limit(): + res = ( + rest_client() + .rpc("list_stored_countries", {}) + .select("nicename, country_name, iso") + .eq("nicename", "Albania") + .limit(1) + .execute() + ) + + assert res.data == [{"nicename": "Albania", "country_name": "ALBANIA", "iso": "AL"}] + + +@pytest.mark.skip(reason="Need to re-implement range to use query parameters") +def test_rpc_with_range(): + res = ( + rest_client() + .rpc("list_stored_countries", {}) + .select("nicename, iso") + .range(0, 1) + .execute() + ) + + assert res.data == [{"nicename": "Albania", "iso": "AL"}]