Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,5 @@ tests_sync/

# version files
.tool-versions

.vscode/
1 change: 1 addition & 0 deletions aredis_om/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
RedisModelError,
VectorFieldOptions,
)
from .model.types import Coordinates, GeoFilter
22 changes: 18 additions & 4 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import abc
import dataclasses
import decimal
import json
import logging
import operator
Expand Down Expand Up @@ -45,6 +44,7 @@
from .encoders import jsonable_encoder
from .render_tree import render_tree
from .token_escaper import TokenEscaper
from .types import Coordinates, CoordinateType, GeoFilter


model_registry = {}
Expand Down Expand Up @@ -405,7 +405,6 @@ class RediSearchFieldTypes(Enum):
GEO = "GEO"


# TODO: How to handle Geo fields?
DEFAULT_PAGE_SIZE = 1000


Expand Down Expand Up @@ -535,8 +534,12 @@ def validate_sort_fields(self, sort_fields: List[str]):
def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldTypes:
field_info: Union[FieldInfo, PydanticFieldInfo] = field

typ = get_outer_type(field_info)

if getattr(field_info, "primary_key", None) is True:
return RediSearchFieldTypes.TAG
elif typ in [CoordinateType, Coordinates]:
return RediSearchFieldTypes.GEO
elif op is Operators.LIKE:
fts = getattr(field_info, "full_text_search", None)
if fts is not True: # Could be PydanticUndefined
Expand All @@ -552,7 +555,6 @@ def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldType
if not isinstance(field_type, type):
field_type = field_type.__origin__

# TODO: GEO fields
container_type = get_origin(field_type)

if is_supported_container_type(container_type):
Expand Down Expand Up @@ -726,6 +728,15 @@ def resolve_value(
field_name=field_name, expanded_value=expanded_value
)

elif field_type is RediSearchFieldTypes.GEO:
if not isinstance(value, GeoFilter):
raise QuerySyntaxError(
"You can only use a GeoFilter object with a GEO field."
)

if op is Operators.EQ:
result += f"@{field_name}:[{value}]"

return result

def resolve_redisearch_pagination(self):
Expand Down Expand Up @@ -1804,6 +1815,8 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
schema = cls.schema_for_type(name, embedded_cls, field_info)
elif typ is bool:
schema = f"{name} TAG"
elif typ in [CoordinateType, Coordinates]:
schema = f"{name} GEO"
elif is_numeric_type(typ):
vector_options: Optional[VectorFieldOptions] = getattr(
field_info, "vector_options", None
Expand Down Expand Up @@ -2107,7 +2120,6 @@ def schema_for_type(
else typ
)

# TODO: GEO field
if is_vector and vector_options:
schema = f"{path} AS {index_field_name} {vector_options.schema}"
elif parent_is_container_type or parent_is_model_in_container:
Expand All @@ -2128,6 +2140,8 @@ def schema_for_type(
schema += " CASESENSITIVE"
elif typ is bool:
schema = f"{path} AS {index_field_name} TAG"
elif typ in [CoordinateType, Coordinates]:
schema = f"{path} AS {index_field_name} GEO"
elif is_numeric_type(typ):
schema = f"{path} AS {index_field_name} NUMERIC"
elif issubclass(typ, str):
Expand Down
50 changes: 50 additions & 0 deletions aredis_om/model/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Annotated, Any, Literal

from pydantic import BeforeValidator, PlainSerializer
from pydantic_extra_types.coordinate import Coordinate


RadiusUnit = Literal["m", "km", "mi", "ft"]


class GeoFilter:
def __init__(self, longitude: float, latitude: float, radius: float, unit: RadiusUnit):
self.longitude = longitude
self.latitude = latitude
self.radius = radius
self.unit = unit

def __str__(self):
return f"{self.longitude} {self.latitude} {self.radius} {self.unit}"


CoordinateType = Coordinate


def parse_redis(v: Any):
"""
The pydantic coordinate type expects a string in the format 'latitude,longitude'.
Redis expects a string in the format 'longitude,latitude'.
This validator transforms the input from Redis into the expected format for pydantic.
"""
if isinstance(v, str):
parts = v.split(",")

if len(parts) != 2:
raise ValueError("Invalid coordinate format")

return (parts[1], parts[0])

return v


Coordinates = Annotated[
CoordinateType,
PlainSerializer(
lambda v: f"{v.longitude},{v.latitude}",
return_type=str,
when_used="unless-none",
),
BeforeValidator(parse_redis),
]
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redis-om"
version = "1.0.2-beta"
version = "1.0.3-beta"
description = "Object mappings, and more, for Redis."
authors = ["Redis OSS <[email protected]>"]
maintainers = ["Redis OSS <[email protected]>"]
Expand Down Expand Up @@ -46,6 +46,7 @@ typing-extensions = "^4.4.0"
hiredis = ">=2.2.3,<4.0.0"
more-itertools = ">=8.14,<11.0"
setuptools = ">=70.0"
pydantic-extra-types = "^2.10.5"

[tool.poetry.group.dev.dependencies]
mypy = "^1.9.0"
Expand Down
110 changes: 110 additions & 0 deletions tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import pytest_asyncio

from aredis_om import (
Coordinates,
Field,
GeoFilter,
HashModel,
Migrator,
NotFoundError,
Expand Down Expand Up @@ -1054,3 +1056,111 @@ class Meta:

rematerialized = await Model.find(Model.first_name == "Steve").first()
assert rematerialized.pk == model.pk


@py_test_mark_asyncio
async def test_can_search_on_coordinates(key_prefix, redis):
class Location(HashModel, index=True):
coordinates: Coordinates = Field(index=True)

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

latitude = 45.5231
longitude = -122.6765

loc = Location(coordinates=(latitude, longitude))

await loc.save()

rematerialized: Location = await Location.find(
Location.coordinates
== GeoFilter(longitude=longitude, latitude=latitude, radius=10, unit="mi")
).first()

assert rematerialized.pk == loc.pk
assert rematerialized.coordinates.latitude == latitude
assert rematerialized.coordinates.longitude == longitude


@py_test_mark_asyncio
async def test_does_not_return_coordinates_if_outside_radius(key_prefix, redis):
class Location(HashModel, index=True):
coordinates: Coordinates = Field(index=True)

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

latitude = 45.5231
longitude = -122.6765

loc = Location(coordinates=(latitude, longitude))

await loc.save()

with pytest.raises(NotFoundError):
await Location.find(
Location.coordinates
== GeoFilter(longitude=0, latitude=0, radius=0.1, unit="mi")
).first()


@py_test_mark_asyncio
async def test_does_not_return_coordinates_if_location_is_none(key_prefix, redis):
class Location(HashModel, index=True):
coordinates: Optional[Coordinates] = Field(index=True)

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

loc = Location(coordinates=None)

await loc.save()

with pytest.raises(NotFoundError):
await Location.find(
Location.coordinates
== GeoFilter(longitude=0, latitude=0, radius=0.1, unit="mi")
).first()


@py_test_mark_asyncio
async def test_can_search_on_multiple_fields_with_geo_filter(key_prefix, redis):
class Location(HashModel, index=True):
coordinates: Coordinates = Field(index=True)
name: str = Field(index=True)

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

latitude = 45.5231
longitude = -122.6765

loc1 = Location(coordinates=(latitude, longitude), name="Portland")
loc2 = Location(coordinates=(latitude + 0.01, longitude + 0.01), name="Nearby")

await loc1.save()
await loc2.save()

rematerialized: List[Location] = await Location.find(
(
Location.coordinates
== GeoFilter(longitude=longitude, latitude=latitude, radius=10, unit="mi")
)
& (Location.name == "Portland")
).all()

assert len(rematerialized) == 1
assert rematerialized[0].pk == loc1.pk
Loading