diff --git a/postgrest_py/base_request_builder.py b/postgrest_py/base_request_builder.py index 350a71d0..83a4d7b7 100644 --- a/postgrest_py/base_request_builder.py +++ b/postgrest_py/base_request_builder.py @@ -125,61 +125,61 @@ def filter(self, column: str, operator: str, criteria: str): self.session.params = self.session.params.add(key, val) return self - def eq(self, column: str, value: str): + def eq(self, column: str, value: Any): return self.filter(column, Filters.EQ, sanitize_param(value)) - def neq(self, column: str, value: str): + def neq(self, column: str, value: Any): return self.filter(column, Filters.NEQ, sanitize_param(value)) - def gt(self, column: str, value: str): + def gt(self, column: str, value: Any): return self.filter(column, Filters.GT, sanitize_param(value)) - def gte(self, column: str, value: str): + def gte(self, column: str, value: Any): return self.filter(column, Filters.GTE, sanitize_param(value)) - def lt(self, column: str, value: str): + def lt(self, column: str, value: Any): return self.filter(column, Filters.LT, sanitize_param(value)) - def lte(self, column: str, value: str): + def lte(self, column: str, value: Any): return self.filter(column, Filters.LTE, sanitize_param(value)) - def is_(self, column: str, value: str): + def is_(self, column: str, value: Any): return self.filter(column, Filters.IS, sanitize_param(value)) - def like(self, column: str, pattern: str): + def like(self, column: str, pattern: Any): return self.filter(column, Filters.LIKE, sanitize_pattern_param(pattern)) - def ilike(self, column: str, pattern: str): + def ilike(self, column: str, pattern: Any): return self.filter(column, Filters.ILIKE, sanitize_pattern_param(pattern)) - def fts(self, column: str, query: str): + def fts(self, column: str, query: Any): return self.filter(column, Filters.FTS, sanitize_param(query)) - def plfts(self, column: str, query: str): + def plfts(self, column: str, query: Any): return self.filter(column, Filters.PLFTS, sanitize_param(query)) - def phfts(self, column: str, query: str): + def phfts(self, column: str, query: Any): return self.filter(column, Filters.PHFTS, sanitize_param(query)) - def wfts(self, column: str, query: str): + def wfts(self, column: str, query: Any): return self.filter(column, Filters.WFTS, sanitize_param(query)) - def in_(self, column: str, values: Iterable[str]): + def in_(self, column: str, values: Iterable[Any]): values = map(sanitize_param, values) values = ",".join(values) return self.filter(column, Filters.IN, f"({values})") - def cs(self, column: str, values: Iterable[str]): + def cs(self, column: str, values: Iterable[Any]): values = map(sanitize_param, values) values = ",".join(values) return self.filter(column, Filters.CS, f"{{{values}}}") - def cd(self, column: str, values: Iterable[str]): + def cd(self, column: str, values: Iterable[Any]): values = map(sanitize_param, values) values = ",".join(values) return self.filter(column, Filters.CD, f"{{{values}}}") - def ov(self, column: str, values: Iterable[str]): + def ov(self, column: str, values: Iterable[Any]): values = map(sanitize_param, values) values = ",".join(values) return self.filter(column, Filters.OV, f"{{{values}}}") diff --git a/postgrest_py/utils.py b/postgrest_py/utils.py index 7540ac36..c58be38f 100644 --- a/postgrest_py/utils.py +++ b/postgrest_py/utils.py @@ -1,3 +1,5 @@ +from typing import Any + from httpx import AsyncClient # noqa: F401 from httpx import Client as BaseClient # noqa: F401 @@ -7,11 +9,12 @@ def aclose(self) -> None: self.close() -def sanitize_param(param: str) -> str: +def sanitize_param(param: Any) -> str: + param_str = str(param) reserved_chars = ",.:()" - if any(char in param for char in reserved_chars): - return f"%22{param}%22" - return param + if any(char in param_str for char in reserved_chars): + return f"%22{param_str}%22" + return param_str def sanitize_pattern_param(pattern: str) -> str: