Skip to content
This repository was archived by the owner on Sep 8, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions infra/init.sql
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,10 @@ insert into public.issues (id, title, tags) values
(2, 'Use better names', array['is:open', 'severity:low', 'priority:medium']),
(3, 'Add missing postgrest filters', array['is:open', 'severity:low', 'priority:high']),
(4, 'Add alias to filters', array['is:closed', 'severity:low', 'priority:medium']);

create or replace function public.list_stored_countries()
returns setof countries
language sql
as $function$
select * from countries;
$function$
3 changes: 2 additions & 1 deletion postgrest/_async/request_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..base_request_builder import (
APIResponse,
BaseFilterRequestBuilder,
BaseRPCRequestBuilder,
BaseSelectRequestBuilder,
CountMethod,
SingleAPIResponse,
Expand Down Expand Up @@ -164,7 +165,7 @@ def __init__(

# this exists for type-safety. see https://gist.github.com/anand2312/93d3abf401335fd3310d9e30112303bf
class AsyncRPCFilterRequestBuilder(
BaseFilterRequestBuilder[_ReturnT], AsyncSingleRequestBuilder[_ReturnT]
BaseRPCRequestBuilder[_ReturnT], AsyncSingleRequestBuilder[_ReturnT]
):
def __init__(
self,
Expand Down
3 changes: 2 additions & 1 deletion postgrest/_sync/request_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..base_request_builder import (
APIResponse,
BaseFilterRequestBuilder,
BaseRPCRequestBuilder,
BaseSelectRequestBuilder,
CountMethod,
SingleAPIResponse,
Expand Down Expand Up @@ -164,7 +165,7 @@ def __init__(

# this exists for type-safety. see https://gist.github.com/anand2312/93d3abf401335fd3310d9e30112303bf
class SyncRPCFilterRequestBuilder(
BaseFilterRequestBuilder[_ReturnT], SyncSingleRequestBuilder[_ReturnT]
BaseRPCRequestBuilder[_ReturnT], SyncSingleRequestBuilder[_ReturnT]
):
def __init__(
self,
Expand Down
50 changes: 50 additions & 0 deletions postgrest/base_request_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,53 @@ def range(self: Self, start: int, end: int) -> Self:
self.headers["Range-Unit"] = "items"
self.headers["Range"] = f"{start}-{end - 1}"
return self


class BaseRPCRequestBuilder(BaseSelectRequestBuilder[_ReturnT]):
def __init__(
self,
session: Union[AsyncClient, SyncClient],
headers: Headers,
params: QueryParams,
) -> None:
# Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__
# tries to call _GenericAlias.__init__ - which is the wrong method
# The __origin__ attribute of the _GenericAlias is the actual class
get_origin_and_cast(BaseSelectRequestBuilder[_ReturnT]).__init__(
self, session, headers, params
)

def select(
self,
*columns: str,
) -> Self:
"""Run a SELECT query.

Args:
*columns: The names of the columns to fetch.
Returns:
:class:`BaseSelectRequestBuilder`
"""
method, params, headers, json = pre_select(*columns, count=None)
self.params = self.params.add("select", params.get("select"))
self.headers["Prefer"] = "return=representation"
return self

def single(self) -> Self:
"""Specify that the query will only return a single row in response.

.. caution::
The API will raise an error if the query returned more than one row.
"""
self.headers["Accept"] = "application/vnd.pgrst.object+json"
return self

def maybe_single(self) -> Self:
"""Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error."""
self.headers["Accept"] = "application/vnd.pgrst.object+json"
return self

def csv(self) -> Self:
"""Specify that the query must retrieve data as a single CSV string."""
self.headers["Accept"] = "text/csv"
return self
41 changes: 41 additions & 0 deletions tests/_async/test_filter_request_builder_integration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from .client import rest_client


Expand Down Expand Up @@ -387,3 +389,42 @@ async def test_or_on_reference_table():
],
},
]


async def test_rpc_with_single():
res = (
await rest_client()
.rpc("list_stored_countries", {})
.select("nicename, country_name, iso")
.eq("nicename", "Albania")
.single()
.execute()
)

assert res.data == {"nicename": "Albania", "country_name": "ALBANIA", "iso": "AL"}


async def test_rpc_with_limit():
res = (
await rest_client()
.rpc("list_stored_countries", {})
.select("nicename, country_name, iso")
.eq("nicename", "Albania")
.limit(1)
.execute()
)

assert res.data == [{"nicename": "Albania", "country_name": "ALBANIA", "iso": "AL"}]


@pytest.mark.skip(reason="Need to re-implement range to use query parameters")
async def test_rpc_with_range():
res = (
await rest_client()
.rpc("list_stored_countries", {})
.select("nicename, iso")
.range(0, 1)
.execute()
)

assert res.data == [{"nicename": "Albania", "iso": "AL"}]
41 changes: 41 additions & 0 deletions tests/_sync/test_filter_request_builder_integration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from .client import rest_client


Expand Down Expand Up @@ -380,3 +382,42 @@ def test_or_on_reference_table():
],
},
]


def test_rpc_with_single():
res = (
rest_client()
.rpc("list_stored_countries", {})
.select("nicename, country_name, iso")
.eq("nicename", "Albania")
.single()
.execute()
)

assert res.data == {"nicename": "Albania", "country_name": "ALBANIA", "iso": "AL"}


def test_rpc_with_limit():
res = (
rest_client()
.rpc("list_stored_countries", {})
.select("nicename, country_name, iso")
.eq("nicename", "Albania")
.limit(1)
.execute()
)

assert res.data == [{"nicename": "Albania", "country_name": "ALBANIA", "iso": "AL"}]


@pytest.mark.skip(reason="Need to re-implement range to use query parameters")
def test_rpc_with_range():
res = (
rest_client()
.rpc("list_stored_countries", {})
.select("nicename, iso")
.range(0, 1)
.execute()
)

assert res.data == [{"nicename": "Albania", "iso": "AL"}]